- Implemented policy_utils.py with helper functions for action selection, including epsilon-greedy support.
- Updated `requirements.txt` to relax PyTorch version constraint for better GPU compatibility. - Added detailed GPU setup instructions, new device fallback options, and command examples to `README.md`. - Developed a new script `plot_model_max_x_trend.py` for visualizing training trends, generating HTML/Markdown reports.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -20,6 +20,7 @@ env/
|
||||
|
||||
# Logs / test / tooling
|
||||
*.log
|
||||
.cache/
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
|
||||
@@ -11,6 +11,7 @@ mario-rl-mvp/
|
||||
train_ppo.py
|
||||
eval.py
|
||||
record_video.py
|
||||
plot_model_max_x_trend.py
|
||||
utils.py
|
||||
artifacts/
|
||||
models/
|
||||
@@ -32,6 +33,44 @@ python -m pip install --upgrade pip setuptools wheel
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 2.1 环境准备(WSL / Ubuntu)
|
||||
|
||||
如果系统 Python 缺少 `venv/pip`,推荐直接用 `uv` 创建环境并安装依赖:
|
||||
|
||||
```bash
|
||||
cd /home/roog/super-mario/mario-rl-mvp
|
||||
uv venv .venv -p /usr/bin/python3.10
|
||||
uv pip install --python .venv/bin/python -r requirements.txt
|
||||
```
|
||||
|
||||
如果你更倾向用系统 `venv`,先安装:
|
||||
|
||||
```bash
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y python3.10-venv python3-pip
|
||||
```
|
||||
|
||||
### RTX 50 系列(如 RTX 5080)GPU 说明
|
||||
|
||||
如果你看到类似:
|
||||
|
||||
```text
|
||||
... CUDA capability sm_120 is not compatible with the current PyTorch installation ...
|
||||
```
|
||||
|
||||
说明当前 torch wheel 不包含 `sm_120` 内核。可直接升级到 `cu128` nightly:
|
||||
|
||||
```bash
|
||||
cd /home/roog/super-mario/mario-rl-mvp
|
||||
uv pip install --python .venv/bin/python --upgrade --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
|
||||
```
|
||||
|
||||
验证 GPU:
|
||||
|
||||
```bash
|
||||
.venv/bin/python -c "import torch; print(torch.__version__); print(torch.version.cuda); print(torch.cuda.is_available()); print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'); print(torch.cuda.get_device_capability(0) if torch.cuda.is_available() else 'N/A')"
|
||||
```
|
||||
|
||||
可选系统依赖(用于 ffmpeg 转码与潜在 SDL 兼容):
|
||||
|
||||
```bash
|
||||
@@ -40,12 +79,19 @@ brew install ffmpeg sdl2
|
||||
|
||||
## 3. 一条命令开始训练
|
||||
|
||||
默认 CPU 训练(如果检测到可用且稳定的 MPS,会自动尝试启用,否则自动回退 CPU):
|
||||
默认 `--device auto` 训练(优先 CUDA,其次 MPS,最后 CPU):
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo
|
||||
```
|
||||
|
||||
显式指定 `--device cuda` 或 `--device mps` 时,如果该设备不可用,脚本会默认报错(避免静默回退到 CPU)。
|
||||
若你明确接受回退,可加:
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo --device cuda --allow-device-fallback
|
||||
```
|
||||
|
||||
常用覆盖参数:
|
||||
|
||||
```bash
|
||||
@@ -78,6 +124,30 @@ python -m src.train_ppo \
|
||||
--total-timesteps 300000
|
||||
```
|
||||
|
||||
我目前的参数
|
||||
|
||||
```
|
||||
python -m src.train_ppo \
|
||||
--init-model-path artifacts/models/latest_model.zip \
|
||||
--n-envs 16 \
|
||||
--allow-partial-init \
|
||||
--reward-mode progress \
|
||||
--movement simple \
|
||||
--ent-coef 0.001 \
|
||||
--learning-rate 2e-5 \
|
||||
--n-steps 2048 \
|
||||
--gamma 0.99 \
|
||||
--death-penalty -50 \
|
||||
--stall-penalty 0.05 \
|
||||
--stall-steps 40 \
|
||||
--backward-penalty-scale 0.01 \
|
||||
--milestone-bonus 2.0 \
|
||||
--no-progress-terminate-steps 300 \
|
||||
--no-progress-terminate-penalty 10 \
|
||||
--time-penalty -0.01 \
|
||||
--total-timesteps 1200000
|
||||
```
|
||||
|
||||
### 3.1 从已有模型继续训练(`--init-model-path`)
|
||||
|
||||
- 用途:加载已有 `.zip` 权重后继续训练,适合“不中断实验目标但调整探索参数”。
|
||||
@@ -143,13 +213,35 @@ tensorboard --logdir artifacts/logs --port 6006
|
||||
加载最新模型,跑 N 个 episode,输出平均指标:
|
||||
|
||||
```bash
|
||||
python -m src.eval --episodes 5 --stochastic
|
||||
python -m src.eval \
|
||||
--model-path artifacts/models/latest_model.zip \
|
||||
--episodes 20 \
|
||||
--movement simple \
|
||||
--reward-mode progress \
|
||||
--no-progress-terminate-steps 300 \
|
||||
--no-progress-terminate-penalty 10 \
|
||||
--death-penalty -50 \
|
||||
--stall-penalty 0.05 \
|
||||
--stall-steps 40 \
|
||||
--time-penalty -0.01 \
|
||||
--epsilon 0.08
|
||||
```
|
||||
|
||||
可指定模型:
|
||||
|
||||
```bash
|
||||
python -m src.eval --model-path artifacts/models/latest_model.zip --episodes 10 --stochastic
|
||||
python -m src.eval \
|
||||
--model-path artifacts/models/latest_model.zip \
|
||||
--episodes 20 \
|
||||
--movement simple \
|
||||
--reward-mode progress \
|
||||
--no-progress-terminate-steps 300 \
|
||||
--no-progress-terminate-penalty 10 \
|
||||
--death-penalty -50 \
|
||||
--stall-penalty 0.05 \
|
||||
--stall-steps 40 \
|
||||
--time-penalty -0.01 \
|
||||
--stochastic
|
||||
```
|
||||
|
||||
注意:`eval.py` 默认 `--movement auto`,会按模型动作维度自动匹配 `right_only/simple`,避免动作空间不一致导致 `KeyError`。
|
||||
@@ -174,13 +266,41 @@ _total_timesteps = 150000
|
||||
默认录制约 10 秒 mp4 到 `artifacts/videos/`:
|
||||
|
||||
```bash
|
||||
python -m src.record_video --duration-sec 10 --fps 30 --stochastic
|
||||
python -m src.record_video \
|
||||
--model-path artifacts/models/latest_model.zip \
|
||||
--movement simple \
|
||||
--reward-mode progress \
|
||||
--no-progress-terminate-steps 300 \
|
||||
--no-progress-terminate-penalty 10 \
|
||||
--death-penalty -50 \
|
||||
--stall-penalty 0.05 \
|
||||
--stall-steps 40 \
|
||||
--time-penalty -0.01 \
|
||||
--epsilon 0.08 \
|
||||
--duration-sec 30
|
||||
```
|
||||
|
||||
可指定输出路径:
|
||||
或者稳定版本
|
||||
|
||||
```bash
|
||||
python -m src.record_video --output artifacts/videos/demo.mp4 --stochastic --duration-sec 10
|
||||
python -m src.record_video \
|
||||
--model-path artifacts/models/latest_model.zip \
|
||||
--movement simple \
|
||||
--reward-mode progress \
|
||||
--no-progress-terminate-steps 300 \
|
||||
--no-progress-terminate-penalty 10 \
|
||||
--death-penalty -50 \
|
||||
--stall-penalty 0.05 \
|
||||
--stall-steps 40 \
|
||||
--time-penalty -0.01 \
|
||||
--epsilon 0.08 \
|
||||
--epsilon-random-mode uniform \
|
||||
--max-steps 6000
|
||||
```
|
||||
可选:
|
||||
|
||||
```bash
|
||||
--output artifacts/videos/mario_eps008.mp4
|
||||
```
|
||||
|
||||
注意:`record_video.py` 默认 `--movement auto`,会按模型自动匹配动作空间。
|
||||
@@ -190,6 +310,47 @@ python -m src.record_video --output artifacts/videos/demo.mp4 --stochastic --dur
|
||||
- 默认通过 `imageio + ffmpeg` 输出 mp4
|
||||
- 若 mp4 写入失败,会自动降级保存帧序列(PNG),并打印 ffmpeg 转码命令
|
||||
|
||||
## 5.1 模型趋势可视化(HTML / Markdown)
|
||||
|
||||
用于可视化 `artifacts/models/` 里的模型在训练过程中的关键指标趋势,输出中文 HTML 或 Markdown 报告。
|
||||
|
||||
默认命令:
|
||||
|
||||
```bash
|
||||
python -m src.plot_model_max_x_trend
|
||||
```
|
||||
|
||||
默认输出:
|
||||
|
||||
- `artifacts/reports/model_max_x_trend.html`
|
||||
|
||||
输出 Markdown 报告:
|
||||
|
||||
```bash
|
||||
python -m src.plot_model_max_x_trend --format markdown
|
||||
```
|
||||
|
||||
Markdown 默认输出:
|
||||
|
||||
- `artifacts/reports/model_max_x_trend.md`
|
||||
|
||||
可选参数(自定义目录/输出):
|
||||
|
||||
```bash
|
||||
uv run python -m src.plot_model_max_x_trend \
|
||||
--models-dir artifacts/models \
|
||||
--logs-dir artifacts/logs \
|
||||
--format markdown \
|
||||
--output artifacts/reports/model_max_x_trend.md
|
||||
```
|
||||
|
||||
报告内容:
|
||||
|
||||
- 主趋势:`max_x`(最大前进距离)
|
||||
- 多维趋势:平均回报、平均回合步数、通关率、无进展终止率、死亡终止率、超时终止率、硬卡死终止率
|
||||
- 模型明细表:每个 checkpoint/final 模型对应的指标值、匹配步数、来源 TensorBoard tag
|
||||
- 术语解释:Run、Checkpoint、model_step、matched_step、TensorBoard Tag 等专有名词
|
||||
|
||||
## 6. 动作空间选择说明
|
||||
|
||||
默认 `RIGHT_ONLY`,原因:
|
||||
@@ -231,22 +392,23 @@ python -m src.train_ppo --reward-mode clip
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo \
|
||||
--init-model-path artifacts/models/latest_model.zip \
|
||||
--init-model-path /home/roog/super-mario/mario-rl-mvp/artifacts/models/ppo_SuperMarioBros-1-1-v0_20260212_205220/ppo_mario_ckpt_100000_steps.zip \
|
||||
--n-envs 16 \
|
||||
--allow-partial-init \
|
||||
--reward-mode progress \
|
||||
--movement simple \
|
||||
--ent-coef 0.04 \
|
||||
--ent-coef 0.01 \
|
||||
--learning-rate 1e-4 \
|
||||
--n-steps 512 \
|
||||
--gamma 0.995 \
|
||||
--death-penalty -120 \
|
||||
--stall-penalty 0.2 \
|
||||
--stall-steps 20 \
|
||||
--backward-penalty-scale 0.03 \
|
||||
--n-steps 1024 \
|
||||
--gamma 0.99 \
|
||||
--death-penalty -50 \
|
||||
--stall-penalty 0.05 \
|
||||
--stall-steps 40 \
|
||||
--backward-penalty-scale 0.01 \
|
||||
--milestone-bonus 2.0 \
|
||||
--no-progress-terminate-steps 80 \
|
||||
--no-progress-terminate-penalty 30 \
|
||||
--total-timesteps 150000
|
||||
--no-progress-terminate-steps 300 \
|
||||
--no-progress-terminate-penalty 10 \
|
||||
--total-timesteps 300000
|
||||
```
|
||||
|
||||
## 8. 常见问题排查
|
||||
@@ -312,6 +474,30 @@ python -m src.train_ppo --init-model-path artifacts/models/latest_model.zip --mo
|
||||
|
||||
3) 或者直接不加载旧模型,从头训练新动作空间。
|
||||
|
||||
### 8.6 `cudaGetDeviceCount ... Error 304`(WSL 下 CUDA 初始化失败)
|
||||
|
||||
如果训练启动时看到:
|
||||
|
||||
```text
|
||||
[device] cpu | CUDA unavailable, using CPU.
|
||||
[device_diag] ... torch.cuda.is_available()=False ... Error 304 ...
|
||||
```
|
||||
|
||||
说明不是 `device` 参数没传,而是 CUDA 运行时在当前环境初始化失败。
|
||||
|
||||
先做两步确认:
|
||||
|
||||
```bash
|
||||
nvidia-smi --query-gpu=name,driver_version,compute_cap --format=csv,noheader
|
||||
.venv/bin/python -c "import torch; print(torch.__version__); print(torch.version.cuda); print(torch.cuda.is_available())"
|
||||
```
|
||||
|
||||
常见原因是 WSL GPU 栈/驱动状态异常,而不是 PPO 代码本身。若你是临时跑通实验,可先显式 CPU:
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo --device cpu
|
||||
```
|
||||
|
||||
## 9. 最小 smoke test(按顺序执行)
|
||||
|
||||
```bash
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
torch==2.5.1
|
||||
# Keep torch unpinned to avoid forcing old wheels on new GPUs (e.g. RTX 50xx).
|
||||
torch>=2.5.1
|
||||
stable-baselines3==2.3.2
|
||||
gym==0.26.2
|
||||
gymnasium==0.29.1
|
||||
|
||||
@@ -212,6 +212,7 @@ class ProgressRewardEnv(gym.Wrapper):
|
||||
truncated = True
|
||||
shaped_reward -= self.no_progress_terminate_penalty
|
||||
info["terminated_by_stall"] = True
|
||||
info["done_reason"] = "no_progress"
|
||||
|
||||
if terminated or truncated:
|
||||
if bool(info.get("flag_get", False)):
|
||||
@@ -222,6 +223,71 @@ class ProgressRewardEnv(gym.Wrapper):
|
||||
return obs, shaped_reward, terminated, truncated, info
|
||||
|
||||
|
||||
class TimePenaltyHardStuckEnv(gym.Wrapper):
|
||||
"""Optional living cost and hard-stuck truncation based on x_pos movement."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
time_penalty: float = 0.0,
|
||||
hard_stuck_steps: int = 0,
|
||||
hard_stuck_epsilon: float = 1.0,
|
||||
hard_stuck_penalty: float = 5.0,
|
||||
):
|
||||
super().__init__(env)
|
||||
if time_penalty > 0.0:
|
||||
raise ValueError(f"time_penalty must be <= 0.0, got {time_penalty}")
|
||||
if hard_stuck_steps < 0:
|
||||
raise ValueError(f"hard_stuck_steps must be >= 0, got {hard_stuck_steps}")
|
||||
if hard_stuck_epsilon < 0.0:
|
||||
raise ValueError(f"hard_stuck_epsilon must be >= 0.0, got {hard_stuck_epsilon}")
|
||||
if hard_stuck_penalty < 0.0:
|
||||
raise ValueError(f"hard_stuck_penalty must be >= 0.0, got {hard_stuck_penalty}")
|
||||
|
||||
self.time_penalty = float(time_penalty)
|
||||
self.hard_stuck_steps = int(hard_stuck_steps)
|
||||
self.hard_stuck_epsilon = float(hard_stuck_epsilon)
|
||||
self.hard_stuck_penalty = float(hard_stuck_penalty)
|
||||
self._last_x_pos: Optional[float] = None
|
||||
self._hard_stuck_count = 0
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
del options
|
||||
obs, info = reset_compat(self.env, seed=seed)
|
||||
self._last_x_pos = float(info.get("x_pos", 0.0))
|
||||
self._hard_stuck_count = 0
|
||||
return obs, info
|
||||
|
||||
def step(self, action: Any):
|
||||
obs, reward, terminated, truncated, info = step_compat(self.env, action)
|
||||
x_pos = float(info.get("x_pos", 0.0))
|
||||
if self._last_x_pos is None:
|
||||
delta_x = 0.0
|
||||
else:
|
||||
delta_x = x_pos - self._last_x_pos
|
||||
self._last_x_pos = x_pos
|
||||
|
||||
shaped_reward = float(reward)
|
||||
if self.time_penalty != 0.0:
|
||||
shaped_reward += self.time_penalty
|
||||
|
||||
if self.hard_stuck_steps > 0 and not terminated and not truncated:
|
||||
if abs(delta_x) < self.hard_stuck_epsilon:
|
||||
self._hard_stuck_count += 1
|
||||
else:
|
||||
self._hard_stuck_count = 0
|
||||
|
||||
if self._hard_stuck_count >= self.hard_stuck_steps:
|
||||
shaped_reward -= self.hard_stuck_penalty
|
||||
truncated = True
|
||||
info["terminated_by_hard_stuck"] = True
|
||||
info["done_reason"] = "hard_stuck"
|
||||
else:
|
||||
self._hard_stuck_count = 0
|
||||
|
||||
return obs, shaped_reward, terminated, truncated, info
|
||||
|
||||
|
||||
def get_action_set(name: str):
|
||||
name = name.lower().strip()
|
||||
if name == "simple":
|
||||
@@ -249,6 +315,10 @@ def make_mario_env(
|
||||
milestone_bonus: float = 1.0,
|
||||
no_progress_terminate_steps: int = 120,
|
||||
no_progress_terminate_penalty: float = 20.0,
|
||||
time_penalty: float = 0.0,
|
||||
hard_stuck_steps: int = 0,
|
||||
hard_stuck_epsilon: float = 1.0,
|
||||
hard_stuck_penalty: float = 5.0,
|
||||
) -> gym.Env:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if render_mode is not None:
|
||||
@@ -284,6 +354,15 @@ def make_mario_env(
|
||||
elif mode != "raw":
|
||||
raise ValueError(f"Unsupported reward_mode='{reward_mode}'. Use one of: raw, clip, progress")
|
||||
|
||||
if time_penalty != 0.0 or hard_stuck_steps > 0:
|
||||
env = TimePenaltyHardStuckEnv(
|
||||
env=env,
|
||||
time_penalty=time_penalty,
|
||||
hard_stuck_steps=hard_stuck_steps,
|
||||
hard_stuck_epsilon=hard_stuck_epsilon,
|
||||
hard_stuck_penalty=hard_stuck_penalty,
|
||||
)
|
||||
|
||||
env = PreprocessFrame(env, width=84, height=84)
|
||||
env = ChannelLastFrameStack(env, num_stack=4)
|
||||
env = TransposeObservation(env)
|
||||
@@ -312,6 +391,10 @@ def make_env_fn(
|
||||
milestone_bonus: float,
|
||||
no_progress_terminate_steps: int,
|
||||
no_progress_terminate_penalty: float,
|
||||
time_penalty: float,
|
||||
hard_stuck_steps: int,
|
||||
hard_stuck_epsilon: float,
|
||||
hard_stuck_penalty: float,
|
||||
) -> Callable[[], gym.Env]:
|
||||
def _thunk() -> gym.Env:
|
||||
return make_mario_env(
|
||||
@@ -332,6 +415,10 @@ def make_env_fn(
|
||||
milestone_bonus=milestone_bonus,
|
||||
no_progress_terminate_steps=no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=no_progress_terminate_penalty,
|
||||
time_penalty=time_penalty,
|
||||
hard_stuck_steps=hard_stuck_steps,
|
||||
hard_stuck_epsilon=hard_stuck_epsilon,
|
||||
hard_stuck_penalty=hard_stuck_penalty,
|
||||
)
|
||||
|
||||
return _thunk
|
||||
|
||||
@@ -6,6 +6,11 @@ from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
from src.policy_utils import select_action
|
||||
from src.runtime import configure_runtime_env
|
||||
|
||||
configure_runtime_env()
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
||||
@@ -32,8 +37,51 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||
parser.add_argument(
|
||||
"--time-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Per-step living cost added to reward (<=0.0). 0.0 disables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard-stuck-steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Consecutive near-stationary steps before truncation. 0 disables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard-stuck-epsilon",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Treat |x_pos delta| < epsilon as no movement for hard-stuck detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard-stuck-penalty",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Extra penalty applied when hard-stuck truncation triggers.",
|
||||
)
|
||||
parser.add_argument("--clip-reward", action="store_true")
|
||||
parser.add_argument("--stochastic", action="store_true", help="Use stochastic policy (deterministic=False).")
|
||||
parser.add_argument(
|
||||
"--epsilon",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Epsilon-greedy probability for deterministic policy. Ignored when --stochastic is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epsilon-random-mode",
|
||||
type=str,
|
||||
default="uniform",
|
||||
choices=["uniform", "policy"],
|
||||
help="Random action source for epsilon-greedy: uniform=action_space.sample, policy=model stochastic sample.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-noops",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Random no-op steps after reset (0 disables). Uses action 0 as NOOP.",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -87,6 +135,8 @@ def resolve_model_path(user_path: str) -> Path:
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.epsilon < 0.0 or args.epsilon > 1.0:
|
||||
raise ValueError(f"--epsilon must be in [0, 1], got {args.epsilon}")
|
||||
seed_everything(args.seed)
|
||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||
|
||||
@@ -94,7 +144,12 @@ def main() -> None:
|
||||
print(f"[eval] model={model_path}")
|
||||
model = PPO.load(str(model_path))
|
||||
movement = resolve_movement(args.movement, model)
|
||||
print(f"[eval] movement={movement} reward_mode={reward_mode}")
|
||||
print(
|
||||
f"[eval] movement={movement} reward_mode={reward_mode} random_noops={args.random_noops} "
|
||||
f"time_penalty={args.time_penalty} hard_stuck_steps={args.hard_stuck_steps} "
|
||||
f"hard_stuck_epsilon={args.hard_stuck_epsilon} hard_stuck_penalty={args.hard_stuck_penalty} "
|
||||
f"epsilon={args.epsilon} epsilon_random_mode={args.epsilon_random_mode}"
|
||||
)
|
||||
|
||||
env = make_mario_env(
|
||||
env_id=args.env_id,
|
||||
@@ -114,6 +169,10 @@ def main() -> None:
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
time_penalty=args.time_penalty,
|
||||
hard_stuck_steps=args.hard_stuck_steps,
|
||||
hard_stuck_epsilon=args.hard_stuck_epsilon,
|
||||
hard_stuck_penalty=args.hard_stuck_penalty,
|
||||
)
|
||||
|
||||
rewards = []
|
||||
@@ -122,6 +181,12 @@ def main() -> None:
|
||||
|
||||
for ep in range(1, args.episodes + 1):
|
||||
obs, info = reset_compat(env, seed=args.seed + ep)
|
||||
if args.random_noops > 0:
|
||||
noop_steps = np.random.randint(0, args.random_noops + 1)
|
||||
for _ in range(noop_steps):
|
||||
obs, _, terminated, truncated, info = step_compat(env, 0)
|
||||
if terminated or truncated:
|
||||
obs, info = reset_compat(env, seed=args.seed + ep + 1000)
|
||||
done = False
|
||||
ep_reward = 0.0
|
||||
ep_max_x = float(info.get("x_pos", 0.0))
|
||||
@@ -129,9 +194,14 @@ def main() -> None:
|
||||
step_count = 0
|
||||
|
||||
while not done and step_count < args.max_steps:
|
||||
action, _ = model.predict(obs, deterministic=not args.stochastic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = int(action.item())
|
||||
action = select_action(
|
||||
model=model,
|
||||
obs=obs,
|
||||
deterministic=not args.stochastic,
|
||||
epsilon=args.epsilon,
|
||||
epsilon_random_mode=args.epsilon_random_mode,
|
||||
env=env,
|
||||
)
|
||||
obs, reward, terminated, truncated, info = step_compat(env, action)
|
||||
ep_reward += float(reward)
|
||||
ep_max_x = max(ep_max_x, float(info.get("x_pos", 0.0)))
|
||||
|
||||
677
mario-rl-mvp/src/plot_model_max_x_trend.py
Normal file
677
mario-rl-mvp/src/plot_model_max_x_trend.py
Normal file
@@ -0,0 +1,677 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import html
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
|
||||
|
||||
from src.utils import ensure_artifact_paths
|
||||
|
||||
RUN_TS_PATTERN = re.compile(r"(\d{8}_\d{6})$")
|
||||
CKPT_STEPS_PATTERN = re.compile(r"ppo_mario_ckpt_(\d+)_steps\.zip$")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScalarPoint:
|
||||
step: int
|
||||
value: float
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class MetricSpec:
|
||||
key: str
|
||||
label: str
|
||||
tag_candidates: Tuple[str, ...]
|
||||
meaning: str
|
||||
scale: float = 1.0
|
||||
decimals: int = 2
|
||||
suffix: str = ""
|
||||
|
||||
|
||||
METRIC_SPECS: List[MetricSpec] = [
|
||||
MetricSpec(
|
||||
key="max_x",
|
||||
label="最大前进距离",
|
||||
tag_candidates=("rollout/episode_max_x_pos", "episode_end/episode_max_x_pos"),
|
||||
meaning="马里奥在单回合内到达的最远 x 坐标,越大通常代表策略能跑得更远。",
|
||||
decimals=1,
|
||||
),
|
||||
MetricSpec(
|
||||
key="ep_rew",
|
||||
label="平均回报",
|
||||
tag_candidates=("rollout/ep_rew_mean",),
|
||||
meaning="训练窗口内的平均 episode reward,会随 reward shaping 配置变化。",
|
||||
decimals=2,
|
||||
),
|
||||
MetricSpec(
|
||||
key="ep_len",
|
||||
label="平均回合步数",
|
||||
tag_candidates=("rollout/ep_len_mean",),
|
||||
meaning="训练窗口内平均每回合步数;若步数高但 max_x 不涨,常见于卡住。",
|
||||
decimals=1,
|
||||
),
|
||||
MetricSpec(
|
||||
key="clear_rate",
|
||||
label="通关率",
|
||||
tag_candidates=("rollout/flag_get", "episode_end/flag_get"),
|
||||
meaning="到达旗帜终点的比例。0 表示从未通关,1 表示全部通关。",
|
||||
scale=100.0,
|
||||
decimals=1,
|
||||
suffix="%",
|
||||
),
|
||||
MetricSpec(
|
||||
key="no_progress_rate",
|
||||
label="无进展终止率",
|
||||
tag_candidates=("rollout/done_reason_no_progress", "episode_end/done_reason_no_progress"),
|
||||
meaning="因 no-progress 规则提前结束的比例,越高说明越容易卡住。",
|
||||
scale=100.0,
|
||||
decimals=1,
|
||||
suffix="%",
|
||||
),
|
||||
MetricSpec(
|
||||
key="death_rate",
|
||||
label="死亡终止率",
|
||||
tag_candidates=("rollout/done_reason_death", "episode_end/done_reason_death"),
|
||||
meaning="因死亡结束 episode 的比例。",
|
||||
scale=100.0,
|
||||
decimals=1,
|
||||
suffix="%",
|
||||
),
|
||||
MetricSpec(
|
||||
key="timeout_rate",
|
||||
label="超时终止率",
|
||||
tag_candidates=("rollout/done_reason_timeout", "episode_end/done_reason_timeout"),
|
||||
meaning="因时间耗尽或 TimeLimit 截断的比例。",
|
||||
scale=100.0,
|
||||
decimals=1,
|
||||
suffix="%",
|
||||
),
|
||||
MetricSpec(
|
||||
key="hard_stuck_rate",
|
||||
label="硬卡死终止率",
|
||||
tag_candidates=("rollout/done_reason_hard_stuck", "episode_end/done_reason_hard_stuck"),
|
||||
meaning="触发 hard-stuck 截断的比例(仅在启用 hard_stuck 后有数据)。",
|
||||
scale=100.0,
|
||||
decimals=1,
|
||||
suffix="%",
|
||||
),
|
||||
]
|
||||
METRIC_BY_KEY: Dict[str, MetricSpec] = {spec.key: spec for spec in METRIC_SPECS}
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelItem:
|
||||
run_name: str
|
||||
run_time: dt.datetime
|
||||
model_path: Path
|
||||
model_kind: str
|
||||
model_step: Optional[int]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelMetric:
|
||||
index: int
|
||||
run_name: str
|
||||
model_file: str
|
||||
model_path: str
|
||||
model_kind: str
|
||||
model_step: Optional[int]
|
||||
metric_values: Dict[str, Optional[float]]
|
||||
metric_steps: Dict[str, Optional[int]]
|
||||
metric_tags: Dict[str, Optional[str]]
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="可视化 artifacts/models 中模型在多维指标上的变化趋势,输出中文 HTML 或 Markdown 报告。"
|
||||
)
|
||||
parser.add_argument("--models-dir", type=str, default="", help="模型目录,默认 artifacts/models")
|
||||
parser.add_argument("--logs-dir", type=str, default="", help="日志目录,默认 artifacts/logs")
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
default="html",
|
||||
choices=("html", "markdown"),
|
||||
help="输出格式:html 或 markdown",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="",
|
||||
help="输出报告路径。默认 html: artifacts/reports/model_max_x_trend.html,markdown: artifacts/reports/model_max_x_trend.md",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _parse_run_time(run_name: str, fallback: float) -> dt.datetime:
|
||||
match = RUN_TS_PATTERN.search(run_name)
|
||||
if match:
|
||||
try:
|
||||
return dt.datetime.strptime(match.group(1), "%Y%m%d_%H%M%S")
|
||||
except ValueError:
|
||||
pass
|
||||
return dt.datetime.fromtimestamp(fallback)
|
||||
|
||||
|
||||
def _list_model_items(models_dir: Path) -> List[ModelItem]:
|
||||
items: List[ModelItem] = []
|
||||
for run_dir in sorted(path for path in models_dir.iterdir() if path.is_dir()):
|
||||
run_name = run_dir.name
|
||||
run_time = _parse_run_time(run_name, run_dir.stat().st_mtime)
|
||||
|
||||
zip_files = sorted(path for path in run_dir.glob("*.zip") if path.is_file())
|
||||
staged: List[Tuple[int, str, Path, Optional[int], str]] = []
|
||||
for zip_path in zip_files:
|
||||
name = zip_path.name
|
||||
ckpt_match = CKPT_STEPS_PATTERN.match(name)
|
||||
if ckpt_match:
|
||||
step = int(ckpt_match.group(1))
|
||||
staged.append((0, f"{step:012d}", zip_path, step, "checkpoint"))
|
||||
continue
|
||||
if name == "final_model.zip":
|
||||
staged.append((1, "final_model", zip_path, None, "final"))
|
||||
continue
|
||||
staged.append((2, name, zip_path, None, "other"))
|
||||
|
||||
for _, _, model_path, model_step, model_kind in sorted(staged, key=lambda row: (row[0], row[1])):
|
||||
items.append(
|
||||
ModelItem(
|
||||
run_name=run_name,
|
||||
run_time=run_time,
|
||||
model_path=model_path,
|
||||
model_kind=model_kind,
|
||||
model_step=model_step,
|
||||
)
|
||||
)
|
||||
|
||||
items.sort(
|
||||
key=lambda item: (
|
||||
item.run_time,
|
||||
10**18 if item.model_step is None and item.model_kind == "final" else (item.model_step or 0),
|
||||
item.model_path.name,
|
||||
)
|
||||
)
|
||||
return items
|
||||
|
||||
|
||||
def _load_scalar_map_from_dir(event_dir: Path) -> Dict[str, List[ScalarPoint]]:
|
||||
if not event_dir.exists():
|
||||
return {}
|
||||
event_files = list(event_dir.glob("events.out.tfevents.*"))
|
||||
if not event_files:
|
||||
return {}
|
||||
|
||||
accumulator = EventAccumulator(str(event_dir), size_guidance={"scalars": 0})
|
||||
try:
|
||||
accumulator.Reload()
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
scalar_map: Dict[str, List[ScalarPoint]] = {}
|
||||
for tag in accumulator.Tags().get("scalars", []):
|
||||
try:
|
||||
points = [ScalarPoint(step=int(s.step), value=float(s.value)) for s in accumulator.Scalars(tag)]
|
||||
except Exception:
|
||||
continue
|
||||
if points:
|
||||
scalar_map[tag] = points
|
||||
return scalar_map
|
||||
|
||||
|
||||
def _load_run_scalar_map(logs_dir: Path, run_name: str) -> Dict[str, List[ScalarPoint]]:
|
||||
run_dir = logs_dir / run_name
|
||||
merged: Dict[str, List[ScalarPoint]] = {}
|
||||
|
||||
for sub in (("tb", "ppo_1"), ("tb", "episode_end")):
|
||||
event_dir = run_dir.joinpath(*sub)
|
||||
part = _load_scalar_map_from_dir(event_dir)
|
||||
for tag, points in part.items():
|
||||
if tag not in merged or len(points) > len(merged[tag]):
|
||||
merged[tag] = points
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _value_at_or_before(points: Sequence[ScalarPoint], target_step: int) -> Tuple[float, int]:
|
||||
best = points[0]
|
||||
for point in points:
|
||||
if point.step <= target_step:
|
||||
best = point
|
||||
else:
|
||||
break
|
||||
return best.value, best.step
|
||||
|
||||
|
||||
def _pick_series_for_metric(
|
||||
scalar_map: Dict[str, List[ScalarPoint]],
|
||||
spec: MetricSpec,
|
||||
) -> Tuple[Optional[str], Optional[List[ScalarPoint]]]:
|
||||
for tag in spec.tag_candidates:
|
||||
points = scalar_map.get(tag)
|
||||
if points:
|
||||
return tag, points
|
||||
return None, None
|
||||
|
||||
|
||||
def _resolve_model_metric(item: ModelItem, scalar_map: Dict[str, List[ScalarPoint]], index: int) -> ModelMetric:
|
||||
metric_values: Dict[str, Optional[float]] = {}
|
||||
metric_steps: Dict[str, Optional[int]] = {}
|
||||
metric_tags: Dict[str, Optional[str]] = {}
|
||||
|
||||
for spec in METRIC_SPECS:
|
||||
tag, points = _pick_series_for_metric(scalar_map, spec)
|
||||
metric_tags[spec.key] = tag
|
||||
if not points:
|
||||
metric_values[spec.key] = None
|
||||
metric_steps[spec.key] = None
|
||||
continue
|
||||
|
||||
if item.model_step is None:
|
||||
point = points[-1]
|
||||
metric_values[spec.key] = point.value
|
||||
metric_steps[spec.key] = point.step
|
||||
continue
|
||||
|
||||
value, matched_step = _value_at_or_before(points, item.model_step)
|
||||
metric_values[spec.key] = value
|
||||
metric_steps[spec.key] = matched_step
|
||||
|
||||
return ModelMetric(
|
||||
index=index,
|
||||
run_name=item.run_name,
|
||||
model_file=item.model_path.name,
|
||||
model_path=str(item.model_path.resolve()),
|
||||
model_kind=item.model_kind,
|
||||
model_step=item.model_step,
|
||||
metric_values=metric_values,
|
||||
metric_steps=metric_steps,
|
||||
metric_tags=metric_tags,
|
||||
)
|
||||
|
||||
|
||||
def _scaled_value(metric: ModelMetric, key: str) -> Optional[float]:
|
||||
spec = METRIC_BY_KEY[key]
|
||||
raw = metric.metric_values.get(key)
|
||||
if raw is None:
|
||||
return None
|
||||
return raw * spec.scale
|
||||
|
||||
|
||||
def _format_value(metric: ModelMetric, key: str) -> str:
|
||||
spec = METRIC_BY_KEY[key]
|
||||
scaled = _scaled_value(metric, key)
|
||||
if scaled is None:
|
||||
return ""
|
||||
return f"{scaled:.{spec.decimals}f}{spec.suffix}"
|
||||
|
||||
|
||||
def _render_metric_svg(metrics: Sequence[ModelMetric], key: str, width: int, height: int) -> str:
|
||||
spec = METRIC_BY_KEY[key]
|
||||
valid: List[Tuple[int, float]] = []
|
||||
for idx, metric in enumerate(metrics):
|
||||
value = _scaled_value(metric, key)
|
||||
if value is not None:
|
||||
valid.append((idx, value))
|
||||
|
||||
if not valid:
|
||||
return f'<div class="chart-empty">{html.escape(spec.label)}:暂无可用数据</div>'
|
||||
|
||||
margin_left = 58
|
||||
margin_right = 22
|
||||
margin_top = 18
|
||||
margin_bottom = 48
|
||||
plot_w = width - margin_left - margin_right
|
||||
plot_h = height - margin_top - margin_bottom
|
||||
|
||||
y_values = [v for _, v in valid]
|
||||
y_min = min(0.0, min(y_values))
|
||||
y_max = max(y_values)
|
||||
if y_max <= y_min:
|
||||
y_max = y_min + 1.0
|
||||
|
||||
def sx(index: int) -> float:
|
||||
if len(metrics) == 1:
|
||||
return margin_left + plot_w / 2.0
|
||||
return margin_left + (index / (len(metrics) - 1)) * plot_w
|
||||
|
||||
def sy(value: float) -> float:
|
||||
ratio = (value - y_min) / (y_max - y_min)
|
||||
return margin_top + (1.0 - ratio) * plot_h
|
||||
|
||||
y_grid_count = 4
|
||||
y_grid: List[str] = []
|
||||
for i in range(y_grid_count + 1):
|
||||
ratio = i / y_grid_count
|
||||
y = margin_top + ratio * plot_h
|
||||
value = y_max - ratio * (y_max - y_min)
|
||||
y_grid.append(
|
||||
f'<line x1="{margin_left}" y1="{y:.2f}" x2="{margin_left + plot_w}" y2="{y:.2f}" '
|
||||
'stroke="#e5e7eb" stroke-width="1" />'
|
||||
)
|
||||
y_grid.append(
|
||||
f'<text x="{margin_left - 6}" y="{y + 4:.2f}" text-anchor="end" font-size="10" fill="#4b5563">'
|
||||
f"{value:.1f}{html.escape(spec.suffix)}</text>"
|
||||
)
|
||||
|
||||
points_attr = " ".join(f"{sx(idx):.2f},{sy(value):.2f}" for idx, value in valid)
|
||||
polyline = (
|
||||
f'<polyline points="{points_attr}" fill="none" stroke="#2563eb" stroke-width="2.2" '
|
||||
'stroke-linejoin="round" stroke-linecap="round" />'
|
||||
)
|
||||
|
||||
point_nodes: List[str] = []
|
||||
for idx, value in valid:
|
||||
metric = metrics[idx]
|
||||
point_nodes.append(
|
||||
f'<circle cx="{sx(idx):.2f}" cy="{sy(value):.2f}" r="3" fill="#1d4ed8">'
|
||||
f"<title>{html.escape(spec.label)}={value:.2f}{spec.suffix} | 序号={metric.index} | {metric.model_file}</title>"
|
||||
"</circle>"
|
||||
)
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
f'<div class="chart-title">{html.escape(spec.label)}趋势</div>',
|
||||
f'<svg width="{width}" height="{height}" viewBox="0 0 {width} {height}">',
|
||||
"<rect x=\"0\" y=\"0\" width=\"100%\" height=\"100%\" fill=\"#ffffff\" />",
|
||||
*y_grid,
|
||||
f'<line x1="{margin_left}" y1="{margin_top + plot_h}" x2="{margin_left + plot_w}" y2="{margin_top + plot_h}" '
|
||||
'stroke="#111827" stroke-width="1.2" />',
|
||||
f'<line x1="{margin_left}" y1="{margin_top}" x2="{margin_left}" y2="{margin_top + plot_h}" '
|
||||
'stroke="#111827" stroke-width="1.2" />',
|
||||
polyline,
|
||||
*point_nodes,
|
||||
f'<text x="{margin_left + plot_w / 2:.2f}" y="{height - 14}" text-anchor="middle" font-size="11" fill="#111827">模型序号(按时间排序)</text>',
|
||||
"</svg>",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _metric_payload(metric: ModelMetric) -> Dict[str, object]:
|
||||
payload: Dict[str, object] = {
|
||||
"index": metric.index,
|
||||
"run_name": metric.run_name,
|
||||
"model_file": metric.model_file,
|
||||
"model_path": metric.model_path,
|
||||
"model_kind": metric.model_kind,
|
||||
"model_step": metric.model_step,
|
||||
}
|
||||
for spec in METRIC_SPECS:
|
||||
payload[spec.key] = metric.metric_values.get(spec.key)
|
||||
payload[f"{spec.key}_matched_step"] = metric.metric_steps.get(spec.key)
|
||||
payload[f"{spec.key}_source_tag"] = metric.metric_tags.get(spec.key)
|
||||
return payload
|
||||
|
||||
|
||||
def _render_table(metrics: Sequence[ModelMetric]) -> str:
|
||||
header = (
|
||||
"<tr>"
|
||||
"<th>序号</th><th>训练Run</th><th>模型文件</th><th>类型</th><th>checkpoint步数</th>"
|
||||
"<th>最大前进距离</th><th>平均回报</th><th>平均回合步数</th><th>通关率</th>"
|
||||
"<th>无进展率</th><th>死亡率</th><th>超时率</th><th>硬卡死率</th>"
|
||||
"<th>max_x匹配步</th><th>max_x来源Tag</th><th>模型路径</th>"
|
||||
"</tr>"
|
||||
)
|
||||
|
||||
rows: List[str] = []
|
||||
for metric in metrics:
|
||||
rows.append(
|
||||
"<tr>"
|
||||
f"<td>{metric.index}</td>"
|
||||
f"<td>{html.escape(metric.run_name)}</td>"
|
||||
f"<td>{html.escape(metric.model_file)}</td>"
|
||||
f"<td>{html.escape(metric.model_kind)}</td>"
|
||||
f"<td>{'' if metric.model_step is None else metric.model_step}</td>"
|
||||
f"<td>{html.escape(_format_value(metric, 'max_x'))}</td>"
|
||||
f"<td>{html.escape(_format_value(metric, 'ep_rew'))}</td>"
|
||||
f"<td>{html.escape(_format_value(metric, 'ep_len'))}</td>"
|
||||
f"<td>{html.escape(_format_value(metric, 'clear_rate'))}</td>"
|
||||
f"<td>{html.escape(_format_value(metric, 'no_progress_rate'))}</td>"
|
||||
f"<td>{html.escape(_format_value(metric, 'death_rate'))}</td>"
|
||||
f"<td>{html.escape(_format_value(metric, 'timeout_rate'))}</td>"
|
||||
f"<td>{html.escape(_format_value(metric, 'hard_stuck_rate'))}</td>"
|
||||
f"<td>{'' if metric.metric_steps.get('max_x') is None else metric.metric_steps.get('max_x')}</td>"
|
||||
f"<td>{html.escape(metric.metric_tags.get('max_x') or '')}</td>"
|
||||
f"<td><code>{html.escape(metric.model_path)}</code></td>"
|
||||
"</tr>"
|
||||
)
|
||||
|
||||
return f"<table>{header}{''.join(rows)}</table>"
|
||||
|
||||
|
||||
def _glossary_items() -> List[Tuple[str, str]]:
|
||||
return [
|
||||
("Run", "一次完整训练实验,目录名通常形如 ppo_..._YYYYMMDD_HHMMSS。"),
|
||||
("Checkpoint", "训练中间保存的模型快照,例如 ppo_mario_ckpt_100000_steps.zip。"),
|
||||
("Final Model", "训练结束时保存的最终模型(final_model.zip)。"),
|
||||
("model_step", "从 checkpoint 文件名解析出的训练步数。final_model 没有固定 step。"),
|
||||
("matched_step", "从 TensorBoard 曲线中用于取值的实际步数(不超过 model_step 的最近点)。"),
|
||||
("TensorBoard Tag", "标量指标的键名,例如 rollout/episode_max_x_pos。"),
|
||||
("max_x", "最远横坐标,表示策略在关卡中的推进深度。"),
|
||||
("ep_rew_mean", "平均回报,受 reward shaping 配置影响。"),
|
||||
("ep_len_mean", "平均回合步数。"),
|
||||
("flag_get / clear_rate", "通关比例(0~1,页面按百分比显示)。"),
|
||||
("done_reason_*", "不同终止原因的比例(例如 no_progress、death、timeout、hard_stuck)。"),
|
||||
]
|
||||
|
||||
|
||||
def _render_glossary() -> str:
|
||||
rows = "".join(f"<li><b>{html.escape(name)}</b>:{html.escape(desc)}</li>" for name, desc in _glossary_items())
|
||||
return f"<ul>{rows}</ul>"
|
||||
|
||||
|
||||
def _render_glossary_markdown() -> str:
|
||||
return "\n".join(f"- **{name}**:{desc}" for name, desc in _glossary_items())
|
||||
|
||||
|
||||
def _summarize_metrics(
|
||||
metrics: Sequence[ModelMetric],
|
||||
) -> Tuple[int, int, Optional[float], Optional[ModelMetric], Optional[ModelMetric]]:
|
||||
total = len(metrics)
|
||||
with_max_x = sum(1 for metric in metrics if metric.metric_values.get("max_x") is not None)
|
||||
|
||||
best_max_x = None
|
||||
best_entry = None
|
||||
for metric in metrics:
|
||||
value = metric.metric_values.get("max_x")
|
||||
if value is None:
|
||||
continue
|
||||
if best_max_x is None or value > best_max_x:
|
||||
best_max_x = value
|
||||
best_entry = metric
|
||||
|
||||
latest_with_max = next((metric for metric in reversed(metrics) if metric.metric_values.get("max_x") is not None), None)
|
||||
return total, with_max_x, best_max_x, best_entry, latest_with_max
|
||||
|
||||
|
||||
def _build_best_text(best_max_x: Optional[float], best_entry: Optional[ModelMetric]) -> str:
|
||||
if best_entry is None or best_max_x is None:
|
||||
return "无"
|
||||
return f"{best_max_x:.1f}(序号#{best_entry.index},{best_entry.model_file})"
|
||||
|
||||
|
||||
def _build_latest_text(latest_with_max: Optional[ModelMetric]) -> str:
|
||||
if latest_with_max is None:
|
||||
return "无"
|
||||
return f"{_format_value(latest_with_max, 'max_x')}(序号#{latest_with_max.index},{latest_with_max.model_file})"
|
||||
|
||||
|
||||
def _md_escape(text: str) -> str:
|
||||
return text.replace("|", "\\|").replace("\n", " ")
|
||||
|
||||
|
||||
def build_html(metrics: Sequence[ModelMetric], models_dir: Path, logs_dir: Path) -> str:
|
||||
total, with_max_x, best_max_x, best_entry, latest_with_max = _summarize_metrics(metrics)
|
||||
|
||||
best_text = _build_best_text(best_max_x=best_max_x, best_entry=best_entry)
|
||||
latest_text = _build_latest_text(latest_with_max=latest_with_max)
|
||||
|
||||
main_chart = _render_metric_svg(metrics, key="max_x", width=1260, height=390)
|
||||
extra_keys = ["ep_rew", "ep_len", "clear_rate", "no_progress_rate", "death_rate", "timeout_rate"]
|
||||
extra_charts = "".join(
|
||||
f'<div class="mini-chart card">{_render_metric_svg(metrics, key=key, width=610, height=250)}</div>' for key in extra_keys
|
||||
)
|
||||
|
||||
table = _render_table(metrics)
|
||||
glossary = _render_glossary()
|
||||
payload = json.dumps([_metric_payload(metric) for metric in metrics], ensure_ascii=False, indent=2)
|
||||
|
||||
return f"""<!doctype html>
|
||||
<html lang=\"zh-CN\">
|
||||
<head>
|
||||
<meta charset=\"utf-8\" />
|
||||
<meta name=\"viewport\" content=\"width=device-width, initial-scale=1\" />
|
||||
<title>模型趋势报告(max_x + 多维指标)</title>
|
||||
<style>
|
||||
body {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", "PingFang SC", "Microsoft YaHei", sans-serif; margin: 24px; color: #111827; background: #f8fafc; }}
|
||||
h1 {{ margin: 0 0 12px; }}
|
||||
h2 {{ margin: 6px 0 10px; font-size: 18px; }}
|
||||
.meta {{ color: #374151; margin-bottom: 14px; line-height: 1.6; }}
|
||||
.card {{ border: 1px solid #dbe2ea; border-radius: 12px; padding: 12px; background: #fff; margin-bottom: 14px; overflow-x: auto; }}
|
||||
.chart-title {{ font-size: 14px; font-weight: 600; margin-bottom: 6px; color: #0f172a; }}
|
||||
.chart-empty {{ padding: 20px; color: #64748b; }}
|
||||
.grid {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(620px, 1fr)); gap: 14px; }}
|
||||
table {{ border-collapse: collapse; width: 100%; margin-top: 12px; font-size: 12px; }}
|
||||
th, td {{ border: 1px solid #e2e8f0; padding: 6px 8px; text-align: left; vertical-align: top; }}
|
||||
th {{ background: #f8fafc; position: sticky; top: 0; z-index: 1; }}
|
||||
code {{ font-size: 11px; }}
|
||||
details {{ margin-top: 12px; }}
|
||||
ul {{ margin: 8px 0 0 20px; line-height: 1.6; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<h1>模型趋势报告(max_x + 多维指标)</h1>
|
||||
<div class=\"meta\">
|
||||
<div>模型目录:<code>{html.escape(str(models_dir.resolve()))}</code></div>
|
||||
<div>日志目录:<code>{html.escape(str(logs_dir.resolve()))}</code></div>
|
||||
<div>模型总数:<b>{total}</b>;有 max_x 数据:<b>{with_max_x}</b></div>
|
||||
<div>历史最佳 max_x:<b>{html.escape(best_text)}</b></div>
|
||||
<div>最近模型 max_x:<b>{html.escape(latest_text)}</b></div>
|
||||
<div>说明:模型序号按时间排序(run 时间 + checkpoint 步数)。鼠标悬停图中点可看细节。</div>
|
||||
</div>
|
||||
|
||||
<div class=\"card\">
|
||||
<h2>主趋势:最大前进距离(max_x)</h2>
|
||||
{main_chart}
|
||||
</div>
|
||||
|
||||
<div class=\"grid\">{extra_charts}</div>
|
||||
|
||||
<div class=\"card\">
|
||||
<h2>模型明细表</h2>
|
||||
{table}
|
||||
</div>
|
||||
|
||||
<div class=\"card\">
|
||||
<h2>术语解释(专有名词)</h2>
|
||||
{glossary}
|
||||
</div>
|
||||
|
||||
<details class=\"card\">
|
||||
<summary>原始数据(JSON)</summary>
|
||||
<pre>{html.escape(payload)}</pre>
|
||||
</details>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def build_markdown(metrics: Sequence[ModelMetric], models_dir: Path, logs_dir: Path) -> str:
|
||||
total, with_max_x, best_max_x, best_entry, latest_with_max = _summarize_metrics(metrics)
|
||||
best_text = _build_best_text(best_max_x=best_max_x, best_entry=best_entry)
|
||||
latest_text = _build_latest_text(latest_with_max=latest_with_max)
|
||||
glossary = _render_glossary_markdown()
|
||||
payload = json.dumps([_metric_payload(metric) for metric in metrics], ensure_ascii=False, indent=2)
|
||||
|
||||
lines: List[str] = []
|
||||
lines.append("# 模型趋势报告(max_x + 多维指标)")
|
||||
lines.append("")
|
||||
lines.append("## 概览")
|
||||
lines.append(f"- 模型目录:`{models_dir.resolve()}`")
|
||||
lines.append(f"- 日志目录:`{logs_dir.resolve()}`")
|
||||
lines.append(f"- 模型总数:**{total}**")
|
||||
lines.append(f"- 含 max_x 数据:**{with_max_x}**")
|
||||
lines.append(f"- 历史最佳 max_x:**{best_text}**")
|
||||
lines.append(f"- 最近模型 max_x:**{latest_text}**")
|
||||
lines.append("- 说明:模型序号按时间排序(run 时间 + checkpoint 步数)。")
|
||||
lines.append("")
|
||||
lines.append("## 指标说明")
|
||||
for spec in METRIC_SPECS:
|
||||
lines.append(f"- **{spec.label}**(`{spec.key}`):{spec.meaning}")
|
||||
lines.append("")
|
||||
lines.append("## 模型明细")
|
||||
lines.append(
|
||||
"| 序号 | Run | 模型文件 | 类型 | checkpoint步数 | 最大前进距离 | 平均回报 | 平均回合步数 | 通关率 | 无进展率 | 死亡率 | 超时率 | 硬卡死率 | max_x匹配步 | max_x来源Tag |"
|
||||
)
|
||||
lines.append("|---:|---|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---|")
|
||||
for metric in metrics:
|
||||
lines.append(
|
||||
"| "
|
||||
f"{metric.index} | "
|
||||
f"{_md_escape(metric.run_name)} | "
|
||||
f"{_md_escape(metric.model_file)} | "
|
||||
f"{metric.model_kind} | "
|
||||
f"{'' if metric.model_step is None else metric.model_step} | "
|
||||
f"{_format_value(metric, 'max_x')} | "
|
||||
f"{_format_value(metric, 'ep_rew')} | "
|
||||
f"{_format_value(metric, 'ep_len')} | "
|
||||
f"{_format_value(metric, 'clear_rate')} | "
|
||||
f"{_format_value(metric, 'no_progress_rate')} | "
|
||||
f"{_format_value(metric, 'death_rate')} | "
|
||||
f"{_format_value(metric, 'timeout_rate')} | "
|
||||
f"{_format_value(metric, 'hard_stuck_rate')} | "
|
||||
f"{'' if metric.metric_steps.get('max_x') is None else metric.metric_steps.get('max_x')} | "
|
||||
f"{_md_escape(metric.metric_tags.get('max_x') or '')} |"
|
||||
)
|
||||
lines.append("")
|
||||
lines.append("## 术语解释(专有名词)")
|
||||
lines.append(glossary)
|
||||
lines.append("")
|
||||
lines.append("## 原始数据(JSON)")
|
||||
lines.append("```json")
|
||||
lines.append(payload)
|
||||
lines.append("```")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
paths = ensure_artifact_paths()
|
||||
|
||||
models_dir = Path(args.models_dir).expanduser().resolve() if args.models_dir else paths.models.resolve()
|
||||
logs_dir = Path(args.logs_dir).expanduser().resolve() if args.logs_dir else paths.logs.resolve()
|
||||
default_name = "model_max_x_trend.md" if args.format == "markdown" else "model_max_x_trend.html"
|
||||
output_path = (
|
||||
Path(args.output).expanduser().resolve()
|
||||
if args.output
|
||||
else (paths.root / "reports" / default_name).resolve()
|
||||
)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model_items = _list_model_items(models_dir)
|
||||
scalar_cache: Dict[str, Dict[str, List[ScalarPoint]]] = {}
|
||||
metrics: List[ModelMetric] = []
|
||||
|
||||
for idx, item in enumerate(model_items, start=1):
|
||||
if item.run_name not in scalar_cache:
|
||||
scalar_cache[item.run_name] = _load_run_scalar_map(logs_dir=logs_dir, run_name=item.run_name)
|
||||
metrics.append(_resolve_model_metric(item=item, scalar_map=scalar_cache[item.run_name], index=idx))
|
||||
|
||||
if args.format == "markdown":
|
||||
content = build_markdown(metrics=metrics, models_dir=models_dir, logs_dir=logs_dir)
|
||||
else:
|
||||
content = build_html(metrics=metrics, models_dir=models_dir, logs_dir=logs_dir)
|
||||
output_path.write_text(content, encoding="utf-8")
|
||||
|
||||
with_max_x = sum(1 for metric in metrics if metric.metric_values.get("max_x") is not None)
|
||||
print(f"[趋势] 报告已生成:{output_path} (format={args.format})")
|
||||
print(f"[趋势] 模型总数={len(metrics)},含max_x数据={with_max_x}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
53
mario-rl-mvp/src/policy_utils.py
Normal file
53
mario-rl-mvp/src/policy_utils.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def _as_discrete_action(action: Any) -> int:
|
||||
if isinstance(action, np.ndarray):
|
||||
return int(action.item())
|
||||
if isinstance(action, np.generic):
|
||||
return int(action.item())
|
||||
try:
|
||||
return int(action)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Cannot convert action to int: {action!r}") from exc
|
||||
|
||||
|
||||
def select_action(
|
||||
model: Any,
|
||||
obs: Any,
|
||||
deterministic: bool,
|
||||
epsilon: float = 0.0,
|
||||
epsilon_random_mode: str = "uniform",
|
||||
env: Optional[Any] = None,
|
||||
rng: Optional[np.random.Generator] = None,
|
||||
) -> int:
|
||||
"""Select action with optional epsilon-greedy for deterministic inference.
|
||||
|
||||
When deterministic=False (stochastic policy), epsilon is ignored by design.
|
||||
"""
|
||||
|
||||
if epsilon < 0.0 or epsilon > 1.0:
|
||||
raise ValueError(f"epsilon must be in [0, 1], got {epsilon}")
|
||||
|
||||
mode = epsilon_random_mode.lower().strip()
|
||||
if mode not in {"uniform", "policy"}:
|
||||
raise ValueError(f"Unsupported epsilon_random_mode={epsilon_random_mode!r}. Use 'uniform' or 'policy'.")
|
||||
|
||||
effective_epsilon = epsilon if deterministic else 0.0
|
||||
if effective_epsilon > 0.0:
|
||||
draw = float(rng.random()) if rng is not None else float(np.random.random())
|
||||
if draw < effective_epsilon:
|
||||
if mode == "uniform":
|
||||
if env is None or not hasattr(env, "action_space"):
|
||||
raise ValueError("uniform epsilon_random_mode requires env.action_space.sample().")
|
||||
return _as_discrete_action(env.action_space.sample())
|
||||
action, _ = model.predict(obs, deterministic=False)
|
||||
return _as_discrete_action(action)
|
||||
|
||||
action, _ = model.predict(obs, deterministic=deterministic)
|
||||
return _as_discrete_action(action)
|
||||
|
||||
@@ -5,7 +5,11 @@ from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import imageio.v2 as imageio
|
||||
import numpy as np
|
||||
from src.policy_utils import select_action
|
||||
from src.runtime import configure_runtime_env
|
||||
|
||||
configure_runtime_env()
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
||||
@@ -30,11 +34,48 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||
parser.add_argument(
|
||||
"--time-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Per-step living cost added to reward (<=0.0). 0.0 disables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard-stuck-steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Consecutive near-stationary steps before truncation. 0 disables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard-stuck-epsilon",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Treat |x_pos delta| < epsilon as no movement for hard-stuck detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard-stuck-penalty",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Extra penalty applied when hard-stuck truncation triggers.",
|
||||
)
|
||||
parser.add_argument("--clip-reward", action="store_true")
|
||||
parser.add_argument("--fps", type=int, default=30)
|
||||
parser.add_argument("--duration-sec", type=int, default=10)
|
||||
parser.add_argument("--max-steps", type=int, default=5_000)
|
||||
parser.add_argument("--stochastic", action="store_true")
|
||||
parser.add_argument(
|
||||
"--epsilon",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Epsilon-greedy probability for deterministic policy. Ignored when --stochastic is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epsilon-random-mode",
|
||||
type=str,
|
||||
default="uniform",
|
||||
choices=["uniform", "policy"],
|
||||
help="Random action source for epsilon-greedy: uniform=action_space.sample, policy=model stochastic sample.",
|
||||
)
|
||||
parser.add_argument("--output", type=str, default="", help="Output mp4 path.")
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -106,6 +147,8 @@ def save_frames_fallback(frames, output_path: Path) -> Path:
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.epsilon < 0.0 or args.epsilon > 1.0:
|
||||
raise ValueError(f"--epsilon must be in [0, 1], got {args.epsilon}")
|
||||
seed_everything(args.seed)
|
||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||
|
||||
@@ -116,7 +159,12 @@ def main() -> None:
|
||||
model = PPO.load(str(model_path))
|
||||
print(f"[video] model={model_path}")
|
||||
movement = resolve_movement(args.movement, model)
|
||||
print(f"[video] movement={movement} reward_mode={reward_mode}")
|
||||
print(
|
||||
f"[video] movement={movement} reward_mode={reward_mode} "
|
||||
f"time_penalty={args.time_penalty} hard_stuck_steps={args.hard_stuck_steps} "
|
||||
f"hard_stuck_epsilon={args.hard_stuck_epsilon} hard_stuck_penalty={args.hard_stuck_penalty} "
|
||||
f"epsilon={args.epsilon} epsilon_random_mode={args.epsilon_random_mode}"
|
||||
)
|
||||
env = make_mario_env(
|
||||
env_id=args.env_id,
|
||||
seed=args.seed,
|
||||
@@ -135,6 +183,10 @@ def main() -> None:
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
time_penalty=args.time_penalty,
|
||||
hard_stuck_steps=args.hard_stuck_steps,
|
||||
hard_stuck_epsilon=args.hard_stuck_epsilon,
|
||||
hard_stuck_penalty=args.hard_stuck_penalty,
|
||||
)
|
||||
|
||||
obs, _ = reset_compat(env, seed=args.seed)
|
||||
@@ -152,9 +204,14 @@ def main() -> None:
|
||||
episode_count = 1
|
||||
|
||||
while len(frames) < target_frames and step_count < args.max_steps:
|
||||
action, _ = model.predict(obs, deterministic=not args.stochastic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = int(action.item())
|
||||
action = select_action(
|
||||
model=model,
|
||||
obs=obs,
|
||||
deterministic=not args.stochastic,
|
||||
epsilon=args.epsilon,
|
||||
epsilon_random_mode=args.epsilon_random_mode,
|
||||
env=env,
|
||||
)
|
||||
obs, reward, terminated, truncated, info = step_compat(env, action)
|
||||
reward_sum += float(reward)
|
||||
step_count += 1
|
||||
|
||||
14
mario-rl-mvp/src/runtime.py
Normal file
14
mario-rl-mvp/src/runtime.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def configure_runtime_env() -> None:
|
||||
"""Set runtime defaults that are safe on restricted Linux/WSL setups."""
|
||||
if "MPLCONFIGDIR" in os.environ:
|
||||
return
|
||||
|
||||
cache_dir = Path.cwd() / ".cache" / "matplotlib"
|
||||
cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.environ["MPLCONFIGDIR"] = str(cache_dir)
|
||||
@@ -5,6 +5,10 @@ import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from src.runtime import configure_runtime_env
|
||||
|
||||
configure_runtime_env()
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback
|
||||
from stable_baselines3.common.save_util import load_from_zip_file
|
||||
@@ -12,7 +16,13 @@ from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.env import make_env_fn
|
||||
from src.utils import ensure_artifact_paths, resolve_torch_device, seed_everything, write_latest_pointer
|
||||
from src.utils import (
|
||||
collect_torch_runtime_diagnostics,
|
||||
ensure_artifact_paths,
|
||||
resolve_torch_device,
|
||||
seed_everything,
|
||||
write_latest_pointer,
|
||||
)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
@@ -34,6 +44,30 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||
parser.add_argument(
|
||||
"--time-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Per-step living cost added to reward (<=0.0). 0.0 disables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard-stuck-steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Consecutive near-stationary steps before truncation. 0 disables.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard-stuck-epsilon",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Treat |x_pos delta| < epsilon as no movement for hard-stuck detection.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hard-stuck-penalty",
|
||||
type=float,
|
||||
default=5.0,
|
||||
help="Extra penalty applied when hard-stuck truncation triggers.",
|
||||
)
|
||||
|
||||
parser.add_argument("--learning-rate", type=float, default=2.5e-4)
|
||||
parser.add_argument("--n-steps", type=int, default=128)
|
||||
@@ -51,6 +85,11 @@ def parse_args() -> argparse.Namespace:
|
||||
default="auto",
|
||||
help="Torch device: auto | cpu | mps | cuda | cuda:0 ...",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow-device-fallback",
|
||||
action="store_true",
|
||||
help="Allow explicit --device requests (cuda/mps) to fallback to CPU when unavailable.",
|
||||
)
|
||||
parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.")
|
||||
parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.")
|
||||
parser.add_argument(
|
||||
@@ -98,7 +137,7 @@ def load_partial_policy_weights(model: PPO, init_model_path: Path, device: str)
|
||||
class EpisodeEndLoggingCallback(BaseCallback):
|
||||
"""Log per-episode terminal diagnostics to stdout and TensorBoard."""
|
||||
|
||||
REASON_TO_CODE = {"death": 0, "no_progress": 1, "timeout": 2, "clear": 3}
|
||||
REASON_TO_CODE = {"death": 0, "no_progress": 1, "timeout": 2, "clear": 3, "hard_stuck": 4}
|
||||
|
||||
def __init__(self, n_envs: int, tb_log_dir: Path):
|
||||
super().__init__(verbose=0)
|
||||
@@ -110,8 +149,13 @@ class EpisodeEndLoggingCallback(BaseCallback):
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_reason(info: dict) -> str:
|
||||
done_reason = str(info.get("done_reason", "")).strip().lower()
|
||||
if done_reason in EpisodeEndLoggingCallback.REASON_TO_CODE:
|
||||
return done_reason
|
||||
if bool(info.get("flag_get", False)):
|
||||
return "clear"
|
||||
if bool(info.get("terminated_by_hard_stuck", False)):
|
||||
return "hard_stuck"
|
||||
if bool(info.get("terminated_by_stall", False)):
|
||||
return "no_progress"
|
||||
if bool(info.get("TimeLimit.truncated", False)):
|
||||
@@ -164,6 +208,9 @@ class EpisodeEndLoggingCallback(BaseCallback):
|
||||
)
|
||||
self.tb_writer.add_scalar("episode_end/done_reason_timeout", 1.0 if done_reason == "timeout" else 0.0, episode_step)
|
||||
self.tb_writer.add_scalar("episode_end/done_reason_clear", 1.0 if done_reason == "clear" else 0.0, episode_step)
|
||||
self.tb_writer.add_scalar(
|
||||
"episode_end/done_reason_hard_stuck", 1.0 if done_reason == "hard_stuck" else 0.0, episode_step
|
||||
)
|
||||
self.tb_writer.flush()
|
||||
|
||||
print(
|
||||
@@ -196,12 +243,25 @@ def main() -> None:
|
||||
run_model_dir.mkdir(parents=True, exist_ok=True)
|
||||
run_tb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
requested_device = args.device.lower().strip()
|
||||
device, device_msg = resolve_torch_device(args.device)
|
||||
print(f"[device] {device} | {device_msg}")
|
||||
print(f"[device_diag] {collect_torch_runtime_diagnostics()}")
|
||||
|
||||
explicit_device_requested = requested_device not in {"auto", "cpu"}
|
||||
if explicit_device_requested and device == "cpu" and not args.allow_device_fallback:
|
||||
raise RuntimeError(
|
||||
f"Requested --device {args.device!r}, but it resolved to CPU. {device_msg}\n"
|
||||
"Either fix your CUDA/MPS environment, or re-run with --allow-device-fallback "
|
||||
"to continue on CPU intentionally."
|
||||
)
|
||||
|
||||
print(f"[run] {run_name}")
|
||||
print(
|
||||
f"[config] env_id={args.env_id}, movement={args.movement}, reward_mode={reward_mode}, "
|
||||
f"clip_reward={args.clip_reward}, n_envs={args.n_envs}"
|
||||
f"clip_reward={args.clip_reward}, n_envs={args.n_envs}, "
|
||||
f"time_penalty={args.time_penalty}, hard_stuck_steps={args.hard_stuck_steps}, "
|
||||
f"hard_stuck_epsilon={args.hard_stuck_epsilon}, hard_stuck_penalty={args.hard_stuck_penalty}"
|
||||
)
|
||||
|
||||
env_fns = [
|
||||
@@ -222,6 +282,10 @@ def main() -> None:
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
time_penalty=args.time_penalty,
|
||||
hard_stuck_steps=args.hard_stuck_steps,
|
||||
hard_stuck_epsilon=args.hard_stuck_epsilon,
|
||||
hard_stuck_penalty=args.hard_stuck_penalty,
|
||||
)
|
||||
for i in range(args.n_envs)
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
@@ -9,6 +10,16 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def cuda_is_available_quiet() -> bool:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="CUDA initialization: Unexpected error from cudaGetDeviceCount\\(\\).*",
|
||||
category=UserWarning,
|
||||
)
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArtifactPaths:
|
||||
root: Path
|
||||
@@ -36,7 +47,7 @@ def seed_everything(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
if cuda_is_available_quiet():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
@@ -46,7 +57,7 @@ def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]:
|
||||
return "cpu", "User requested CPU."
|
||||
|
||||
def _probe_cuda(device_name: str = "cuda") -> Tuple[str, str]:
|
||||
if not torch.cuda.is_available():
|
||||
if not cuda_is_available_quiet():
|
||||
return "cpu", "CUDA unavailable, using CPU."
|
||||
try:
|
||||
x = torch.ones(8, device=device_name)
|
||||
@@ -90,6 +101,47 @@ def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]:
|
||||
return "cpu", f"Unknown device '{requested}', fallback to CPU."
|
||||
|
||||
|
||||
def collect_torch_runtime_diagnostics() -> str:
|
||||
lines: list[str] = []
|
||||
lines.append(f"torch={torch.__version__}")
|
||||
lines.append(f"torch.version.cuda={torch.version.cuda}")
|
||||
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="CUDA initialization: Unexpected error from cudaGetDeviceCount\\(\\).*",
|
||||
category=UserWarning,
|
||||
)
|
||||
cuda_available = torch.cuda.is_available()
|
||||
lines.append(f"torch.cuda.is_available()={cuda_available}")
|
||||
except Exception as exc: # pragma: no cover - hardware dependent
|
||||
cuda_available = False
|
||||
lines.append(f"torch.cuda.is_available() raised={exc}")
|
||||
|
||||
try:
|
||||
cuda_count = torch.cuda.device_count()
|
||||
lines.append(f"torch.cuda.device_count()={cuda_count}")
|
||||
except Exception as exc: # pragma: no cover - hardware dependent
|
||||
cuda_count = 0
|
||||
lines.append(f"torch.cuda.device_count() raised={exc}")
|
||||
|
||||
if cuda_count > 0:
|
||||
try:
|
||||
lines.append(f"torch.cuda.get_device_name(0)={torch.cuda.get_device_name(0)}")
|
||||
except Exception as exc: # pragma: no cover - hardware dependent
|
||||
lines.append(f"torch.cuda.get_device_name(0) raised={exc}")
|
||||
try:
|
||||
lines.append(f"torch.cuda.get_device_capability(0)={torch.cuda.get_device_capability(0)}")
|
||||
except Exception as exc: # pragma: no cover - hardware dependent
|
||||
lines.append(f"torch.cuda.get_device_capability(0) raised={exc}")
|
||||
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
lines.append(f"torch.backends.mps.is_available()={mps_available}")
|
||||
|
||||
return " | ".join(lines)
|
||||
|
||||
|
||||
def latest_model_path(models_dir: Path) -> Optional[Path]:
|
||||
candidates = [p for p in models_dir.rglob("*.zip") if p.is_file()]
|
||||
if not candidates:
|
||||
|
||||
Reference in New Issue
Block a user