feat: initial mario rl mvp
This commit is contained in:
44
.gitignore
vendored
Normal file
44
.gitignore
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
# Python cache / build
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.pyo
|
||||
*.pyd
|
||||
*.so
|
||||
*.egg-info/
|
||||
.eggs/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
*/.venv/
|
||||
|
||||
# Logs / test / tooling
|
||||
*.log
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
|
||||
# Jupyter
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# RL artifacts
|
||||
artifacts/
|
||||
*/artifacts/
|
||||
|
||||
# TensorBoard / checkpoints / models
|
||||
runs/
|
||||
checkpoints/
|
||||
models/
|
||||
videos/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
339
mario-rl-mvp/README.md
Normal file
339
mario-rl-mvp/README.md
Normal file
@@ -0,0 +1,339 @@
|
||||
# Mario RL MVP (macOS Apple Silicon)
|
||||
|
||||
最小可运行工程:使用像素输入 + 传统 CNN policy(`stable-baselines3` PPO)训练 `gym-super-mario-bros / nes-py` 智能体,不使用大语言模型。
|
||||
|
||||
## 1. 项目结构
|
||||
|
||||
```text
|
||||
mario-rl-mvp/
|
||||
src/
|
||||
env.py
|
||||
train_ppo.py
|
||||
eval.py
|
||||
record_video.py
|
||||
utils.py
|
||||
artifacts/
|
||||
models/
|
||||
videos/
|
||||
logs/
|
||||
requirements.txt
|
||||
README.md
|
||||
```
|
||||
|
||||
## 2. 环境准备(macOS / Apple Silicon)
|
||||
|
||||
建议 Python 3.9+(本机默认 `python3` 即可)。
|
||||
|
||||
```bash
|
||||
cd /Users/roog/Work/FNT/SpMr/mario-rl-mvp
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
可选系统依赖(用于 ffmpeg 转码与潜在 SDL 兼容):
|
||||
|
||||
```bash
|
||||
brew install ffmpeg sdl2
|
||||
```
|
||||
|
||||
## 3. 一条命令开始训练
|
||||
|
||||
默认 CPU 训练(如果检测到可用且稳定的 MPS,会自动尝试启用,否则自动回退 CPU):
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo
|
||||
```
|
||||
|
||||
常用覆盖参数:
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo \
|
||||
--total-timesteps 1000000 \
|
||||
--n-envs 4 \
|
||||
--save-freq 50000 \
|
||||
--env-id SuperMarioBros-1-1-v0 \
|
||||
--movement right_only \
|
||||
--seed 42
|
||||
```
|
||||
|
||||
从已有 checkpoint 初始化后继续训练(可同时改超参数,如 `--ent-coef`):
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo \
|
||||
--init-model-path artifacts/models/latest_model.zip \
|
||||
--total-timesteps 500000 \
|
||||
--ent-coef 0.02 \
|
||||
--learning-rate 1e-4
|
||||
```
|
||||
|
||||
如果切换了动作空间(例如 checkpoint 是 `right_only`,当前想用 `simple`),可用部分加载:
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo \
|
||||
--init-model-path artifacts/models/latest_model.zip \
|
||||
--allow-partial-init \
|
||||
--movement simple \
|
||||
--total-timesteps 300000
|
||||
```
|
||||
|
||||
### 3.1 从已有模型继续训练(`--init-model-path`)
|
||||
|
||||
- 用途:加载已有 `.zip` 权重后继续训练,适合“不中断实验目标但调整探索参数”。
|
||||
- 常见场景:当前策略陷入局部最优(例如 `approx_kl` 和 `policy_gradient_loss` 长期接近 0),希望从已有模型继续探索。
|
||||
- 注意:这不是“热更新”,仍然需要停止当前训练进程后用新命令重启。
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo \
|
||||
--init-model-path artifacts/models/latest_model.zip \
|
||||
--ent-coef 0.02 \
|
||||
--learning-rate 1e-4 \
|
||||
--total-timesteps 300000
|
||||
```
|
||||
|
||||
训练输出:
|
||||
- stdout:PPO 训练日志
|
||||
- TensorBoard:`artifacts/logs/<run_name>/tb/`
|
||||
- checkpoint:`artifacts/models/<run_name>/ppo_mario_ckpt_*.zip`
|
||||
- final model:`artifacts/models/<run_name>/final_model.zip`
|
||||
- latest 指针:`artifacts/models/latest_model.zip` + `LATEST_MODEL.txt`
|
||||
|
||||
启动 TensorBoard:
|
||||
|
||||
```bash
|
||||
tensorboard --logdir artifacts/logs --port 6006
|
||||
```
|
||||
|
||||
## 3.2 训练日志字段速查(PPO)
|
||||
|
||||
训练时你会看到类似:
|
||||
|
||||
```text
|
||||
| rollout/ep_len_mean | rollout/ep_rew_mean | ... |
|
||||
| train/approx_kl | train/entropy_loss | ... |
|
||||
```
|
||||
|
||||
下面是常用字段的含义(对应你贴出来那组):
|
||||
|
||||
- `rollout/ep_len_mean`:最近一批 episode 的平均步数。越大不一定越好,要结合 reward 一起看。
|
||||
- `rollout/ep_rew_mean`:最近一批 episode 的平均回报。通常越高越好。
|
||||
- `time/fps`:训练吞吐(每秒环境步数),只代表速度,不代表策略质量。
|
||||
- `time/iterations`:第几次 rollout + update 循环。
|
||||
- `time/time_elapsed`:训练已运行的秒数。
|
||||
- `time/total_timesteps`:累计环境交互步数(达到你设定的 `--total-timesteps` 会停止)。
|
||||
- `train/approx_kl`:新旧策略差异大小。太大说明更新过猛;接近 0 说明几乎没在更新。极小负数通常是数值误差,可当作 0。
|
||||
- `train/clip_fraction`:有多少样本触发 PPO clipping。长期为 0 且 KL 也接近 0,常见于“策略基本不再更新”。
|
||||
- `train/clip_range`:PPO 的 clipping 阈值(默认 0.2)。
|
||||
- `train/entropy_loss`:探索强度指标。绝对值越接近 0,策略越确定、探索越少。
|
||||
- `train/explained_variance`:价值网络对回报的解释度,越接近 1 越好,接近 0 说明 value 还不稳。
|
||||
- `train/learning_rate`:优化器步长(参数更新幅度),不是硬件速度。
|
||||
- `train/loss`:总损失(由多个部分组成),主要看趋势,不单看绝对值。
|
||||
- `train/policy_gradient_loss`:策略网络的更新信号。长期接近 0 可能表示 actor 更新很弱。
|
||||
- `train/value_loss`:价值网络误差。过大通常代表 critic 拟合还不稳定。
|
||||
|
||||
### 快速判断(实用版)
|
||||
|
||||
- `ep_rew_mean` / `avg_max_x_pos` 持续上升:一般在变好。
|
||||
- `approx_kl≈0` + `clip_fraction=0` + `policy_gradient_loss≈0`:大概率卡住(更新几乎停了)。
|
||||
- `entropy_loss` 绝对值太小且长期不变:探索不足,可尝试提高 `--ent-coef`。
|
||||
|
||||
## 4. 评估模型
|
||||
|
||||
加载最新模型,跑 N 个 episode,输出平均指标:
|
||||
|
||||
```bash
|
||||
python -m src.eval --episodes 5
|
||||
```
|
||||
|
||||
可指定模型:
|
||||
|
||||
```bash
|
||||
python -m src.eval --model-path artifacts/models/latest_model.zip --episodes 10
|
||||
```
|
||||
|
||||
注意:`eval.py` 默认 `--movement auto`,会按模型动作维度自动匹配 `right_only/simple`,避免动作空间不一致导致 `KeyError`。
|
||||
|
||||
输出指标包括:
|
||||
- `avg_reward`
|
||||
- `avg_max_x_pos`
|
||||
- `clear_rate`(`flag_get=True` 的比例)
|
||||
|
||||
|
||||
查看步数
|
||||
|
||||
```php
|
||||
unzip -p artifacts/models/latest_model.zip data | rg '"num_timesteps"|"_total_timesteps"|"_tensorboard_log"'
|
||||
```
|
||||
|
||||
num_timesteps = 151552
|
||||
_total_timesteps = 150000
|
||||
|
||||
## 5. 录制回放视频(无窗口/headless)
|
||||
|
||||
默认录制约 10 秒 mp4 到 `artifacts/videos/`:
|
||||
|
||||
```bash
|
||||
python -m src.record_video --duration-sec 10 --fps 30
|
||||
```
|
||||
|
||||
可指定输出路径:
|
||||
|
||||
```bash
|
||||
python -m src.record_video --output artifacts/videos/demo.mp4 --duration-sec 10
|
||||
```
|
||||
|
||||
注意:`record_video.py` 默认 `--movement auto`,会按模型自动匹配动作空间。
|
||||
|
||||
实现方式:
|
||||
- 使用 `render_mode=rgb_array`,无需打开窗口
|
||||
- 默认通过 `imageio + ffmpeg` 输出 mp4
|
||||
- 若 mp4 写入失败,会自动降级保存帧序列(PNG),并打印 ffmpeg 转码命令
|
||||
|
||||
## 6. 动作空间选择说明
|
||||
|
||||
默认 `RIGHT_ONLY`,原因:
|
||||
- 动作更少,探索空间更小,MVP 更快收敛到“向右推进”策略
|
||||
- 适合先验证训练闭环
|
||||
|
||||
可切到 `SIMPLE_MOVEMENT`(动作更丰富):
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo --movement simple
|
||||
```
|
||||
|
||||
## 7. 预处理与奖励
|
||||
|
||||
默认预处理链路:
|
||||
- 跳帧:`frame_skip=4`
|
||||
- 灰度:`RGB -> Gray`
|
||||
- 缩放:`84x84`
|
||||
- 帧堆叠:`4`
|
||||
- 通道布局:`CHW`(兼容 `CnnPolicy`)
|
||||
|
||||
奖励:
|
||||
- `--reward-mode raw`:原始奖励(默认)
|
||||
- `--reward-mode clip`:裁剪奖励 `sign(reward)`(等价于旧参数 `--clip-reward`)
|
||||
- `--reward-mode progress`:奖励塑形模式,额外包含:
|
||||
- 前进增益奖励(`--progress-scale`)
|
||||
- 死亡惩罚(`--death-penalty`)
|
||||
- 通关奖励(`--flag-bonus`)
|
||||
- 原地卡住惩罚(`--stall-penalty` + `--stall-steps`)
|
||||
- 后退惩罚(`--backward-penalty-scale`)
|
||||
|
||||
使用裁剪奖励:
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo --reward-mode clip
|
||||
```
|
||||
|
||||
针对“卡在固定位置(如 x=314 撞蘑菇)”的推荐续训命令:
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo \
|
||||
--init-model-path artifacts/models/latest_model.zip \
|
||||
--allow-partial-init \
|
||||
--reward-mode progress \
|
||||
--movement simple \
|
||||
--ent-coef 0.04 \
|
||||
--learning-rate 1e-4 \
|
||||
--n-steps 512 \
|
||||
--gamma 0.995 \
|
||||
--death-penalty -120 \
|
||||
--stall-penalty 0.2 \
|
||||
--stall-steps 20 \
|
||||
--backward-penalty-scale 0.03 \
|
||||
--milestone-bonus 2.0 \
|
||||
--no-progress-terminate-steps 80 \
|
||||
--no-progress-terminate-penalty 30 \
|
||||
--total-timesteps 150000
|
||||
```
|
||||
|
||||
## 8. 常见问题排查
|
||||
|
||||
### 8.1 `pip install` 失败(nes-py/gym-super-mario-bros 编译问题)
|
||||
|
||||
先安装工具链并重试:
|
||||
|
||||
```bash
|
||||
xcode-select --install
|
||||
brew install cmake swig sdl2
|
||||
pip install --upgrade pip setuptools wheel
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### 8.2 MPS 不稳定或报错
|
||||
|
||||
强制 CPU:
|
||||
|
||||
```bash
|
||||
python -m src.train_ppo --device cpu
|
||||
```
|
||||
|
||||
说明:脚本会先做一次 MPS 张量 sanity check,失败自动回退 CPU。
|
||||
|
||||
### 8.3 视频写入失败(ffmpeg/codec)
|
||||
|
||||
1) 安装系统 ffmpeg:
|
||||
```bash
|
||||
brew install ffmpeg
|
||||
```
|
||||
|
||||
2) 已降级保存帧序列时,手动转码:
|
||||
```bash
|
||||
ffmpeg -framerate 30 -i artifacts/videos/<name>_frames/frame_%06d.png -c:v libx264 -pix_fmt yuv420p artifacts/videos/<name>.mp4
|
||||
```
|
||||
|
||||
### 8.4 图形窗口相关报错
|
||||
|
||||
本工程默认 `rgb_array` 录制,不依赖 GUI 窗口。
|
||||
若仍遇到 SDL 问题,可显式设置:
|
||||
|
||||
```bash
|
||||
export SDL_VIDEODRIVER=dummy
|
||||
python -m src.record_video --duration-sec 10
|
||||
```
|
||||
|
||||
### 8.5 `size mismatch for action_net`(加载旧模型时报错)
|
||||
|
||||
典型原因:旧 checkpoint 的动作空间与当前配置不同(如 `right_only`=5 动作,`simple`=7 动作)。
|
||||
|
||||
可选修复:
|
||||
|
||||
1) 保持和 checkpoint 一致的动作空间:
|
||||
```bash
|
||||
python -m src.train_ppo --init-model-path /Users/roog/Work/FNT/SpMr/mario-rl-mvp/artifacts/models/ppo_SuperMarioBros-1-1-v0_20260212_164717/ppo_mario_ckpt_150000_steps.zip --movement right_only
|
||||
```
|
||||
|
||||
2) 若你确实要切动作空间,用部分初始化(跳过不兼容动作头):
|
||||
```bash
|
||||
python -m src.train_ppo --init-model-path artifacts/models/latest_model.zip --movement simple --allow-partial-init
|
||||
```
|
||||
|
||||
3) 或者直接不加载旧模型,从头训练新动作空间。
|
||||
|
||||
## 9. 最小 smoke test(按顺序执行)
|
||||
|
||||
```bash
|
||||
cd /Users/roog/Work/FNT/SpMr/mario-rl-mvp
|
||||
source .venv/bin/activate
|
||||
|
||||
# 1) 训练 1e4 steps,并至少写出 checkpoint + final model
|
||||
python -m src.train_ppo \
|
||||
--total-timesteps 10000 \
|
||||
--save-freq 2000 \
|
||||
--n-envs 1 \
|
||||
--device cpu \
|
||||
--movement right_only
|
||||
|
||||
# 2) 快速评估
|
||||
python -m src.eval --episodes 2 --max-steps 2000
|
||||
|
||||
# 3) 录制 10 秒视频
|
||||
python -m src.record_video --duration-sec 10 --fps 30
|
||||
```
|
||||
|
||||
验收标准:
|
||||
- `artifacts/models/` 下有 `.zip` 模型
|
||||
- `artifacts/logs/` 下有 TensorBoard event 文件
|
||||
- `artifacts/videos/` 下有 `.mp4`(或失败时有 `_frames/` 帧序列)
|
||||
12
mario-rl-mvp/requirements.txt
Normal file
12
mario-rl-mvp/requirements.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
torch==2.5.1
|
||||
stable-baselines3==2.3.2
|
||||
gym==0.26.2
|
||||
gymnasium==0.29.1
|
||||
shimmy==1.3.0
|
||||
gym-super-mario-bros==7.4.0
|
||||
nes-py==8.2.1
|
||||
opencv-python==4.10.0.84
|
||||
numpy==1.26.4
|
||||
tensorboard==2.18.0
|
||||
imageio==2.36.1
|
||||
imageio-ffmpeg==0.5.1
|
||||
1
mario-rl-mvp/src/__init__.py
Normal file
1
mario-rl-mvp/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Mario RL MVP package."""
|
||||
337
mario-rl-mvp/src/env.py
Normal file
337
mario-rl-mvp/src/env.py
Normal file
@@ -0,0 +1,337 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Deque, Dict, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import gym
|
||||
import gym_super_mario_bros
|
||||
import numpy as np
|
||||
from gym.spaces import Box
|
||||
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT
|
||||
from nes_py.wrappers import JoypadSpace
|
||||
|
||||
|
||||
def reset_compat(env: gym.Env, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict[str, Any]]:
|
||||
try:
|
||||
result = env.reset(seed=seed)
|
||||
except TypeError:
|
||||
result = env.reset()
|
||||
if seed is not None and hasattr(env, "seed"):
|
||||
env.seed(seed)
|
||||
|
||||
if isinstance(result, tuple) and len(result) == 2:
|
||||
obs, info = result
|
||||
return obs, info
|
||||
return result, {}
|
||||
|
||||
|
||||
def step_compat(env: gym.Env, action: Any) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
|
||||
result = env.step(action)
|
||||
if isinstance(result, tuple) and len(result) == 5:
|
||||
obs, reward, terminated, truncated, info = result
|
||||
return obs, float(reward), bool(terminated), bool(truncated), info
|
||||
if isinstance(result, tuple) and len(result) == 4:
|
||||
obs, reward, done, info = result
|
||||
return obs, float(reward), bool(done), False, info
|
||||
raise RuntimeError(f"Unexpected step return format: {type(result)} / {result}")
|
||||
|
||||
|
||||
class SkipFrame(gym.Wrapper):
|
||||
def __init__(self, env: gym.Env, skip: int = 4):
|
||||
super().__init__(env)
|
||||
self._skip = skip
|
||||
|
||||
def step(self, action: Any):
|
||||
total_reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
info: Dict[str, Any] = {}
|
||||
obs = None
|
||||
for _ in range(self._skip):
|
||||
obs, reward, terminated, truncated, info = step_compat(self.env, action)
|
||||
total_reward += reward
|
||||
if terminated or truncated:
|
||||
break
|
||||
return obs, total_reward, terminated, truncated, info
|
||||
|
||||
|
||||
class PreprocessFrame(gym.ObservationWrapper):
|
||||
"""Convert RGB frame to grayscale 84x84 uint8."""
|
||||
|
||||
def __init__(self, env: gym.Env, width: int = 84, height: int = 84):
|
||||
super().__init__(env)
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.observation_space = Box(low=0, high=255, shape=(height, width, 1), dtype=np.uint8)
|
||||
|
||||
def observation(self, observation: np.ndarray) -> np.ndarray:
|
||||
if observation.ndim == 3 and observation.shape[2] == 3:
|
||||
gray = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
|
||||
elif observation.ndim == 2:
|
||||
gray = observation
|
||||
else:
|
||||
gray = np.squeeze(observation)
|
||||
resized = cv2.resize(gray, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
||||
return resized[:, :, None].astype(np.uint8)
|
||||
|
||||
|
||||
class ChannelLastFrameStack(gym.Wrapper):
|
||||
"""Stack frames on the channel axis: (H, W, C*num_stack)."""
|
||||
|
||||
def __init__(self, env: gym.Env, num_stack: int = 4):
|
||||
super().__init__(env)
|
||||
self.num_stack = num_stack
|
||||
self.frames: Deque[np.ndarray] = deque(maxlen=num_stack)
|
||||
|
||||
obs_space = env.observation_space
|
||||
assert isinstance(obs_space, Box), "Frame stack requires Box observation space."
|
||||
h, w, c = obs_space.shape
|
||||
self.observation_space = Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(h, w, c * num_stack),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
del options
|
||||
obs, info = reset_compat(self.env, seed=seed)
|
||||
self.frames.clear()
|
||||
for _ in range(self.num_stack):
|
||||
self.frames.append(obs)
|
||||
return self._get_observation(), info
|
||||
|
||||
def step(self, action: Any):
|
||||
obs, reward, terminated, truncated, info = step_compat(self.env, action)
|
||||
self.frames.append(obs)
|
||||
return self._get_observation(), reward, terminated, truncated, info
|
||||
|
||||
def _get_observation(self) -> np.ndarray:
|
||||
assert len(self.frames) == self.num_stack
|
||||
return np.concatenate(list(self.frames), axis=2)
|
||||
|
||||
|
||||
class TransposeObservation(gym.ObservationWrapper):
|
||||
"""Convert observation from HWC to CHW for CNN policy."""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
super().__init__(env)
|
||||
obs_space = env.observation_space
|
||||
assert isinstance(obs_space, Box), "TransposeObservation requires Box observation space."
|
||||
h, w, c = obs_space.shape
|
||||
self.observation_space = Box(low=0, high=255, shape=(c, h, w), dtype=obs_space.dtype)
|
||||
|
||||
def observation(self, observation: np.ndarray) -> np.ndarray:
|
||||
return np.transpose(observation, (2, 0, 1)).astype(np.uint8)
|
||||
|
||||
|
||||
class ClipRewardEnv(gym.RewardWrapper):
|
||||
def reward(self, reward):
|
||||
return float(np.sign(reward))
|
||||
|
||||
|
||||
class ProgressRewardEnv(gym.Wrapper):
|
||||
"""Reward shaping focused on moving right and avoiding local traps."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
progress_scale: float = 0.02,
|
||||
death_penalty: float = -50.0,
|
||||
flag_bonus: float = 100.0,
|
||||
stall_penalty: float = 0.05,
|
||||
stall_steps: int = 40,
|
||||
backward_penalty_scale: float = 0.01,
|
||||
milestone_interval: int = 32,
|
||||
milestone_bonus: float = 1.0,
|
||||
no_progress_terminate_steps: int = 120,
|
||||
no_progress_terminate_penalty: float = 20.0,
|
||||
):
|
||||
super().__init__(env)
|
||||
self.progress_scale = progress_scale
|
||||
self.death_penalty = death_penalty
|
||||
self.flag_bonus = flag_bonus
|
||||
self.stall_penalty = stall_penalty
|
||||
self.stall_steps = stall_steps
|
||||
self.backward_penalty_scale = backward_penalty_scale
|
||||
self.milestone_interval = milestone_interval
|
||||
self.milestone_bonus = milestone_bonus
|
||||
self.no_progress_terminate_steps = no_progress_terminate_steps
|
||||
self.no_progress_terminate_penalty = no_progress_terminate_penalty
|
||||
self._last_x_pos: Optional[float] = None
|
||||
self._best_x_pos = 0.0
|
||||
self._stall_count = 0
|
||||
self._next_milestone_x = float(milestone_interval)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
del options
|
||||
obs, info = reset_compat(self.env, seed=seed)
|
||||
self._last_x_pos = float(info.get("x_pos", 0.0))
|
||||
self._best_x_pos = self._last_x_pos
|
||||
self._stall_count = 0
|
||||
if self.milestone_interval > 0:
|
||||
k = int(self._best_x_pos // self.milestone_interval) + 1
|
||||
self._next_milestone_x = float(k * self.milestone_interval)
|
||||
return obs, info
|
||||
|
||||
def step(self, action: Any):
|
||||
obs, reward, terminated, truncated, info = step_compat(self.env, action)
|
||||
x_pos = float(info.get("x_pos", 0.0))
|
||||
|
||||
if self._last_x_pos is None:
|
||||
delta_x = 0.0
|
||||
else:
|
||||
delta_x = x_pos - self._last_x_pos
|
||||
self._last_x_pos = x_pos
|
||||
|
||||
shaped_reward = float(reward)
|
||||
if delta_x > 0:
|
||||
shaped_reward += self.progress_scale * delta_x
|
||||
self._stall_count = 0
|
||||
if x_pos > self._best_x_pos:
|
||||
self._best_x_pos = x_pos
|
||||
if self.milestone_interval > 0 and self.milestone_bonus != 0.0:
|
||||
while x_pos >= self._next_milestone_x:
|
||||
shaped_reward += self.milestone_bonus
|
||||
self._next_milestone_x += self.milestone_interval
|
||||
else:
|
||||
self._stall_count += 1
|
||||
if delta_x < 0:
|
||||
shaped_reward -= self.backward_penalty_scale * abs(delta_x)
|
||||
|
||||
if self._stall_count >= self.stall_steps:
|
||||
shaped_reward -= self.stall_penalty
|
||||
|
||||
if (
|
||||
self.no_progress_terminate_steps > 0
|
||||
and not terminated
|
||||
and not truncated
|
||||
and self._stall_count >= self.no_progress_terminate_steps
|
||||
):
|
||||
truncated = True
|
||||
shaped_reward -= self.no_progress_terminate_penalty
|
||||
info["terminated_by_stall"] = True
|
||||
|
||||
if terminated or truncated:
|
||||
if bool(info.get("flag_get", False)):
|
||||
shaped_reward += self.flag_bonus
|
||||
elif terminated:
|
||||
shaped_reward += self.death_penalty
|
||||
|
||||
return obs, shaped_reward, terminated, truncated, info
|
||||
|
||||
|
||||
def get_action_set(name: str):
|
||||
name = name.lower().strip()
|
||||
if name == "simple":
|
||||
return SIMPLE_MOVEMENT
|
||||
if name == "right_only":
|
||||
return RIGHT_ONLY
|
||||
raise ValueError(f"Unsupported movement='{name}'. Use one of: right_only, simple")
|
||||
|
||||
|
||||
def make_mario_env(
|
||||
env_id: str = "SuperMarioBros-1-1-v0",
|
||||
seed: int = 0,
|
||||
movement: str = "right_only",
|
||||
reward_mode: str = "raw",
|
||||
clip_reward: bool = False,
|
||||
frame_skip: int = 4,
|
||||
render_mode: Optional[str] = None,
|
||||
progress_scale: float = 0.02,
|
||||
death_penalty: float = -50.0,
|
||||
flag_bonus: float = 100.0,
|
||||
stall_penalty: float = 0.05,
|
||||
stall_steps: int = 40,
|
||||
backward_penalty_scale: float = 0.01,
|
||||
milestone_interval: int = 32,
|
||||
milestone_bonus: float = 1.0,
|
||||
no_progress_terminate_steps: int = 120,
|
||||
no_progress_terminate_penalty: float = 20.0,
|
||||
) -> gym.Env:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if render_mode is not None:
|
||||
kwargs["render_mode"] = render_mode
|
||||
|
||||
try:
|
||||
env = gym_super_mario_bros.make(env_id, apply_api_compatibility=True, **kwargs)
|
||||
except TypeError:
|
||||
env = gym_super_mario_bros.make(env_id, **kwargs)
|
||||
|
||||
env = JoypadSpace(env, get_action_set(movement))
|
||||
env = SkipFrame(env, skip=frame_skip)
|
||||
|
||||
mode = reward_mode.lower().strip()
|
||||
if clip_reward and mode == "raw":
|
||||
mode = "clip"
|
||||
if mode == "clip":
|
||||
env = ClipRewardEnv(env)
|
||||
elif mode == "progress":
|
||||
env = ProgressRewardEnv(
|
||||
env=env,
|
||||
progress_scale=progress_scale,
|
||||
death_penalty=death_penalty,
|
||||
flag_bonus=flag_bonus,
|
||||
stall_penalty=stall_penalty,
|
||||
stall_steps=stall_steps,
|
||||
backward_penalty_scale=backward_penalty_scale,
|
||||
milestone_interval=milestone_interval,
|
||||
milestone_bonus=milestone_bonus,
|
||||
no_progress_terminate_steps=no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=no_progress_terminate_penalty,
|
||||
)
|
||||
elif mode != "raw":
|
||||
raise ValueError(f"Unsupported reward_mode='{reward_mode}'. Use one of: raw, clip, progress")
|
||||
|
||||
env = PreprocessFrame(env, width=84, height=84)
|
||||
env = ChannelLastFrameStack(env, num_stack=4)
|
||||
env = TransposeObservation(env)
|
||||
|
||||
# Seed once so each env subprocess/dummy env has deterministic startup.
|
||||
reset_compat(env, seed=seed)
|
||||
if hasattr(env.action_space, "seed"):
|
||||
env.action_space.seed(seed)
|
||||
return env
|
||||
|
||||
|
||||
def make_env_fn(
|
||||
env_id: str,
|
||||
seed: int,
|
||||
movement: str,
|
||||
reward_mode: str,
|
||||
clip_reward: bool,
|
||||
frame_skip: int,
|
||||
progress_scale: float,
|
||||
death_penalty: float,
|
||||
flag_bonus: float,
|
||||
stall_penalty: float,
|
||||
stall_steps: int,
|
||||
backward_penalty_scale: float,
|
||||
milestone_interval: int,
|
||||
milestone_bonus: float,
|
||||
no_progress_terminate_steps: int,
|
||||
no_progress_terminate_penalty: float,
|
||||
) -> Callable[[], gym.Env]:
|
||||
def _thunk() -> gym.Env:
|
||||
return make_mario_env(
|
||||
env_id=env_id,
|
||||
seed=seed,
|
||||
movement=movement,
|
||||
reward_mode=reward_mode,
|
||||
clip_reward=clip_reward,
|
||||
frame_skip=frame_skip,
|
||||
render_mode=None,
|
||||
progress_scale=progress_scale,
|
||||
death_penalty=death_penalty,
|
||||
flag_bonus=flag_bonus,
|
||||
stall_penalty=stall_penalty,
|
||||
stall_steps=stall_steps,
|
||||
backward_penalty_scale=backward_penalty_scale,
|
||||
milestone_interval=milestone_interval,
|
||||
milestone_bonus=milestone_bonus,
|
||||
no_progress_terminate_steps=no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=no_progress_terminate_penalty,
|
||||
)
|
||||
|
||||
return _thunk
|
||||
162
mario-rl-mvp/src/eval.py
Normal file
162
mario-rl-mvp/src/eval.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
||||
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Evaluate trained Mario PPO agent.")
|
||||
parser.add_argument("--model-path", type=str, default="", help="Path to .zip model. If empty, use latest model.")
|
||||
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
|
||||
parser.add_argument("--movement", type=str, default="auto", choices=["auto", "right_only", "simple"])
|
||||
parser.add_argument("--episodes", type=int, default=5)
|
||||
parser.add_argument("--max-steps", type=int, default=5_000)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--frame-skip", type=int, default=4)
|
||||
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
|
||||
parser.add_argument("--progress-scale", type=float, default=0.02)
|
||||
parser.add_argument("--death-penalty", type=float, default=-50.0)
|
||||
parser.add_argument("--flag-bonus", type=float, default=100.0)
|
||||
parser.add_argument("--stall-penalty", type=float, default=0.05)
|
||||
parser.add_argument("--stall-steps", type=int, default=40)
|
||||
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
|
||||
parser.add_argument("--milestone-interval", type=int, default=32)
|
||||
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||
parser.add_argument("--clip-reward", action="store_true")
|
||||
parser.add_argument("--stochastic", action="store_true", help="Use stochastic policy (deterministic=False).")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _action_count_for_movement(movement: str) -> int:
|
||||
return len(get_action_set(movement))
|
||||
|
||||
|
||||
def _model_action_count(model: PPO):
|
||||
if hasattr(model, "action_space") and hasattr(model.action_space, "n"):
|
||||
return int(model.action_space.n)
|
||||
action_net = getattr(getattr(model, "policy", None), "action_net", None)
|
||||
out_features = getattr(action_net, "out_features", None)
|
||||
if out_features is not None:
|
||||
return int(out_features)
|
||||
return None
|
||||
|
||||
|
||||
def resolve_movement(movement_arg: str, model: PPO) -> str:
|
||||
model_action_n = _model_action_count(model)
|
||||
|
||||
if movement_arg == "auto":
|
||||
if model_action_n is not None:
|
||||
for candidate in ("right_only", "simple"):
|
||||
if _action_count_for_movement(candidate) == model_action_n:
|
||||
return candidate
|
||||
return "right_only"
|
||||
|
||||
if model_action_n is not None:
|
||||
expected_n = _action_count_for_movement(movement_arg)
|
||||
if expected_n != model_action_n:
|
||||
raise ValueError(
|
||||
f"movement='{movement_arg}' has {expected_n} actions, but model expects {model_action_n}. "
|
||||
"Use --movement auto or pass the matching movement."
|
||||
)
|
||||
return movement_arg
|
||||
|
||||
|
||||
def resolve_model_path(user_path: str) -> Path:
|
||||
if user_path:
|
||||
p = Path(user_path).expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Model not found: {p}")
|
||||
return p
|
||||
|
||||
paths = ensure_artifact_paths()
|
||||
latest = latest_model_path(paths.models)
|
||||
if latest is None:
|
||||
raise FileNotFoundError("No model found under artifacts/models. Please run training first.")
|
||||
return latest
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
seed_everything(args.seed)
|
||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||
|
||||
model_path = resolve_model_path(args.model_path)
|
||||
print(f"[eval] model={model_path}")
|
||||
model = PPO.load(str(model_path))
|
||||
movement = resolve_movement(args.movement, model)
|
||||
print(f"[eval] movement={movement} reward_mode={reward_mode}")
|
||||
|
||||
env = make_mario_env(
|
||||
env_id=args.env_id,
|
||||
seed=args.seed,
|
||||
movement=movement,
|
||||
reward_mode=reward_mode,
|
||||
clip_reward=args.clip_reward,
|
||||
frame_skip=args.frame_skip,
|
||||
render_mode=None,
|
||||
progress_scale=args.progress_scale,
|
||||
death_penalty=args.death_penalty,
|
||||
flag_bonus=args.flag_bonus,
|
||||
stall_penalty=args.stall_penalty,
|
||||
stall_steps=args.stall_steps,
|
||||
backward_penalty_scale=args.backward_penalty_scale,
|
||||
milestone_interval=args.milestone_interval,
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
)
|
||||
|
||||
rewards = []
|
||||
max_x_positions = []
|
||||
clear_flags = []
|
||||
|
||||
for ep in range(1, args.episodes + 1):
|
||||
obs, info = reset_compat(env, seed=args.seed + ep)
|
||||
done = False
|
||||
ep_reward = 0.0
|
||||
ep_max_x = float(info.get("x_pos", 0.0))
|
||||
flag_get = False
|
||||
step_count = 0
|
||||
|
||||
while not done and step_count < args.max_steps:
|
||||
action, _ = model.predict(obs, deterministic=not args.stochastic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = int(action.item())
|
||||
obs, reward, terminated, truncated, info = step_compat(env, action)
|
||||
ep_reward += float(reward)
|
||||
ep_max_x = max(ep_max_x, float(info.get("x_pos", 0.0)))
|
||||
flag_get = flag_get or bool(info.get("flag_get", False))
|
||||
done = terminated or truncated
|
||||
step_count += 1
|
||||
|
||||
rewards.append(ep_reward)
|
||||
max_x_positions.append(ep_max_x)
|
||||
clear_flags.append(1.0 if flag_get else 0.0)
|
||||
print(
|
||||
f"[episode {ep}] reward={ep_reward:.2f} max_x={ep_max_x:.1f} "
|
||||
f"clear={flag_get} steps={step_count}"
|
||||
)
|
||||
|
||||
summary = {
|
||||
"episodes": args.episodes,
|
||||
"avg_reward": mean(rewards) if rewards else 0.0,
|
||||
"avg_max_x_pos": mean(max_x_positions) if max_x_positions else 0.0,
|
||||
"clear_rate": mean(clear_flags) if clear_flags else 0.0,
|
||||
}
|
||||
print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2))
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
200
mario-rl-mvp/src/record_video.py
Normal file
200
mario-rl-mvp/src/record_video.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import imageio.v2 as imageio
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
||||
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Record Mario agent rollout to mp4 (headless).")
|
||||
parser.add_argument("--model-path", type=str, default="", help="Path to .zip model. Empty means latest model.")
|
||||
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
|
||||
parser.add_argument("--movement", type=str, default="auto", choices=["auto", "right_only", "simple"])
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--frame-skip", type=int, default=4)
|
||||
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
|
||||
parser.add_argument("--progress-scale", type=float, default=0.02)
|
||||
parser.add_argument("--death-penalty", type=float, default=-50.0)
|
||||
parser.add_argument("--flag-bonus", type=float, default=100.0)
|
||||
parser.add_argument("--stall-penalty", type=float, default=0.05)
|
||||
parser.add_argument("--stall-steps", type=int, default=40)
|
||||
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
|
||||
parser.add_argument("--milestone-interval", type=int, default=32)
|
||||
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||
parser.add_argument("--clip-reward", action="store_true")
|
||||
parser.add_argument("--fps", type=int, default=30)
|
||||
parser.add_argument("--duration-sec", type=int, default=10)
|
||||
parser.add_argument("--max-steps", type=int, default=5_000)
|
||||
parser.add_argument("--stochastic", action="store_true")
|
||||
parser.add_argument("--output", type=str, default="", help="Output mp4 path.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _action_count_for_movement(movement: str) -> int:
|
||||
return len(get_action_set(movement))
|
||||
|
||||
|
||||
def _model_action_count(model: PPO):
|
||||
if hasattr(model, "action_space") and hasattr(model.action_space, "n"):
|
||||
return int(model.action_space.n)
|
||||
action_net = getattr(getattr(model, "policy", None), "action_net", None)
|
||||
out_features = getattr(action_net, "out_features", None)
|
||||
if out_features is not None:
|
||||
return int(out_features)
|
||||
return None
|
||||
|
||||
|
||||
def resolve_movement(movement_arg: str, model: PPO) -> str:
|
||||
model_action_n = _model_action_count(model)
|
||||
|
||||
if movement_arg == "auto":
|
||||
if model_action_n is not None:
|
||||
for candidate in ("right_only", "simple"):
|
||||
if _action_count_for_movement(candidate) == model_action_n:
|
||||
return candidate
|
||||
return "right_only"
|
||||
|
||||
if model_action_n is not None:
|
||||
expected_n = _action_count_for_movement(movement_arg)
|
||||
if expected_n != model_action_n:
|
||||
raise ValueError(
|
||||
f"movement='{movement_arg}' has {expected_n} actions, but model expects {model_action_n}. "
|
||||
"Use --movement auto or pass the matching movement."
|
||||
)
|
||||
return movement_arg
|
||||
|
||||
|
||||
def resolve_model(user_path: str) -> Path:
|
||||
if user_path:
|
||||
p = Path(user_path).expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Model not found: {p}")
|
||||
return p
|
||||
|
||||
paths = ensure_artifact_paths()
|
||||
latest = latest_model_path(paths.models)
|
||||
if latest is None:
|
||||
raise FileNotFoundError("No model found under artifacts/models. Please run training first.")
|
||||
return latest
|
||||
|
||||
|
||||
def resolve_output_path(user_output: str) -> Path:
|
||||
if user_output:
|
||||
return Path(user_output).expanduser().resolve()
|
||||
|
||||
paths = ensure_artifact_paths()
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return (paths.videos / f"mario_replay_{ts}.mp4").resolve()
|
||||
|
||||
|
||||
def save_frames_fallback(frames, output_path: Path) -> Path:
|
||||
frame_dir = output_path.with_suffix("")
|
||||
frame_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, frame in enumerate(frames):
|
||||
imageio.imwrite(frame_dir / f"frame_{i:06d}.png", frame)
|
||||
return frame_dir
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
seed_everything(args.seed)
|
||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||
|
||||
model_path = resolve_model(args.model_path)
|
||||
output_path = resolve_output_path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = PPO.load(str(model_path))
|
||||
movement = resolve_movement(args.movement, model)
|
||||
print(f"[video] movement={movement} reward_mode={reward_mode}")
|
||||
env = make_mario_env(
|
||||
env_id=args.env_id,
|
||||
seed=args.seed,
|
||||
movement=movement,
|
||||
reward_mode=reward_mode,
|
||||
clip_reward=args.clip_reward,
|
||||
frame_skip=args.frame_skip,
|
||||
render_mode="rgb_array",
|
||||
progress_scale=args.progress_scale,
|
||||
death_penalty=args.death_penalty,
|
||||
flag_bonus=args.flag_bonus,
|
||||
stall_penalty=args.stall_penalty,
|
||||
stall_steps=args.stall_steps,
|
||||
backward_penalty_scale=args.backward_penalty_scale,
|
||||
milestone_interval=args.milestone_interval,
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
)
|
||||
|
||||
obs, _ = reset_compat(env, seed=args.seed)
|
||||
frames = []
|
||||
|
||||
first_frame = env.render()
|
||||
if first_frame is not None:
|
||||
# nes-py may reuse the same frame buffer; copy to avoid aliasing all frames.
|
||||
frames.append(first_frame.copy())
|
||||
|
||||
target_frames = max(1, args.fps * args.duration_sec)
|
||||
done = False
|
||||
step_count = 0
|
||||
reward_sum = 0.0
|
||||
episode_count = 1
|
||||
|
||||
while len(frames) < target_frames and step_count < args.max_steps:
|
||||
action, _ = model.predict(obs, deterministic=not args.stochastic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = int(action.item())
|
||||
obs, reward, terminated, truncated, info = step_compat(env, action)
|
||||
reward_sum += float(reward)
|
||||
step_count += 1
|
||||
done = terminated or truncated
|
||||
|
||||
frame = env.render()
|
||||
if frame is not None:
|
||||
frames.append(frame.copy())
|
||||
|
||||
if done:
|
||||
obs, _ = reset_compat(env, seed=args.seed + episode_count)
|
||||
episode_count += 1
|
||||
frame = env.render()
|
||||
if frame is not None and len(frames) < target_frames:
|
||||
frames.append(frame.copy())
|
||||
done = False
|
||||
|
||||
if not frames:
|
||||
raise RuntimeError("No frames captured. Check render_mode support and environment setup.")
|
||||
|
||||
try:
|
||||
writer = imageio.get_writer(str(output_path), fps=args.fps, codec="libx264", quality=8)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
print(f"[video] Saved mp4: {output_path}")
|
||||
except Exception as exc:
|
||||
frame_dir = save_frames_fallback(frames, output_path)
|
||||
print(f"[warn] mp4 write failed: {exc}")
|
||||
print(f"[fallback] Saved frame sequence: {frame_dir}")
|
||||
print(
|
||||
"[fallback] Convert frames to mp4 with: "
|
||||
f"ffmpeg -framerate {args.fps} -i {frame_dir}/frame_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[stats] frames={len(frames)} approx_sec={len(frames)/max(args.fps,1):.2f} "
|
||||
f"steps={step_count} reward_sum={reward_sum:.2f} episodes={episode_count}"
|
||||
)
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
300
mario-rl-mvp/src/train_ppo.py
Normal file
300
mario-rl-mvp/src/train_ppo.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback
|
||||
from stable_baselines3.common.save_util import load_from_zip_file
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.env import make_env_fn
|
||||
from src.utils import ensure_artifact_paths, resolve_torch_device, seed_everything, write_latest_pointer
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Train PPO agent on NES Super Mario Bros.")
|
||||
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
|
||||
parser.add_argument("--movement", type=str, default="right_only", choices=["right_only", "simple"])
|
||||
parser.add_argument("--total-timesteps", type=int, default=1_000_000)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--n-envs", type=int, default=4)
|
||||
parser.add_argument("--frame-skip", type=int, default=4)
|
||||
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
|
||||
parser.add_argument("--progress-scale", type=float, default=0.02)
|
||||
parser.add_argument("--death-penalty", type=float, default=-50.0)
|
||||
parser.add_argument("--flag-bonus", type=float, default=100.0)
|
||||
parser.add_argument("--stall-penalty", type=float, default=0.05)
|
||||
parser.add_argument("--stall-steps", type=int, default=40)
|
||||
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
|
||||
parser.add_argument("--milestone-interval", type=int, default=32)
|
||||
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||
|
||||
parser.add_argument("--learning-rate", type=float, default=2.5e-4)
|
||||
parser.add_argument("--n-steps", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--n-epochs", type=int, default=4)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||||
parser.add_argument("--ent-coef", type=float, default=0.01)
|
||||
parser.add_argument("--clip-range", type=float, default=0.2)
|
||||
|
||||
parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.")
|
||||
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "mps"])
|
||||
parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.")
|
||||
parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.")
|
||||
parser.add_argument(
|
||||
"--init-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional .zip model path to initialize weights before training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow-partial-init",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Allow partial init when checkpoint is from a different action space "
|
||||
"(e.g. right_only -> simple). Compatible policy layers will be loaded, "
|
||||
"incompatible layers (like action head) are skipped."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--progress-bar", action="store_true", help="Enable SB3 progress bar (requires tqdm/rich).")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_partial_policy_weights(model: PPO, init_model_path: Path, device: str) -> tuple[int, int]:
|
||||
_, params, _ = load_from_zip_file(str(init_model_path), device=device)
|
||||
if params is None or "policy" not in params:
|
||||
raise RuntimeError(f"No policy parameters found in checkpoint: {init_model_path}")
|
||||
|
||||
src_state = params["policy"]
|
||||
dst_state = model.policy.state_dict()
|
||||
|
||||
compatible_state = {}
|
||||
skipped_count = 0
|
||||
for key, value in src_state.items():
|
||||
if key in dst_state and dst_state[key].shape == value.shape:
|
||||
compatible_state[key] = value
|
||||
else:
|
||||
skipped_count += 1
|
||||
|
||||
if not compatible_state:
|
||||
raise RuntimeError("No compatible policy tensors found for partial init.")
|
||||
|
||||
model.policy.load_state_dict(compatible_state, strict=False)
|
||||
return len(compatible_state), skipped_count
|
||||
|
||||
|
||||
class EpisodeEndLoggingCallback(BaseCallback):
|
||||
"""Log per-episode terminal diagnostics to stdout and TensorBoard."""
|
||||
|
||||
REASON_TO_CODE = {"death": 0, "no_progress": 1, "timeout": 2, "clear": 3}
|
||||
|
||||
def __init__(self, n_envs: int, tb_log_dir: Path):
|
||||
super().__init__(verbose=0)
|
||||
self.n_envs = n_envs
|
||||
self.tb_log_dir = tb_log_dir
|
||||
self.tb_writer: SummaryWriter | None = None
|
||||
self.episode_count = 0
|
||||
self.episode_max_x_pos = [0.0 for _ in range(max(n_envs, 1))]
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_reason(info: dict) -> str:
|
||||
if bool(info.get("flag_get", False)):
|
||||
return "clear"
|
||||
if bool(info.get("terminated_by_stall", False)):
|
||||
return "no_progress"
|
||||
if bool(info.get("TimeLimit.truncated", False)):
|
||||
return "timeout"
|
||||
try:
|
||||
if float(info.get("time", 1.0)) <= 0.0:
|
||||
return "timeout"
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return "death"
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
self.tb_log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.tb_writer = SummaryWriter(log_dir=str(self.tb_log_dir))
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
infos = self.locals.get("infos")
|
||||
dones = self.locals.get("dones")
|
||||
if infos is None or dones is None:
|
||||
return True
|
||||
|
||||
for env_idx, (info, done) in enumerate(zip(infos, dones)):
|
||||
if env_idx >= len(self.episode_max_x_pos):
|
||||
continue
|
||||
x_pos = float(info.get("x_pos", 0.0))
|
||||
if x_pos > self.episode_max_x_pos[env_idx]:
|
||||
self.episode_max_x_pos[env_idx] = x_pos
|
||||
|
||||
if not bool(done):
|
||||
continue
|
||||
|
||||
self.episode_count += 1
|
||||
max_x_pos = self.episode_max_x_pos[env_idx]
|
||||
flag_get = 1.0 if bool(info.get("flag_get", False)) else 0.0
|
||||
done_reason = self._resolve_done_reason(info)
|
||||
done_reason_code = float(self.REASON_TO_CODE[done_reason])
|
||||
episode_step = self.episode_count
|
||||
|
||||
self.logger.record_mean("rollout/episode_max_x_pos", max_x_pos)
|
||||
self.logger.record_mean("rollout/flag_get", flag_get)
|
||||
self.logger.record_mean(f"rollout/done_reason_{done_reason}", 1.0)
|
||||
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar("episode_end/episode_max_x_pos", max_x_pos, episode_step)
|
||||
self.tb_writer.add_scalar("episode_end/flag_get", flag_get, episode_step)
|
||||
self.tb_writer.add_scalar("episode_end/done_reason_code", done_reason_code, episode_step)
|
||||
self.tb_writer.add_scalar("episode_end/done_reason_death", 1.0 if done_reason == "death" else 0.0, episode_step)
|
||||
self.tb_writer.add_scalar(
|
||||
"episode_end/done_reason_no_progress", 1.0 if done_reason == "no_progress" else 0.0, episode_step
|
||||
)
|
||||
self.tb_writer.add_scalar("episode_end/done_reason_timeout", 1.0 if done_reason == "timeout" else 0.0, episode_step)
|
||||
self.tb_writer.add_scalar("episode_end/done_reason_clear", 1.0 if done_reason == "clear" else 0.0, episode_step)
|
||||
self.tb_writer.flush()
|
||||
|
||||
print(
|
||||
f"[episode_end] ep={episode_step} env={env_idx} reason={done_reason} "
|
||||
f"max_x={max_x_pos:.1f} flag_get={bool(flag_get)}"
|
||||
)
|
||||
self.episode_max_x_pos[env_idx] = 0.0
|
||||
|
||||
return True
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.close()
|
||||
self.tb_writer = None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
seed_everything(args.seed)
|
||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||
|
||||
paths = ensure_artifact_paths()
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
env_slug = args.env_id.replace("/", "_")
|
||||
run_name = args.run_name.strip() or f"ppo_{env_slug}_{timestamp}"
|
||||
|
||||
run_model_dir = paths.models / run_name
|
||||
run_log_dir = paths.logs / run_name
|
||||
run_tb_dir = run_log_dir / "tb"
|
||||
run_model_dir.mkdir(parents=True, exist_ok=True)
|
||||
run_tb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device, device_msg = resolve_torch_device(args.device)
|
||||
print(f"[device] {device} | {device_msg}")
|
||||
print(f"[run] {run_name}")
|
||||
print(
|
||||
f"[config] env_id={args.env_id}, movement={args.movement}, reward_mode={reward_mode}, "
|
||||
f"clip_reward={args.clip_reward}, n_envs={args.n_envs}"
|
||||
)
|
||||
|
||||
env_fns = [
|
||||
make_env_fn(
|
||||
env_id=args.env_id,
|
||||
seed=args.seed + i,
|
||||
movement=args.movement,
|
||||
reward_mode=reward_mode,
|
||||
clip_reward=args.clip_reward,
|
||||
frame_skip=args.frame_skip,
|
||||
progress_scale=args.progress_scale,
|
||||
death_penalty=args.death_penalty,
|
||||
flag_bonus=args.flag_bonus,
|
||||
stall_penalty=args.stall_penalty,
|
||||
stall_steps=args.stall_steps,
|
||||
backward_penalty_scale=args.backward_penalty_scale,
|
||||
milestone_interval=args.milestone_interval,
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
)
|
||||
for i in range(args.n_envs)
|
||||
]
|
||||
vec_env = DummyVecEnv(env_fns)
|
||||
vec_env = VecMonitor(vec_env, filename=str(run_log_dir / "monitor.csv"))
|
||||
|
||||
model = PPO(
|
||||
policy="CnnPolicy",
|
||||
env=vec_env,
|
||||
learning_rate=args.learning_rate,
|
||||
n_steps=args.n_steps,
|
||||
batch_size=args.batch_size,
|
||||
n_epochs=args.n_epochs,
|
||||
gamma=args.gamma,
|
||||
gae_lambda=args.gae_lambda,
|
||||
ent_coef=args.ent_coef,
|
||||
clip_range=args.clip_range,
|
||||
tensorboard_log=str(run_tb_dir),
|
||||
seed=args.seed,
|
||||
verbose=1,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if args.init_model_path.strip():
|
||||
init_model_path = Path(args.init_model_path).expanduser().resolve()
|
||||
if not init_model_path.exists():
|
||||
raise FileNotFoundError(f"Init model not found: {init_model_path}")
|
||||
try:
|
||||
model.set_parameters(str(init_model_path), exact_match=False, device=device)
|
||||
print(f"[init] Loaded initial weights from: {init_model_path}")
|
||||
except RuntimeError as exc:
|
||||
if not args.allow_partial_init:
|
||||
raise RuntimeError(
|
||||
f"{exc}\n"
|
||||
"Hint: checkpoint and current env may use different action spaces. "
|
||||
"Try one of:\n"
|
||||
"1) keep the same --movement as the checkpoint;\n"
|
||||
"2) remove --init-model-path and train from scratch;\n"
|
||||
"3) add --allow-partial-init to load compatible layers only."
|
||||
) from exc
|
||||
loaded_count, skipped_count = load_partial_policy_weights(model, init_model_path, device=device)
|
||||
print(
|
||||
f"[init] Partial init from {init_model_path}: "
|
||||
f"loaded_tensors={loaded_count}, skipped_tensors={skipped_count}"
|
||||
)
|
||||
|
||||
callback = CheckpointCallback(
|
||||
save_freq=max(args.save_freq // max(args.n_envs, 1), 1),
|
||||
save_path=str(run_model_dir),
|
||||
name_prefix="ppo_mario_ckpt",
|
||||
save_replay_buffer=False,
|
||||
save_vecnormalize=False,
|
||||
)
|
||||
episode_end_logging_callback = EpisodeEndLoggingCallback(
|
||||
n_envs=args.n_envs,
|
||||
tb_log_dir=run_tb_dir / "episode_end",
|
||||
)
|
||||
|
||||
try:
|
||||
model.learn(
|
||||
total_timesteps=args.total_timesteps,
|
||||
callback=CallbackList([callback, episode_end_logging_callback]),
|
||||
tb_log_name="ppo",
|
||||
progress_bar=args.progress_bar,
|
||||
)
|
||||
finally:
|
||||
vec_env.close()
|
||||
|
||||
final_model = run_model_dir / "final_model"
|
||||
model.save(str(final_model))
|
||||
|
||||
latest_model = paths.models / "latest_model.zip"
|
||||
shutil.copy2(str(final_model) + ".zip", latest_model)
|
||||
pointer = write_latest_pointer(paths.models, latest_model)
|
||||
|
||||
print(f"[done] Final model: {final_model}.zip")
|
||||
print(f"[done] Latest model pointer: {pointer}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
75
mario-rl-mvp/src/utils.py
Normal file
75
mario-rl-mvp/src/utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArtifactPaths:
|
||||
root: Path
|
||||
models: Path
|
||||
videos: Path
|
||||
logs: Path
|
||||
|
||||
|
||||
def project_root() -> Path:
|
||||
return Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def ensure_artifact_paths(root: Optional[Path] = None) -> ArtifactPaths:
|
||||
root = root or project_root()
|
||||
artifacts = root / "artifacts"
|
||||
models = artifacts / "models"
|
||||
videos = artifacts / "videos"
|
||||
logs = artifacts / "logs"
|
||||
for p in (artifacts, models, videos, logs):
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return ArtifactPaths(root=artifacts, models=models, videos=videos, logs=logs)
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]:
|
||||
requested = requested.lower().strip()
|
||||
if requested == "cpu":
|
||||
return "cpu", "User requested CPU."
|
||||
|
||||
if requested not in {"auto", "mps"}:
|
||||
return "cpu", f"Unknown device '{requested}', fallback to CPU."
|
||||
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
if not mps_available:
|
||||
if requested == "mps":
|
||||
return "cpu", "MPS requested but unavailable, fallback to CPU."
|
||||
return "cpu", "MPS unavailable, using CPU."
|
||||
|
||||
try:
|
||||
x = torch.ones(8, device="mps")
|
||||
_ = (x * 2).cpu().numpy()
|
||||
return "mps", "MPS is available and passed a quick tensor sanity check."
|
||||
except Exception as exc: # pragma: no cover - hardware dependent
|
||||
return "cpu", f"MPS check failed ({exc}), fallback to CPU."
|
||||
|
||||
|
||||
def latest_model_path(models_dir: Path) -> Optional[Path]:
|
||||
candidates = [p for p in models_dir.rglob("*.zip") if p.is_file()]
|
||||
if not candidates:
|
||||
return None
|
||||
return max(candidates, key=lambda p: p.stat().st_mtime)
|
||||
|
||||
|
||||
def write_latest_pointer(models_dir: Path, model_path: Path) -> Path:
|
||||
pointer = models_dir / "LATEST_MODEL.txt"
|
||||
pointer.write_text(str(model_path.resolve()) + "\n", encoding="utf-8")
|
||||
return pointer
|
||||
Reference in New Issue
Block a user