Compare commits

..

2 Commits

Author SHA1 Message Date
f2a164b82c main:新增 Worker 支持及任务管理优化
变更内容:
- 添加 Worker 进程模块,支持基于 Redis 的任务管理及分布式锁。
- 增加 `entrypoint.sh` 启动脚本,支持根据 `RUN_MODE` 自动运行 API 或 Worker。
- 优化 `docker-compose.yml` 配置,添加镜像及平台支持。
- 在 JobManager 中集成 `request_id` 上下文传递,改进日志追踪功能。
- 扩展单元测试,提升测试覆盖率。
2026-02-03 15:13:11 +08:00
bad3a34a82 main:支持 Worker 模式运行并优化任务管理
变更内容:
- 在 `Dockerfile` 和 `docker-compose.yml` 中添加 Worker 模式支持,包含运行模式 `RUN_MODE` 的配置。
- 更新 API 路由,改为将任务入队处理,并由 Worker 执行。
- 在 JobManager 中新增任务队列及分布式锁功能,支持任务的入队、出队、执行控制以及重试机制。
- 添加全局并发控制逻辑,避免任务超额运行。
- 扩展单元测试,覆盖任务队列、锁机制和并发控制的各类场景。
- 在 Serverless 配置中分别为 API 和 Worker 添加独立服务定义。

提升任务调度灵活性,增强系统可靠性与扩展性。
2026-02-03 13:29:32 +08:00
9 changed files with 795 additions and 27 deletions

View File

@@ -30,9 +30,15 @@ USER appuser
# 暴露端口 # 暴露端口
EXPOSE 8000 EXPOSE 8000
# 健康检查 # 运行模式api默认或 worker
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ ENV RUN_MODE=api
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/healthz')"
# 启动命令 # 健康检查(仅对 API 模式有效)
CMD ["uvicorn", "functional_scaffold.main:app", "--host", "0.0.0.0", "--port", "8000"] HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD if [ "$RUN_MODE" = "api" ]; then python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/healthz')"; else exit 0; fi
# 启动脚本
COPY --chown=appuser:appuser deployment/entrypoint.sh /app/entrypoint.sh
RUN chmod +x /app/entrypoint.sh
CMD ["/app/entrypoint.sh"]

View File

@@ -11,6 +11,7 @@ services:
- APP_ENV=development - APP_ENV=development
- LOG_LEVEL=INFO - LOG_LEVEL=INFO
- METRICS_ENABLED=true - METRICS_ENABLED=true
- RUN_MODE=api
# Redis 指标存储配置 # Redis 指标存储配置
- REDIS_HOST=redis - REDIS_HOST=redis
- REDIS_PORT=6379 - REDIS_PORT=6379
@@ -38,6 +39,38 @@ services:
retries: 3 retries: 3
start_period: 5s start_period: 5s
# Worker 服务 - 处理异步任务
worker:
build:
context: ..
dockerfile: deployment/Dockerfile
environment:
- APP_ENV=development
- LOG_LEVEL=INFO
- METRICS_ENABLED=true
- RUN_MODE=worker
# Redis 配置
- REDIS_HOST=redis
- REDIS_PORT=6379
- REDIS_DB=0
# Worker 配置
- WORKER_POLL_INTERVAL=1.0
- MAX_CONCURRENT_JOBS=10
- JOB_MAX_RETRIES=3
- JOB_EXECUTION_TIMEOUT=300
volumes:
- ../src:/app/src
- ../config:/app/config
labels:
logging: "promtail"
logging_jobname: "functional-scaffold-worker"
restart: unless-stopped
depends_on:
redis:
condition: service_healthy
deploy:
replicas: 2
# Redis - 用于集中式指标存储 # Redis - 用于集中式指标存储
redis: redis:
image: redis:7-alpine image: redis:7-alpine

12
deployment/entrypoint.sh Normal file
View File

@@ -0,0 +1,12 @@
#!/bin/bash
# 启动脚本:根据 RUN_MODE 环境变量选择启动 API 或 Worker
set -e
if [ "$RUN_MODE" = "worker" ]; then
echo "启动 Worker 模式..."
exec python -m functional_scaffold.worker
else
echo "启动 API 模式..."
exec uvicorn functional_scaffold.main:app --host 0.0.0.0 --port 8000
fi

View File

@@ -17,7 +17,7 @@ Resources:
prime-checker: prime-checker:
Type: 'Aliyun::Serverless::Function' Type: 'Aliyun::Serverless::Function'
Properties: Properties:
Description: '质数判断算法服务' Description: '质数判断算法服务API'
Runtime: custom-container Runtime: custom-container
MemorySize: 512 MemorySize: 512
Timeout: 60 Timeout: 60
@@ -25,11 +25,14 @@ Resources:
CAPort: 8000 CAPort: 8000
CustomContainerConfig: CustomContainerConfig:
Image: 'registry.cn-hangzhou.aliyuncs.com/your-namespace/functional-scaffold:latest' Image: 'registry.cn-hangzhou.aliyuncs.com/your-namespace/functional-scaffold:latest'
Command: '["uvicorn", "functional_scaffold.main:app", "--host", "0.0.0.0", "--port", "8000"]' Command: '["/app/entrypoint.sh"]'
EnvironmentVariables: EnvironmentVariables:
APP_ENV: production APP_ENV: production
LOG_LEVEL: INFO LOG_LEVEL: INFO
METRICS_ENABLED: 'true' METRICS_ENABLED: 'true'
RUN_MODE: api
REDIS_HOST: 'r-xxxxx.redis.rds.aliyuncs.com'
REDIS_PORT: '6379'
Events: Events:
httpTrigger: httpTrigger:
Type: HTTP Type: HTTP
@@ -38,3 +41,32 @@ Resources:
Methods: Methods:
- GET - GET
- POST - POST
job-worker:
Type: 'Aliyun::Serverless::Function'
Properties:
Description: '异步任务 Worker'
Runtime: custom-container
MemorySize: 512
Timeout: 900
InstanceConcurrency: 1
CustomContainerConfig:
Image: 'registry.cn-hangzhou.aliyuncs.com/your-namespace/functional-scaffold:latest'
Command: '["/app/entrypoint.sh"]'
EnvironmentVariables:
APP_ENV: production
LOG_LEVEL: INFO
METRICS_ENABLED: 'true'
RUN_MODE: worker
REDIS_HOST: 'r-xxxxx.redis.rds.aliyuncs.com'
REDIS_PORT: '6379'
WORKER_POLL_INTERVAL: '1.0'
MAX_CONCURRENT_JOBS: '5'
JOB_MAX_RETRIES: '3'
JOB_EXECUTION_TIMEOUT: '300'
Events:
timerTrigger:
Type: Timer
Properties:
CronExpression: '0 */1 * * * *'
Enable: true
Payload: '{}'

View File

@@ -1,6 +1,5 @@
"""API 路由""" """API 路由"""
import asyncio
from fastapi import APIRouter, HTTPException, Depends, status from fastapi import APIRouter, HTTPException, Depends, status
import time import time
import logging import logging
@@ -200,10 +199,10 @@ async def create_job(
# 获取任务信息 # 获取任务信息
job_data = await job_manager.get_job(job_id) job_data = await job_manager.get_job(job_id)
# 后台执行任务 # 任务入队,由 Worker 执行
asyncio.create_task(job_manager.execute_job(job_id)) await job_manager.enqueue_job(job_id)
logger.info(f"异步任务已创建: job_id={job_id}, request_id={request_id}") logger.info(f"异步任务已创建并入队: job_id={job_id}, request_id={request_id}")
return JobCreateResponse( return JobCreateResponse(
job_id=job_id, job_id=job_id,

View File

@@ -57,6 +57,14 @@ class Settings(BaseSettings):
webhook_timeout: int = 10 # Webhook 超时时间(秒) webhook_timeout: int = 10 # Webhook 超时时间(秒)
max_concurrent_jobs: int = 10 # 最大并发任务数 max_concurrent_jobs: int = 10 # 最大并发任务数
# Worker 配置
worker_poll_interval: float = 1.0 # Worker 轮询间隔(秒)
job_queue_key: str = "job:queue" # 任务队列 Redis Key
job_concurrency_key: str = "job:concurrency" # 全局并发计数器 Redis Key
job_lock_ttl: int = 300 # 任务锁 TTL
job_max_retries: int = 3 # 任务最大重试次数
job_execution_timeout: int = 300 # 任务执行超时(秒)
# 全局配置实例 # 全局配置实例
settings = Settings() settings = Settings()

View File

@@ -16,6 +16,7 @@ import redis.asyncio as aioredis
from ..algorithms.base import BaseAlgorithm from ..algorithms.base import BaseAlgorithm
from ..config import settings from ..config import settings
from ..core.metrics_unified import incr, observe from ..core.metrics_unified import incr, observe
from ..core.tracing import set_request_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -176,6 +177,7 @@ class JobManager:
"job_id": job_id, "job_id": job_id,
"status": job_data.get("status", ""), "status": job_data.get("status", ""),
"algorithm": job_data.get("algorithm", ""), "algorithm": job_data.get("algorithm", ""),
"request_id": job_data.get("request_id") or None,
"created_at": job_data.get("created_at", ""), "created_at": job_data.get("created_at", ""),
"started_at": job_data.get("started_at") or None, "started_at": job_data.get("started_at") or None,
"completed_at": job_data.get("completed_at") or None, "completed_at": job_data.get("completed_at") or None,
@@ -223,6 +225,11 @@ class JobManager:
algorithm_name = job_data.get("algorithm", "") algorithm_name = job_data.get("algorithm", "")
webhook_url = job_data.get("webhook", "") webhook_url = job_data.get("webhook", "")
request_id = job_data.get("request_id", "")
# 设置 request_id 上下文,确保日志中包含 request_id
if request_id:
set_request_id(request_id)
# 解析参数 # 解析参数
try: try:
@@ -234,7 +241,9 @@ class JobManager:
async with self._semaphore: async with self._semaphore:
# 更新状态为 running # 更新状态为 running
started_at = self._get_timestamp() started_at = self._get_timestamp()
await self._redis_client.hset(key, mapping={"status": "running", "started_at": started_at}) await self._redis_client.hset(
key, mapping={"status": "running", "started_at": started_at}
)
logger.info(f"开始执行任务: job_id={job_id}, algorithm={algorithm_name}") logger.info(f"开始执行任务: job_id={job_id}, algorithm={algorithm_name}")
@@ -295,7 +304,9 @@ class JobManager:
incr("jobs_completed_total", {"algorithm": algorithm_name, "status": status}) incr("jobs_completed_total", {"algorithm": algorithm_name, "status": status})
observe("job_execution_duration_seconds", {"algorithm": algorithm_name}, elapsed_time) observe("job_execution_duration_seconds", {"algorithm": algorithm_name}, elapsed_time)
logger.info(f"任务执行完成: job_id={job_id}, status={status}, elapsed={elapsed_time:.3f}s") logger.info(
f"任务执行完成: job_id={job_id}, status={status}, elapsed={elapsed_time:.3f}s"
)
# 发送 Webhook 回调 # 发送 Webhook 回调
if webhook_url: if webhook_url:
@@ -372,6 +383,195 @@ class JobManager:
"""检查任务管理器是否可用""" """检查任务管理器是否可用"""
return self._redis_client is not None return self._redis_client is not None
async def enqueue_job(self, job_id: str) -> bool:
"""将任务加入队列
Args:
job_id: 任务 ID
Returns:
bool: 是否成功入队
"""
if not self._redis_client:
logger.error(f"Redis 不可用,无法入队任务: {job_id}")
return False
try:
await self._redis_client.lpush(settings.job_queue_key, job_id)
logger.info(f"任务已入队: job_id={job_id}")
return True
except Exception as e:
logger.error(f"任务入队失败: job_id={job_id}, error={e}")
return False
async def dequeue_job(self, timeout: int = 5) -> Optional[str]:
"""从队列获取任务(阻塞式)
Args:
timeout: 阻塞超时时间(秒)
Returns:
Optional[str]: 任务 ID超时返回 None
"""
if not self._redis_client:
return None
try:
result = await self._redis_client.brpop(settings.job_queue_key, timeout=timeout)
if result:
# brpop 返回 (key, value) 元组
return result[1]
return None
except Exception as e:
logger.error(f"任务出队失败: error={e}")
return None
async def acquire_job_lock(self, job_id: str) -> bool:
"""获取任务执行锁(分布式锁)
Args:
job_id: 任务 ID
Returns:
bool: 是否成功获取锁
"""
if not self._redis_client:
return False
lock_key = f"job:lock:{job_id}"
try:
acquired = await self._redis_client.set(
lock_key, "locked", nx=True, ex=settings.job_lock_ttl
)
if acquired:
logger.debug(f"获取任务锁成功: job_id={job_id}")
return acquired is not None
except Exception as e:
logger.error(f"获取任务锁失败: job_id={job_id}, error={e}")
return False
async def release_job_lock(self, job_id: str) -> bool:
"""释放任务执行锁
Args:
job_id: 任务 ID
Returns:
bool: 是否成功释放锁
"""
if not self._redis_client:
return False
lock_key = f"job:lock:{job_id}"
try:
await self._redis_client.delete(lock_key)
logger.debug(f"释放任务锁成功: job_id={job_id}")
return True
except Exception as e:
logger.error(f"释放任务锁失败: job_id={job_id}, error={e}")
return False
async def increment_concurrency(self) -> int:
"""增加全局并发计数
Returns:
int: 增加后的并发数
"""
if not self._redis_client:
return 0
try:
count = await self._redis_client.incr(settings.job_concurrency_key)
return count
except Exception as e:
logger.error(f"增加并发计数失败: error={e}")
return 0
async def decrement_concurrency(self) -> int:
"""减少全局并发计数
Returns:
int: 减少后的并发数
"""
if not self._redis_client:
return 0
try:
count = await self._redis_client.decr(settings.job_concurrency_key)
# 防止计数变为负数
if count < 0:
await self._redis_client.set(settings.job_concurrency_key, 0)
return 0
return count
except Exception as e:
logger.error(f"减少并发计数失败: error={e}")
return 0
async def get_global_concurrency(self) -> int:
"""获取当前全局并发数
Returns:
int: 当前并发数
"""
if not self._redis_client:
return 0
try:
count = await self._redis_client.get(settings.job_concurrency_key)
return int(count) if count else 0
except Exception as e:
logger.error(f"获取并发计数失败: error={e}")
return 0
async def can_execute(self) -> bool:
"""检查是否可以执行新任务(全局并发控制)
Returns:
bool: 是否可以执行
"""
current = await self.get_global_concurrency()
return current < settings.max_concurrent_jobs
async def get_job_retry_count(self, job_id: str) -> int:
"""获取任务重试次数
Args:
job_id: 任务 ID
Returns:
int: 重试次数
"""
if not self._redis_client:
return 0
key = f"job:{job_id}"
try:
retry_count = await self._redis_client.hget(key, "retry_count")
return int(retry_count) if retry_count else 0
except Exception:
return 0
async def increment_job_retry(self, job_id: str) -> int:
"""增加任务重试次数
Args:
job_id: 任务 ID
Returns:
int: 增加后的重试次数
"""
if not self._redis_client:
return 0
key = f"job:{job_id}"
try:
await self._redis_client.hincrby(key, "retry_count", 1)
retry_count = await self._redis_client.hget(key, "retry_count")
return int(retry_count) if retry_count else 1
except Exception as e:
logger.error(f"增加重试次数失败: job_id={job_id}, error={e}")
return 0
def get_concurrency_status(self) -> Dict[str, int]: def get_concurrency_status(self) -> Dict[str, int]:
"""获取并发状态 """获取并发状态

View File

@@ -0,0 +1,197 @@
"""Worker 进程模块
基于 Redis 队列的任务 Worker支持分布式锁和全局并发控制。
"""
import asyncio
import logging
import signal
import sys
from typing import Optional
from .config import settings
from .core.job_manager import JobManager
from .core.logging import setup_logging
from .core.tracing import set_request_id
logger = logging.getLogger(__name__)
class JobWorker:
"""任务 Worker
从 Redis 队列获取任务并执行,支持:
- 分布式锁防止重复执行
- 全局并发控制
- 任务重试机制
- 优雅关闭
"""
def __init__(self):
self._job_manager: Optional[JobManager] = None
self._running: bool = False
self._current_job_id: Optional[str] = None
async def initialize(self) -> None:
"""初始化 Worker"""
self._job_manager = JobManager()
await self._job_manager.initialize()
logger.info("Worker 初始化完成")
async def shutdown(self) -> None:
"""关闭 Worker"""
logger.info("Worker 正在关闭...")
self._running = False
# 等待当前任务完成
if self._current_job_id:
logger.info(f"等待当前任务完成: {self._current_job_id}")
if self._job_manager:
await self._job_manager.shutdown()
logger.info("Worker 已关闭")
async def run(self) -> None:
"""运行 Worker 主循环"""
self._running = True
logger.info(
f"Worker 启动,轮询间隔: {settings.worker_poll_interval}s"
f"最大并发: {settings.max_concurrent_jobs}"
)
while self._running:
try:
await self._process_next_job()
except Exception as e:
logger.error(f"Worker 循环异常: {e}", exc_info=True)
await asyncio.sleep(settings.worker_poll_interval)
async def _process_next_job(self) -> None:
"""处理下一个任务"""
if not self._job_manager:
logger.error("JobManager 未初始化")
await asyncio.sleep(settings.worker_poll_interval)
return
# 从队列获取任务
job_id = await self._job_manager.dequeue_job(timeout=int(settings.worker_poll_interval))
if not job_id:
return
# 获取任务信息以提取 request_id
job_data = await self._job_manager.get_job(job_id)
if job_data:
request_id = job_data.get("request_id") or job_id
set_request_id(request_id)
else:
set_request_id(job_id)
logger.info(f"从队列获取任务: {job_id}")
# 尝试获取分布式锁
if not await self._job_manager.acquire_job_lock(job_id):
logger.warning(f"无法获取任务锁,任务可能正在被其他 Worker 执行: {job_id}")
return
try:
# 检查全局并发限制
if not await self._job_manager.can_execute():
logger.info(f"达到并发限制,任务重新入队: {job_id}")
await self._job_manager.enqueue_job(job_id)
return
# 增加并发计数
await self._job_manager.increment_concurrency()
self._current_job_id = job_id
try:
# 执行任务
await self._execute_with_retry(job_id)
finally:
# 减少并发计数
await self._job_manager.decrement_concurrency()
self._current_job_id = None
finally:
# 释放分布式锁
await self._job_manager.release_job_lock(job_id)
async def _execute_with_retry(self, job_id: str) -> None:
"""执行任务(带重试机制)"""
if not self._job_manager:
return
try:
# 执行任务
await asyncio.wait_for(
self._job_manager.execute_job(job_id),
timeout=settings.job_execution_timeout,
)
except asyncio.TimeoutError:
logger.error(f"任务执行超时: {job_id}")
await self._handle_job_failure(job_id, "任务执行超时")
except Exception as e:
logger.error(f"任务执行异常: {job_id}, error={e}", exc_info=True)
await self._handle_job_failure(job_id, str(e))
async def _handle_job_failure(self, job_id: str, error: str) -> None:
"""处理任务失败"""
if not self._job_manager:
return
retry_count = await self._job_manager.increment_job_retry(job_id)
if retry_count < settings.job_max_retries:
logger.info(f"任务将重试 ({retry_count}/{settings.job_max_retries}): {job_id}")
# 重新入队
await self._job_manager.enqueue_job(job_id)
else:
logger.error(f"任务达到最大重试次数,标记为失败: {job_id}")
# 更新任务状态为失败
if self._job_manager._redis_client:
key = f"job:{job_id}"
await self._job_manager._redis_client.hset(
key,
mapping={
"status": "failed",
"error": f"达到最大重试次数 ({settings.job_max_retries}): {error}",
},
)
def setup_signal_handlers(worker: JobWorker, loop: asyncio.AbstractEventLoop) -> None:
"""设置信号处理器"""
def signal_handler(sig: signal.Signals) -> None:
logger.info(f"收到信号 {sig.name},准备关闭...")
loop.create_task(worker.shutdown())
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler, sig)
async def main() -> None:
"""Worker 入口函数"""
# 设置日志
setup_logging(level=settings.log_level, format_type=settings.log_format)
worker = JobWorker()
# 设置信号处理
loop = asyncio.get_running_loop()
setup_signal_handlers(worker, loop)
try:
await worker.initialize()
await worker.run()
except Exception as e:
logger.error(f"Worker 异常退出: {e}", exc_info=True)
sys.exit(1)
finally:
await worker.shutdown()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,17 +1,13 @@
"""异步任务管理器测试""" """异步任务管理器测试"""
import asyncio import asyncio
import json
import pytest import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import status from fastapi import status
from functional_scaffold.core.job_manager import ( from functional_scaffold.core.job_manager import (
JobManager, JobManager,
get_job_manager,
shutdown_job_manager,
) )
from functional_scaffold.api.models import JobStatus
class TestJobManager: class TestJobManager:
@@ -188,6 +184,7 @@ class TestJobManagerWithMocks:
# 初始化 semaphore # 初始化 semaphore
import asyncio import asyncio
manager._semaphore = asyncio.Semaphore(10) manager._semaphore = asyncio.Semaphore(10)
await manager.execute_job("test-job-id") await manager.execute_job("test-job-id")
@@ -217,7 +214,7 @@ class TestJobsAPI:
"created_at": "2026-02-02T10:00:00+00:00", "created_at": "2026-02-02T10:00:00+00:00",
} }
) )
mock_manager.execute_job = AsyncMock() mock_manager.enqueue_job = AsyncMock(return_value=True)
mock_get_manager.return_value = mock_manager mock_get_manager.return_value = mock_manager
response = client.post( response = client.post(
@@ -486,14 +483,298 @@ class TestConcurrencyControl:
def test_concurrency_status_api(self, client): def test_concurrency_status_api(self, client):
"""测试并发状态 API 端点""" """测试并发状态 API 端点"""
response = client.get("/jobs/concurrency/status") with patch(
"functional_scaffold.api.routes.get_job_manager", new_callable=AsyncMock
) as mock_get_manager:
mock_manager = MagicMock()
mock_manager.is_available.return_value = True
mock_manager.get_concurrency_status.return_value = {
"max_concurrent": 10,
"available_slots": 8,
"running_jobs": 2,
}
mock_get_manager.return_value = mock_manager
assert response.status_code == status.HTTP_200_OK response = client.get("/jobs/concurrency/status")
data = response.json()
assert "max_concurrent" in data assert response.status_code == status.HTTP_200_OK
assert "available_slots" in data data = response.json()
assert "running_jobs" in data
assert isinstance(data["max_concurrent"], int) assert "max_concurrent" in data
assert isinstance(data["available_slots"], int) assert "available_slots" in data
assert isinstance(data["running_jobs"], int) assert "running_jobs" in data
assert isinstance(data["max_concurrent"], int)
assert isinstance(data["available_slots"], int)
assert isinstance(data["running_jobs"], int)
class TestJobQueue:
"""测试任务队列功能"""
@pytest.mark.asyncio
async def test_enqueue_job(self):
"""测试任务入队"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.lpush = AsyncMock(return_value=1)
manager._redis_client = mock_redis
result = await manager.enqueue_job("test-job-id")
assert result is True
mock_redis.lpush.assert_called_once()
@pytest.mark.asyncio
async def test_enqueue_job_without_redis(self):
"""测试 Redis 不可用时入队"""
manager = JobManager()
result = await manager.enqueue_job("test-job-id")
assert result is False
@pytest.mark.asyncio
async def test_dequeue_job(self):
"""测试任务出队"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.brpop = AsyncMock(return_value=("job:queue", "test-job-id"))
manager._redis_client = mock_redis
result = await manager.dequeue_job(timeout=5)
assert result == "test-job-id"
mock_redis.brpop.assert_called_once()
@pytest.mark.asyncio
async def test_dequeue_job_timeout(self):
"""测试任务出队超时"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.brpop = AsyncMock(return_value=None)
manager._redis_client = mock_redis
result = await manager.dequeue_job(timeout=1)
assert result is None
@pytest.mark.asyncio
async def test_dequeue_job_without_redis(self):
"""测试 Redis 不可用时出队"""
manager = JobManager()
result = await manager.dequeue_job(timeout=1)
assert result is None
class TestDistributedLock:
"""测试分布式锁功能"""
@pytest.mark.asyncio
async def test_acquire_job_lock(self):
"""测试获取任务锁"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=True)
manager._redis_client = mock_redis
result = await manager.acquire_job_lock("test-job-id")
assert result is True
mock_redis.set.assert_called_once()
call_args = mock_redis.set.call_args
assert call_args[0][0] == "job:lock:test-job-id"
assert call_args[1]["nx"] is True
assert "ex" in call_args[1]
@pytest.mark.asyncio
async def test_acquire_job_lock_already_locked(self):
"""测试获取已被锁定的任务锁"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=None) # 锁已存在
manager._redis_client = mock_redis
result = await manager.acquire_job_lock("test-job-id")
assert result is False
@pytest.mark.asyncio
async def test_release_job_lock(self):
"""测试释放任务锁"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.delete = AsyncMock(return_value=1)
manager._redis_client = mock_redis
result = await manager.release_job_lock("test-job-id")
assert result is True
mock_redis.delete.assert_called_once_with("job:lock:test-job-id")
@pytest.mark.asyncio
async def test_release_job_lock_without_redis(self):
"""测试 Redis 不可用时释放锁"""
manager = JobManager()
result = await manager.release_job_lock("test-job-id")
assert result is False
class TestGlobalConcurrency:
"""测试全局并发控制功能"""
@pytest.mark.asyncio
async def test_increment_concurrency(self):
"""测试增加并发计数"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.incr = AsyncMock(return_value=5)
manager._redis_client = mock_redis
result = await manager.increment_concurrency()
assert result == 5
mock_redis.incr.assert_called_once()
@pytest.mark.asyncio
async def test_decrement_concurrency(self):
"""测试减少并发计数"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.decr = AsyncMock(return_value=4)
manager._redis_client = mock_redis
result = await manager.decrement_concurrency()
assert result == 4
mock_redis.decr.assert_called_once()
@pytest.mark.asyncio
async def test_decrement_concurrency_prevent_negative(self):
"""测试防止并发计数变为负数"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.decr = AsyncMock(return_value=-1)
mock_redis.set = AsyncMock()
manager._redis_client = mock_redis
result = await manager.decrement_concurrency()
assert result == 0
mock_redis.set.assert_called_once()
@pytest.mark.asyncio
async def test_get_global_concurrency(self):
"""测试获取全局并发数"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value="7")
manager._redis_client = mock_redis
result = await manager.get_global_concurrency()
assert result == 7
@pytest.mark.asyncio
async def test_get_global_concurrency_empty(self):
"""测试获取空的全局并发数"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
manager._redis_client = mock_redis
result = await manager.get_global_concurrency()
assert result == 0
@pytest.mark.asyncio
async def test_can_execute(self):
"""测试检查是否可执行"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value="5")
manager._redis_client = mock_redis
with patch("functional_scaffold.core.job_manager.settings") as mock_settings:
mock_settings.max_concurrent_jobs = 10
result = await manager.can_execute()
assert result is True
@pytest.mark.asyncio
async def test_can_execute_at_limit(self):
"""测试达到并发限制时"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value="10")
manager._redis_client = mock_redis
with patch("functional_scaffold.core.job_manager.settings") as mock_settings:
mock_settings.max_concurrent_jobs = 10
result = await manager.can_execute()
assert result is False
class TestJobRetry:
"""测试任务重试功能"""
@pytest.mark.asyncio
async def test_get_job_retry_count(self):
"""测试获取任务重试次数"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.hget = AsyncMock(return_value="2")
manager._redis_client = mock_redis
result = await manager.get_job_retry_count("test-job-id")
assert result == 2
mock_redis.hget.assert_called_once_with("job:test-job-id", "retry_count")
@pytest.mark.asyncio
async def test_get_job_retry_count_empty(self):
"""测试获取空的重试次数"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.hget = AsyncMock(return_value=None)
manager._redis_client = mock_redis
result = await manager.get_job_retry_count("test-job-id")
assert result == 0
@pytest.mark.asyncio
async def test_increment_job_retry(self):
"""测试增加任务重试次数"""
manager = JobManager()
mock_redis = AsyncMock()
mock_redis.hincrby = AsyncMock()
mock_redis.hget = AsyncMock(return_value="3")
manager._redis_client = mock_redis
result = await manager.increment_job_retry("test-job-id")
assert result == 3
mock_redis.hincrby.assert_called_once_with("job:test-job-id", "retry_count", 1)