feat: initial mario rl mvp
This commit is contained in:
44
.gitignore
vendored
Normal file
44
.gitignore
vendored
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# macOS
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# Python cache / build
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.pyo
|
||||||
|
*.pyd
|
||||||
|
*.so
|
||||||
|
*.egg-info/
|
||||||
|
.eggs/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
*/.venv/
|
||||||
|
|
||||||
|
# Logs / test / tooling
|
||||||
|
*.log
|
||||||
|
.pytest_cache/
|
||||||
|
.mypy_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# Jupyter
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
|
# RL artifacts
|
||||||
|
artifacts/
|
||||||
|
*/artifacts/
|
||||||
|
|
||||||
|
# TensorBoard / checkpoints / models
|
||||||
|
runs/
|
||||||
|
checkpoints/
|
||||||
|
models/
|
||||||
|
videos/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
339
mario-rl-mvp/README.md
Normal file
339
mario-rl-mvp/README.md
Normal file
@@ -0,0 +1,339 @@
|
|||||||
|
# Mario RL MVP (macOS Apple Silicon)
|
||||||
|
|
||||||
|
最小可运行工程:使用像素输入 + 传统 CNN policy(`stable-baselines3` PPO)训练 `gym-super-mario-bros / nes-py` 智能体,不使用大语言模型。
|
||||||
|
|
||||||
|
## 1. 项目结构
|
||||||
|
|
||||||
|
```text
|
||||||
|
mario-rl-mvp/
|
||||||
|
src/
|
||||||
|
env.py
|
||||||
|
train_ppo.py
|
||||||
|
eval.py
|
||||||
|
record_video.py
|
||||||
|
utils.py
|
||||||
|
artifacts/
|
||||||
|
models/
|
||||||
|
videos/
|
||||||
|
logs/
|
||||||
|
requirements.txt
|
||||||
|
README.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## 2. 环境准备(macOS / Apple Silicon)
|
||||||
|
|
||||||
|
建议 Python 3.9+(本机默认 `python3` 即可)。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /Users/roog/Work/FNT/SpMr/mario-rl-mvp
|
||||||
|
python3 -m venv .venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
python -m pip install --upgrade pip setuptools wheel
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
可选系统依赖(用于 ffmpeg 转码与潜在 SDL 兼容):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
brew install ffmpeg sdl2
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. 一条命令开始训练
|
||||||
|
|
||||||
|
默认 CPU 训练(如果检测到可用且稳定的 MPS,会自动尝试启用,否则自动回退 CPU):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo
|
||||||
|
```
|
||||||
|
|
||||||
|
常用覆盖参数:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo \
|
||||||
|
--total-timesteps 1000000 \
|
||||||
|
--n-envs 4 \
|
||||||
|
--save-freq 50000 \
|
||||||
|
--env-id SuperMarioBros-1-1-v0 \
|
||||||
|
--movement right_only \
|
||||||
|
--seed 42
|
||||||
|
```
|
||||||
|
|
||||||
|
从已有 checkpoint 初始化后继续训练(可同时改超参数,如 `--ent-coef`):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo \
|
||||||
|
--init-model-path artifacts/models/latest_model.zip \
|
||||||
|
--total-timesteps 500000 \
|
||||||
|
--ent-coef 0.02 \
|
||||||
|
--learning-rate 1e-4
|
||||||
|
```
|
||||||
|
|
||||||
|
如果切换了动作空间(例如 checkpoint 是 `right_only`,当前想用 `simple`),可用部分加载:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo \
|
||||||
|
--init-model-path artifacts/models/latest_model.zip \
|
||||||
|
--allow-partial-init \
|
||||||
|
--movement simple \
|
||||||
|
--total-timesteps 300000
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3.1 从已有模型继续训练(`--init-model-path`)
|
||||||
|
|
||||||
|
- 用途:加载已有 `.zip` 权重后继续训练,适合“不中断实验目标但调整探索参数”。
|
||||||
|
- 常见场景:当前策略陷入局部最优(例如 `approx_kl` 和 `policy_gradient_loss` 长期接近 0),希望从已有模型继续探索。
|
||||||
|
- 注意:这不是“热更新”,仍然需要停止当前训练进程后用新命令重启。
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo \
|
||||||
|
--init-model-path artifacts/models/latest_model.zip \
|
||||||
|
--ent-coef 0.02 \
|
||||||
|
--learning-rate 1e-4 \
|
||||||
|
--total-timesteps 300000
|
||||||
|
```
|
||||||
|
|
||||||
|
训练输出:
|
||||||
|
- stdout:PPO 训练日志
|
||||||
|
- TensorBoard:`artifacts/logs/<run_name>/tb/`
|
||||||
|
- checkpoint:`artifacts/models/<run_name>/ppo_mario_ckpt_*.zip`
|
||||||
|
- final model:`artifacts/models/<run_name>/final_model.zip`
|
||||||
|
- latest 指针:`artifacts/models/latest_model.zip` + `LATEST_MODEL.txt`
|
||||||
|
|
||||||
|
启动 TensorBoard:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
tensorboard --logdir artifacts/logs --port 6006
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3.2 训练日志字段速查(PPO)
|
||||||
|
|
||||||
|
训练时你会看到类似:
|
||||||
|
|
||||||
|
```text
|
||||||
|
| rollout/ep_len_mean | rollout/ep_rew_mean | ... |
|
||||||
|
| train/approx_kl | train/entropy_loss | ... |
|
||||||
|
```
|
||||||
|
|
||||||
|
下面是常用字段的含义(对应你贴出来那组):
|
||||||
|
|
||||||
|
- `rollout/ep_len_mean`:最近一批 episode 的平均步数。越大不一定越好,要结合 reward 一起看。
|
||||||
|
- `rollout/ep_rew_mean`:最近一批 episode 的平均回报。通常越高越好。
|
||||||
|
- `time/fps`:训练吞吐(每秒环境步数),只代表速度,不代表策略质量。
|
||||||
|
- `time/iterations`:第几次 rollout + update 循环。
|
||||||
|
- `time/time_elapsed`:训练已运行的秒数。
|
||||||
|
- `time/total_timesteps`:累计环境交互步数(达到你设定的 `--total-timesteps` 会停止)。
|
||||||
|
- `train/approx_kl`:新旧策略差异大小。太大说明更新过猛;接近 0 说明几乎没在更新。极小负数通常是数值误差,可当作 0。
|
||||||
|
- `train/clip_fraction`:有多少样本触发 PPO clipping。长期为 0 且 KL 也接近 0,常见于“策略基本不再更新”。
|
||||||
|
- `train/clip_range`:PPO 的 clipping 阈值(默认 0.2)。
|
||||||
|
- `train/entropy_loss`:探索强度指标。绝对值越接近 0,策略越确定、探索越少。
|
||||||
|
- `train/explained_variance`:价值网络对回报的解释度,越接近 1 越好,接近 0 说明 value 还不稳。
|
||||||
|
- `train/learning_rate`:优化器步长(参数更新幅度),不是硬件速度。
|
||||||
|
- `train/loss`:总损失(由多个部分组成),主要看趋势,不单看绝对值。
|
||||||
|
- `train/policy_gradient_loss`:策略网络的更新信号。长期接近 0 可能表示 actor 更新很弱。
|
||||||
|
- `train/value_loss`:价值网络误差。过大通常代表 critic 拟合还不稳定。
|
||||||
|
|
||||||
|
### 快速判断(实用版)
|
||||||
|
|
||||||
|
- `ep_rew_mean` / `avg_max_x_pos` 持续上升:一般在变好。
|
||||||
|
- `approx_kl≈0` + `clip_fraction=0` + `policy_gradient_loss≈0`:大概率卡住(更新几乎停了)。
|
||||||
|
- `entropy_loss` 绝对值太小且长期不变:探索不足,可尝试提高 `--ent-coef`。
|
||||||
|
|
||||||
|
## 4. 评估模型
|
||||||
|
|
||||||
|
加载最新模型,跑 N 个 episode,输出平均指标:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.eval --episodes 5
|
||||||
|
```
|
||||||
|
|
||||||
|
可指定模型:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.eval --model-path artifacts/models/latest_model.zip --episodes 10
|
||||||
|
```
|
||||||
|
|
||||||
|
注意:`eval.py` 默认 `--movement auto`,会按模型动作维度自动匹配 `right_only/simple`,避免动作空间不一致导致 `KeyError`。
|
||||||
|
|
||||||
|
输出指标包括:
|
||||||
|
- `avg_reward`
|
||||||
|
- `avg_max_x_pos`
|
||||||
|
- `clear_rate`(`flag_get=True` 的比例)
|
||||||
|
|
||||||
|
|
||||||
|
查看步数
|
||||||
|
|
||||||
|
```php
|
||||||
|
unzip -p artifacts/models/latest_model.zip data | rg '"num_timesteps"|"_total_timesteps"|"_tensorboard_log"'
|
||||||
|
```
|
||||||
|
|
||||||
|
num_timesteps = 151552
|
||||||
|
_total_timesteps = 150000
|
||||||
|
|
||||||
|
## 5. 录制回放视频(无窗口/headless)
|
||||||
|
|
||||||
|
默认录制约 10 秒 mp4 到 `artifacts/videos/`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.record_video --duration-sec 10 --fps 30
|
||||||
|
```
|
||||||
|
|
||||||
|
可指定输出路径:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.record_video --output artifacts/videos/demo.mp4 --duration-sec 10
|
||||||
|
```
|
||||||
|
|
||||||
|
注意:`record_video.py` 默认 `--movement auto`,会按模型自动匹配动作空间。
|
||||||
|
|
||||||
|
实现方式:
|
||||||
|
- 使用 `render_mode=rgb_array`,无需打开窗口
|
||||||
|
- 默认通过 `imageio + ffmpeg` 输出 mp4
|
||||||
|
- 若 mp4 写入失败,会自动降级保存帧序列(PNG),并打印 ffmpeg 转码命令
|
||||||
|
|
||||||
|
## 6. 动作空间选择说明
|
||||||
|
|
||||||
|
默认 `RIGHT_ONLY`,原因:
|
||||||
|
- 动作更少,探索空间更小,MVP 更快收敛到“向右推进”策略
|
||||||
|
- 适合先验证训练闭环
|
||||||
|
|
||||||
|
可切到 `SIMPLE_MOVEMENT`(动作更丰富):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo --movement simple
|
||||||
|
```
|
||||||
|
|
||||||
|
## 7. 预处理与奖励
|
||||||
|
|
||||||
|
默认预处理链路:
|
||||||
|
- 跳帧:`frame_skip=4`
|
||||||
|
- 灰度:`RGB -> Gray`
|
||||||
|
- 缩放:`84x84`
|
||||||
|
- 帧堆叠:`4`
|
||||||
|
- 通道布局:`CHW`(兼容 `CnnPolicy`)
|
||||||
|
|
||||||
|
奖励:
|
||||||
|
- `--reward-mode raw`:原始奖励(默认)
|
||||||
|
- `--reward-mode clip`:裁剪奖励 `sign(reward)`(等价于旧参数 `--clip-reward`)
|
||||||
|
- `--reward-mode progress`:奖励塑形模式,额外包含:
|
||||||
|
- 前进增益奖励(`--progress-scale`)
|
||||||
|
- 死亡惩罚(`--death-penalty`)
|
||||||
|
- 通关奖励(`--flag-bonus`)
|
||||||
|
- 原地卡住惩罚(`--stall-penalty` + `--stall-steps`)
|
||||||
|
- 后退惩罚(`--backward-penalty-scale`)
|
||||||
|
|
||||||
|
使用裁剪奖励:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo --reward-mode clip
|
||||||
|
```
|
||||||
|
|
||||||
|
针对“卡在固定位置(如 x=314 撞蘑菇)”的推荐续训命令:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo \
|
||||||
|
--init-model-path artifacts/models/latest_model.zip \
|
||||||
|
--allow-partial-init \
|
||||||
|
--reward-mode progress \
|
||||||
|
--movement simple \
|
||||||
|
--ent-coef 0.04 \
|
||||||
|
--learning-rate 1e-4 \
|
||||||
|
--n-steps 512 \
|
||||||
|
--gamma 0.995 \
|
||||||
|
--death-penalty -120 \
|
||||||
|
--stall-penalty 0.2 \
|
||||||
|
--stall-steps 20 \
|
||||||
|
--backward-penalty-scale 0.03 \
|
||||||
|
--milestone-bonus 2.0 \
|
||||||
|
--no-progress-terminate-steps 80 \
|
||||||
|
--no-progress-terminate-penalty 30 \
|
||||||
|
--total-timesteps 150000
|
||||||
|
```
|
||||||
|
|
||||||
|
## 8. 常见问题排查
|
||||||
|
|
||||||
|
### 8.1 `pip install` 失败(nes-py/gym-super-mario-bros 编译问题)
|
||||||
|
|
||||||
|
先安装工具链并重试:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
xcode-select --install
|
||||||
|
brew install cmake swig sdl2
|
||||||
|
pip install --upgrade pip setuptools wheel
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8.2 MPS 不稳定或报错
|
||||||
|
|
||||||
|
强制 CPU:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo --device cpu
|
||||||
|
```
|
||||||
|
|
||||||
|
说明:脚本会先做一次 MPS 张量 sanity check,失败自动回退 CPU。
|
||||||
|
|
||||||
|
### 8.3 视频写入失败(ffmpeg/codec)
|
||||||
|
|
||||||
|
1) 安装系统 ffmpeg:
|
||||||
|
```bash
|
||||||
|
brew install ffmpeg
|
||||||
|
```
|
||||||
|
|
||||||
|
2) 已降级保存帧序列时,手动转码:
|
||||||
|
```bash
|
||||||
|
ffmpeg -framerate 30 -i artifacts/videos/<name>_frames/frame_%06d.png -c:v libx264 -pix_fmt yuv420p artifacts/videos/<name>.mp4
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8.4 图形窗口相关报错
|
||||||
|
|
||||||
|
本工程默认 `rgb_array` 录制,不依赖 GUI 窗口。
|
||||||
|
若仍遇到 SDL 问题,可显式设置:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export SDL_VIDEODRIVER=dummy
|
||||||
|
python -m src.record_video --duration-sec 10
|
||||||
|
```
|
||||||
|
|
||||||
|
### 8.5 `size mismatch for action_net`(加载旧模型时报错)
|
||||||
|
|
||||||
|
典型原因:旧 checkpoint 的动作空间与当前配置不同(如 `right_only`=5 动作,`simple`=7 动作)。
|
||||||
|
|
||||||
|
可选修复:
|
||||||
|
|
||||||
|
1) 保持和 checkpoint 一致的动作空间:
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo --init-model-path /Users/roog/Work/FNT/SpMr/mario-rl-mvp/artifacts/models/ppo_SuperMarioBros-1-1-v0_20260212_164717/ppo_mario_ckpt_150000_steps.zip --movement right_only
|
||||||
|
```
|
||||||
|
|
||||||
|
2) 若你确实要切动作空间,用部分初始化(跳过不兼容动作头):
|
||||||
|
```bash
|
||||||
|
python -m src.train_ppo --init-model-path artifacts/models/latest_model.zip --movement simple --allow-partial-init
|
||||||
|
```
|
||||||
|
|
||||||
|
3) 或者直接不加载旧模型,从头训练新动作空间。
|
||||||
|
|
||||||
|
## 9. 最小 smoke test(按顺序执行)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd /Users/roog/Work/FNT/SpMr/mario-rl-mvp
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
# 1) 训练 1e4 steps,并至少写出 checkpoint + final model
|
||||||
|
python -m src.train_ppo \
|
||||||
|
--total-timesteps 10000 \
|
||||||
|
--save-freq 2000 \
|
||||||
|
--n-envs 1 \
|
||||||
|
--device cpu \
|
||||||
|
--movement right_only
|
||||||
|
|
||||||
|
# 2) 快速评估
|
||||||
|
python -m src.eval --episodes 2 --max-steps 2000
|
||||||
|
|
||||||
|
# 3) 录制 10 秒视频
|
||||||
|
python -m src.record_video --duration-sec 10 --fps 30
|
||||||
|
```
|
||||||
|
|
||||||
|
验收标准:
|
||||||
|
- `artifacts/models/` 下有 `.zip` 模型
|
||||||
|
- `artifacts/logs/` 下有 TensorBoard event 文件
|
||||||
|
- `artifacts/videos/` 下有 `.mp4`(或失败时有 `_frames/` 帧序列)
|
||||||
12
mario-rl-mvp/requirements.txt
Normal file
12
mario-rl-mvp/requirements.txt
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
torch==2.5.1
|
||||||
|
stable-baselines3==2.3.2
|
||||||
|
gym==0.26.2
|
||||||
|
gymnasium==0.29.1
|
||||||
|
shimmy==1.3.0
|
||||||
|
gym-super-mario-bros==7.4.0
|
||||||
|
nes-py==8.2.1
|
||||||
|
opencv-python==4.10.0.84
|
||||||
|
numpy==1.26.4
|
||||||
|
tensorboard==2.18.0
|
||||||
|
imageio==2.36.1
|
||||||
|
imageio-ffmpeg==0.5.1
|
||||||
1
mario-rl-mvp/src/__init__.py
Normal file
1
mario-rl-mvp/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Mario RL MVP package."""
|
||||||
337
mario-rl-mvp/src/env.py
Normal file
337
mario-rl-mvp/src/env.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import deque
|
||||||
|
from typing import Any, Callable, Deque, Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import gym
|
||||||
|
import gym_super_mario_bros
|
||||||
|
import numpy as np
|
||||||
|
from gym.spaces import Box
|
||||||
|
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT
|
||||||
|
from nes_py.wrappers import JoypadSpace
|
||||||
|
|
||||||
|
|
||||||
|
def reset_compat(env: gym.Env, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict[str, Any]]:
|
||||||
|
try:
|
||||||
|
result = env.reset(seed=seed)
|
||||||
|
except TypeError:
|
||||||
|
result = env.reset()
|
||||||
|
if seed is not None and hasattr(env, "seed"):
|
||||||
|
env.seed(seed)
|
||||||
|
|
||||||
|
if isinstance(result, tuple) and len(result) == 2:
|
||||||
|
obs, info = result
|
||||||
|
return obs, info
|
||||||
|
return result, {}
|
||||||
|
|
||||||
|
|
||||||
|
def step_compat(env: gym.Env, action: Any) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
|
||||||
|
result = env.step(action)
|
||||||
|
if isinstance(result, tuple) and len(result) == 5:
|
||||||
|
obs, reward, terminated, truncated, info = result
|
||||||
|
return obs, float(reward), bool(terminated), bool(truncated), info
|
||||||
|
if isinstance(result, tuple) and len(result) == 4:
|
||||||
|
obs, reward, done, info = result
|
||||||
|
return obs, float(reward), bool(done), False, info
|
||||||
|
raise RuntimeError(f"Unexpected step return format: {type(result)} / {result}")
|
||||||
|
|
||||||
|
|
||||||
|
class SkipFrame(gym.Wrapper):
|
||||||
|
def __init__(self, env: gym.Env, skip: int = 4):
|
||||||
|
super().__init__(env)
|
||||||
|
self._skip = skip
|
||||||
|
|
||||||
|
def step(self, action: Any):
|
||||||
|
total_reward = 0.0
|
||||||
|
terminated = False
|
||||||
|
truncated = False
|
||||||
|
info: Dict[str, Any] = {}
|
||||||
|
obs = None
|
||||||
|
for _ in range(self._skip):
|
||||||
|
obs, reward, terminated, truncated, info = step_compat(self.env, action)
|
||||||
|
total_reward += reward
|
||||||
|
if terminated or truncated:
|
||||||
|
break
|
||||||
|
return obs, total_reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessFrame(gym.ObservationWrapper):
|
||||||
|
"""Convert RGB frame to grayscale 84x84 uint8."""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, width: int = 84, height: int = 84):
|
||||||
|
super().__init__(env)
|
||||||
|
self.width = width
|
||||||
|
self.height = height
|
||||||
|
self.observation_space = Box(low=0, high=255, shape=(height, width, 1), dtype=np.uint8)
|
||||||
|
|
||||||
|
def observation(self, observation: np.ndarray) -> np.ndarray:
|
||||||
|
if observation.ndim == 3 and observation.shape[2] == 3:
|
||||||
|
gray = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
|
||||||
|
elif observation.ndim == 2:
|
||||||
|
gray = observation
|
||||||
|
else:
|
||||||
|
gray = np.squeeze(observation)
|
||||||
|
resized = cv2.resize(gray, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
||||||
|
return resized[:, :, None].astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelLastFrameStack(gym.Wrapper):
|
||||||
|
"""Stack frames on the channel axis: (H, W, C*num_stack)."""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env, num_stack: int = 4):
|
||||||
|
super().__init__(env)
|
||||||
|
self.num_stack = num_stack
|
||||||
|
self.frames: Deque[np.ndarray] = deque(maxlen=num_stack)
|
||||||
|
|
||||||
|
obs_space = env.observation_space
|
||||||
|
assert isinstance(obs_space, Box), "Frame stack requires Box observation space."
|
||||||
|
h, w, c = obs_space.shape
|
||||||
|
self.observation_space = Box(
|
||||||
|
low=0,
|
||||||
|
high=255,
|
||||||
|
shape=(h, w, c * num_stack),
|
||||||
|
dtype=np.uint8,
|
||||||
|
)
|
||||||
|
|
||||||
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
|
del options
|
||||||
|
obs, info = reset_compat(self.env, seed=seed)
|
||||||
|
self.frames.clear()
|
||||||
|
for _ in range(self.num_stack):
|
||||||
|
self.frames.append(obs)
|
||||||
|
return self._get_observation(), info
|
||||||
|
|
||||||
|
def step(self, action: Any):
|
||||||
|
obs, reward, terminated, truncated, info = step_compat(self.env, action)
|
||||||
|
self.frames.append(obs)
|
||||||
|
return self._get_observation(), reward, terminated, truncated, info
|
||||||
|
|
||||||
|
def _get_observation(self) -> np.ndarray:
|
||||||
|
assert len(self.frames) == self.num_stack
|
||||||
|
return np.concatenate(list(self.frames), axis=2)
|
||||||
|
|
||||||
|
|
||||||
|
class TransposeObservation(gym.ObservationWrapper):
|
||||||
|
"""Convert observation from HWC to CHW for CNN policy."""
|
||||||
|
|
||||||
|
def __init__(self, env: gym.Env):
|
||||||
|
super().__init__(env)
|
||||||
|
obs_space = env.observation_space
|
||||||
|
assert isinstance(obs_space, Box), "TransposeObservation requires Box observation space."
|
||||||
|
h, w, c = obs_space.shape
|
||||||
|
self.observation_space = Box(low=0, high=255, shape=(c, h, w), dtype=obs_space.dtype)
|
||||||
|
|
||||||
|
def observation(self, observation: np.ndarray) -> np.ndarray:
|
||||||
|
return np.transpose(observation, (2, 0, 1)).astype(np.uint8)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipRewardEnv(gym.RewardWrapper):
|
||||||
|
def reward(self, reward):
|
||||||
|
return float(np.sign(reward))
|
||||||
|
|
||||||
|
|
||||||
|
class ProgressRewardEnv(gym.Wrapper):
|
||||||
|
"""Reward shaping focused on moving right and avoiding local traps."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: gym.Env,
|
||||||
|
progress_scale: float = 0.02,
|
||||||
|
death_penalty: float = -50.0,
|
||||||
|
flag_bonus: float = 100.0,
|
||||||
|
stall_penalty: float = 0.05,
|
||||||
|
stall_steps: int = 40,
|
||||||
|
backward_penalty_scale: float = 0.01,
|
||||||
|
milestone_interval: int = 32,
|
||||||
|
milestone_bonus: float = 1.0,
|
||||||
|
no_progress_terminate_steps: int = 120,
|
||||||
|
no_progress_terminate_penalty: float = 20.0,
|
||||||
|
):
|
||||||
|
super().__init__(env)
|
||||||
|
self.progress_scale = progress_scale
|
||||||
|
self.death_penalty = death_penalty
|
||||||
|
self.flag_bonus = flag_bonus
|
||||||
|
self.stall_penalty = stall_penalty
|
||||||
|
self.stall_steps = stall_steps
|
||||||
|
self.backward_penalty_scale = backward_penalty_scale
|
||||||
|
self.milestone_interval = milestone_interval
|
||||||
|
self.milestone_bonus = milestone_bonus
|
||||||
|
self.no_progress_terminate_steps = no_progress_terminate_steps
|
||||||
|
self.no_progress_terminate_penalty = no_progress_terminate_penalty
|
||||||
|
self._last_x_pos: Optional[float] = None
|
||||||
|
self._best_x_pos = 0.0
|
||||||
|
self._stall_count = 0
|
||||||
|
self._next_milestone_x = float(milestone_interval)
|
||||||
|
|
||||||
|
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||||
|
del options
|
||||||
|
obs, info = reset_compat(self.env, seed=seed)
|
||||||
|
self._last_x_pos = float(info.get("x_pos", 0.0))
|
||||||
|
self._best_x_pos = self._last_x_pos
|
||||||
|
self._stall_count = 0
|
||||||
|
if self.milestone_interval > 0:
|
||||||
|
k = int(self._best_x_pos // self.milestone_interval) + 1
|
||||||
|
self._next_milestone_x = float(k * self.milestone_interval)
|
||||||
|
return obs, info
|
||||||
|
|
||||||
|
def step(self, action: Any):
|
||||||
|
obs, reward, terminated, truncated, info = step_compat(self.env, action)
|
||||||
|
x_pos = float(info.get("x_pos", 0.0))
|
||||||
|
|
||||||
|
if self._last_x_pos is None:
|
||||||
|
delta_x = 0.0
|
||||||
|
else:
|
||||||
|
delta_x = x_pos - self._last_x_pos
|
||||||
|
self._last_x_pos = x_pos
|
||||||
|
|
||||||
|
shaped_reward = float(reward)
|
||||||
|
if delta_x > 0:
|
||||||
|
shaped_reward += self.progress_scale * delta_x
|
||||||
|
self._stall_count = 0
|
||||||
|
if x_pos > self._best_x_pos:
|
||||||
|
self._best_x_pos = x_pos
|
||||||
|
if self.milestone_interval > 0 and self.milestone_bonus != 0.0:
|
||||||
|
while x_pos >= self._next_milestone_x:
|
||||||
|
shaped_reward += self.milestone_bonus
|
||||||
|
self._next_milestone_x += self.milestone_interval
|
||||||
|
else:
|
||||||
|
self._stall_count += 1
|
||||||
|
if delta_x < 0:
|
||||||
|
shaped_reward -= self.backward_penalty_scale * abs(delta_x)
|
||||||
|
|
||||||
|
if self._stall_count >= self.stall_steps:
|
||||||
|
shaped_reward -= self.stall_penalty
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.no_progress_terminate_steps > 0
|
||||||
|
and not terminated
|
||||||
|
and not truncated
|
||||||
|
and self._stall_count >= self.no_progress_terminate_steps
|
||||||
|
):
|
||||||
|
truncated = True
|
||||||
|
shaped_reward -= self.no_progress_terminate_penalty
|
||||||
|
info["terminated_by_stall"] = True
|
||||||
|
|
||||||
|
if terminated or truncated:
|
||||||
|
if bool(info.get("flag_get", False)):
|
||||||
|
shaped_reward += self.flag_bonus
|
||||||
|
elif terminated:
|
||||||
|
shaped_reward += self.death_penalty
|
||||||
|
|
||||||
|
return obs, shaped_reward, terminated, truncated, info
|
||||||
|
|
||||||
|
|
||||||
|
def get_action_set(name: str):
|
||||||
|
name = name.lower().strip()
|
||||||
|
if name == "simple":
|
||||||
|
return SIMPLE_MOVEMENT
|
||||||
|
if name == "right_only":
|
||||||
|
return RIGHT_ONLY
|
||||||
|
raise ValueError(f"Unsupported movement='{name}'. Use one of: right_only, simple")
|
||||||
|
|
||||||
|
|
||||||
|
def make_mario_env(
|
||||||
|
env_id: str = "SuperMarioBros-1-1-v0",
|
||||||
|
seed: int = 0,
|
||||||
|
movement: str = "right_only",
|
||||||
|
reward_mode: str = "raw",
|
||||||
|
clip_reward: bool = False,
|
||||||
|
frame_skip: int = 4,
|
||||||
|
render_mode: Optional[str] = None,
|
||||||
|
progress_scale: float = 0.02,
|
||||||
|
death_penalty: float = -50.0,
|
||||||
|
flag_bonus: float = 100.0,
|
||||||
|
stall_penalty: float = 0.05,
|
||||||
|
stall_steps: int = 40,
|
||||||
|
backward_penalty_scale: float = 0.01,
|
||||||
|
milestone_interval: int = 32,
|
||||||
|
milestone_bonus: float = 1.0,
|
||||||
|
no_progress_terminate_steps: int = 120,
|
||||||
|
no_progress_terminate_penalty: float = 20.0,
|
||||||
|
) -> gym.Env:
|
||||||
|
kwargs: Dict[str, Any] = {}
|
||||||
|
if render_mode is not None:
|
||||||
|
kwargs["render_mode"] = render_mode
|
||||||
|
|
||||||
|
try:
|
||||||
|
env = gym_super_mario_bros.make(env_id, apply_api_compatibility=True, **kwargs)
|
||||||
|
except TypeError:
|
||||||
|
env = gym_super_mario_bros.make(env_id, **kwargs)
|
||||||
|
|
||||||
|
env = JoypadSpace(env, get_action_set(movement))
|
||||||
|
env = SkipFrame(env, skip=frame_skip)
|
||||||
|
|
||||||
|
mode = reward_mode.lower().strip()
|
||||||
|
if clip_reward and mode == "raw":
|
||||||
|
mode = "clip"
|
||||||
|
if mode == "clip":
|
||||||
|
env = ClipRewardEnv(env)
|
||||||
|
elif mode == "progress":
|
||||||
|
env = ProgressRewardEnv(
|
||||||
|
env=env,
|
||||||
|
progress_scale=progress_scale,
|
||||||
|
death_penalty=death_penalty,
|
||||||
|
flag_bonus=flag_bonus,
|
||||||
|
stall_penalty=stall_penalty,
|
||||||
|
stall_steps=stall_steps,
|
||||||
|
backward_penalty_scale=backward_penalty_scale,
|
||||||
|
milestone_interval=milestone_interval,
|
||||||
|
milestone_bonus=milestone_bonus,
|
||||||
|
no_progress_terminate_steps=no_progress_terminate_steps,
|
||||||
|
no_progress_terminate_penalty=no_progress_terminate_penalty,
|
||||||
|
)
|
||||||
|
elif mode != "raw":
|
||||||
|
raise ValueError(f"Unsupported reward_mode='{reward_mode}'. Use one of: raw, clip, progress")
|
||||||
|
|
||||||
|
env = PreprocessFrame(env, width=84, height=84)
|
||||||
|
env = ChannelLastFrameStack(env, num_stack=4)
|
||||||
|
env = TransposeObservation(env)
|
||||||
|
|
||||||
|
# Seed once so each env subprocess/dummy env has deterministic startup.
|
||||||
|
reset_compat(env, seed=seed)
|
||||||
|
if hasattr(env.action_space, "seed"):
|
||||||
|
env.action_space.seed(seed)
|
||||||
|
return env
|
||||||
|
|
||||||
|
|
||||||
|
def make_env_fn(
|
||||||
|
env_id: str,
|
||||||
|
seed: int,
|
||||||
|
movement: str,
|
||||||
|
reward_mode: str,
|
||||||
|
clip_reward: bool,
|
||||||
|
frame_skip: int,
|
||||||
|
progress_scale: float,
|
||||||
|
death_penalty: float,
|
||||||
|
flag_bonus: float,
|
||||||
|
stall_penalty: float,
|
||||||
|
stall_steps: int,
|
||||||
|
backward_penalty_scale: float,
|
||||||
|
milestone_interval: int,
|
||||||
|
milestone_bonus: float,
|
||||||
|
no_progress_terminate_steps: int,
|
||||||
|
no_progress_terminate_penalty: float,
|
||||||
|
) -> Callable[[], gym.Env]:
|
||||||
|
def _thunk() -> gym.Env:
|
||||||
|
return make_mario_env(
|
||||||
|
env_id=env_id,
|
||||||
|
seed=seed,
|
||||||
|
movement=movement,
|
||||||
|
reward_mode=reward_mode,
|
||||||
|
clip_reward=clip_reward,
|
||||||
|
frame_skip=frame_skip,
|
||||||
|
render_mode=None,
|
||||||
|
progress_scale=progress_scale,
|
||||||
|
death_penalty=death_penalty,
|
||||||
|
flag_bonus=flag_bonus,
|
||||||
|
stall_penalty=stall_penalty,
|
||||||
|
stall_steps=stall_steps,
|
||||||
|
backward_penalty_scale=backward_penalty_scale,
|
||||||
|
milestone_interval=milestone_interval,
|
||||||
|
milestone_bonus=milestone_bonus,
|
||||||
|
no_progress_terminate_steps=no_progress_terminate_steps,
|
||||||
|
no_progress_terminate_penalty=no_progress_terminate_penalty,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _thunk
|
||||||
162
mario-rl-mvp/src/eval.py
Normal file
162
mario-rl-mvp/src/eval.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from statistics import mean
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
|
||||||
|
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
||||||
|
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Evaluate trained Mario PPO agent.")
|
||||||
|
parser.add_argument("--model-path", type=str, default="", help="Path to .zip model. If empty, use latest model.")
|
||||||
|
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
|
||||||
|
parser.add_argument("--movement", type=str, default="auto", choices=["auto", "right_only", "simple"])
|
||||||
|
parser.add_argument("--episodes", type=int, default=5)
|
||||||
|
parser.add_argument("--max-steps", type=int, default=5_000)
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument("--frame-skip", type=int, default=4)
|
||||||
|
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
|
||||||
|
parser.add_argument("--progress-scale", type=float, default=0.02)
|
||||||
|
parser.add_argument("--death-penalty", type=float, default=-50.0)
|
||||||
|
parser.add_argument("--flag-bonus", type=float, default=100.0)
|
||||||
|
parser.add_argument("--stall-penalty", type=float, default=0.05)
|
||||||
|
parser.add_argument("--stall-steps", type=int, default=40)
|
||||||
|
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
|
||||||
|
parser.add_argument("--milestone-interval", type=int, default=32)
|
||||||
|
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||||
|
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||||
|
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||||
|
parser.add_argument("--clip-reward", action="store_true")
|
||||||
|
parser.add_argument("--stochastic", action="store_true", help="Use stochastic policy (deterministic=False).")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _action_count_for_movement(movement: str) -> int:
|
||||||
|
return len(get_action_set(movement))
|
||||||
|
|
||||||
|
|
||||||
|
def _model_action_count(model: PPO):
|
||||||
|
if hasattr(model, "action_space") and hasattr(model.action_space, "n"):
|
||||||
|
return int(model.action_space.n)
|
||||||
|
action_net = getattr(getattr(model, "policy", None), "action_net", None)
|
||||||
|
out_features = getattr(action_net, "out_features", None)
|
||||||
|
if out_features is not None:
|
||||||
|
return int(out_features)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_movement(movement_arg: str, model: PPO) -> str:
|
||||||
|
model_action_n = _model_action_count(model)
|
||||||
|
|
||||||
|
if movement_arg == "auto":
|
||||||
|
if model_action_n is not None:
|
||||||
|
for candidate in ("right_only", "simple"):
|
||||||
|
if _action_count_for_movement(candidate) == model_action_n:
|
||||||
|
return candidate
|
||||||
|
return "right_only"
|
||||||
|
|
||||||
|
if model_action_n is not None:
|
||||||
|
expected_n = _action_count_for_movement(movement_arg)
|
||||||
|
if expected_n != model_action_n:
|
||||||
|
raise ValueError(
|
||||||
|
f"movement='{movement_arg}' has {expected_n} actions, but model expects {model_action_n}. "
|
||||||
|
"Use --movement auto or pass the matching movement."
|
||||||
|
)
|
||||||
|
return movement_arg
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model_path(user_path: str) -> Path:
|
||||||
|
if user_path:
|
||||||
|
p = Path(user_path).expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Model not found: {p}")
|
||||||
|
return p
|
||||||
|
|
||||||
|
paths = ensure_artifact_paths()
|
||||||
|
latest = latest_model_path(paths.models)
|
||||||
|
if latest is None:
|
||||||
|
raise FileNotFoundError("No model found under artifacts/models. Please run training first.")
|
||||||
|
return latest
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
seed_everything(args.seed)
|
||||||
|
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||||
|
|
||||||
|
model_path = resolve_model_path(args.model_path)
|
||||||
|
print(f"[eval] model={model_path}")
|
||||||
|
model = PPO.load(str(model_path))
|
||||||
|
movement = resolve_movement(args.movement, model)
|
||||||
|
print(f"[eval] movement={movement} reward_mode={reward_mode}")
|
||||||
|
|
||||||
|
env = make_mario_env(
|
||||||
|
env_id=args.env_id,
|
||||||
|
seed=args.seed,
|
||||||
|
movement=movement,
|
||||||
|
reward_mode=reward_mode,
|
||||||
|
clip_reward=args.clip_reward,
|
||||||
|
frame_skip=args.frame_skip,
|
||||||
|
render_mode=None,
|
||||||
|
progress_scale=args.progress_scale,
|
||||||
|
death_penalty=args.death_penalty,
|
||||||
|
flag_bonus=args.flag_bonus,
|
||||||
|
stall_penalty=args.stall_penalty,
|
||||||
|
stall_steps=args.stall_steps,
|
||||||
|
backward_penalty_scale=args.backward_penalty_scale,
|
||||||
|
milestone_interval=args.milestone_interval,
|
||||||
|
milestone_bonus=args.milestone_bonus,
|
||||||
|
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||||
|
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||||
|
)
|
||||||
|
|
||||||
|
rewards = []
|
||||||
|
max_x_positions = []
|
||||||
|
clear_flags = []
|
||||||
|
|
||||||
|
for ep in range(1, args.episodes + 1):
|
||||||
|
obs, info = reset_compat(env, seed=args.seed + ep)
|
||||||
|
done = False
|
||||||
|
ep_reward = 0.0
|
||||||
|
ep_max_x = float(info.get("x_pos", 0.0))
|
||||||
|
flag_get = False
|
||||||
|
step_count = 0
|
||||||
|
|
||||||
|
while not done and step_count < args.max_steps:
|
||||||
|
action, _ = model.predict(obs, deterministic=not args.stochastic)
|
||||||
|
if isinstance(action, np.ndarray):
|
||||||
|
action = int(action.item())
|
||||||
|
obs, reward, terminated, truncated, info = step_compat(env, action)
|
||||||
|
ep_reward += float(reward)
|
||||||
|
ep_max_x = max(ep_max_x, float(info.get("x_pos", 0.0)))
|
||||||
|
flag_get = flag_get or bool(info.get("flag_get", False))
|
||||||
|
done = terminated or truncated
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
|
rewards.append(ep_reward)
|
||||||
|
max_x_positions.append(ep_max_x)
|
||||||
|
clear_flags.append(1.0 if flag_get else 0.0)
|
||||||
|
print(
|
||||||
|
f"[episode {ep}] reward={ep_reward:.2f} max_x={ep_max_x:.1f} "
|
||||||
|
f"clear={flag_get} steps={step_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"episodes": args.episodes,
|
||||||
|
"avg_reward": mean(rewards) if rewards else 0.0,
|
||||||
|
"avg_max_x_pos": mean(max_x_positions) if max_x_positions else 0.0,
|
||||||
|
"clear_rate": mean(clear_flags) if clear_flags else 0.0,
|
||||||
|
}
|
||||||
|
print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
200
mario-rl-mvp/src/record_video.py
Normal file
200
mario-rl-mvp/src/record_video.py
Normal file
@@ -0,0 +1,200 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import imageio.v2 as imageio
|
||||||
|
import numpy as np
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
|
||||||
|
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
||||||
|
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Record Mario agent rollout to mp4 (headless).")
|
||||||
|
parser.add_argument("--model-path", type=str, default="", help="Path to .zip model. Empty means latest model.")
|
||||||
|
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
|
||||||
|
parser.add_argument("--movement", type=str, default="auto", choices=["auto", "right_only", "simple"])
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument("--frame-skip", type=int, default=4)
|
||||||
|
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
|
||||||
|
parser.add_argument("--progress-scale", type=float, default=0.02)
|
||||||
|
parser.add_argument("--death-penalty", type=float, default=-50.0)
|
||||||
|
parser.add_argument("--flag-bonus", type=float, default=100.0)
|
||||||
|
parser.add_argument("--stall-penalty", type=float, default=0.05)
|
||||||
|
parser.add_argument("--stall-steps", type=int, default=40)
|
||||||
|
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
|
||||||
|
parser.add_argument("--milestone-interval", type=int, default=32)
|
||||||
|
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||||
|
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||||
|
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||||
|
parser.add_argument("--clip-reward", action="store_true")
|
||||||
|
parser.add_argument("--fps", type=int, default=30)
|
||||||
|
parser.add_argument("--duration-sec", type=int, default=10)
|
||||||
|
parser.add_argument("--max-steps", type=int, default=5_000)
|
||||||
|
parser.add_argument("--stochastic", action="store_true")
|
||||||
|
parser.add_argument("--output", type=str, default="", help="Output mp4 path.")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _action_count_for_movement(movement: str) -> int:
|
||||||
|
return len(get_action_set(movement))
|
||||||
|
|
||||||
|
|
||||||
|
def _model_action_count(model: PPO):
|
||||||
|
if hasattr(model, "action_space") and hasattr(model.action_space, "n"):
|
||||||
|
return int(model.action_space.n)
|
||||||
|
action_net = getattr(getattr(model, "policy", None), "action_net", None)
|
||||||
|
out_features = getattr(action_net, "out_features", None)
|
||||||
|
if out_features is not None:
|
||||||
|
return int(out_features)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_movement(movement_arg: str, model: PPO) -> str:
|
||||||
|
model_action_n = _model_action_count(model)
|
||||||
|
|
||||||
|
if movement_arg == "auto":
|
||||||
|
if model_action_n is not None:
|
||||||
|
for candidate in ("right_only", "simple"):
|
||||||
|
if _action_count_for_movement(candidate) == model_action_n:
|
||||||
|
return candidate
|
||||||
|
return "right_only"
|
||||||
|
|
||||||
|
if model_action_n is not None:
|
||||||
|
expected_n = _action_count_for_movement(movement_arg)
|
||||||
|
if expected_n != model_action_n:
|
||||||
|
raise ValueError(
|
||||||
|
f"movement='{movement_arg}' has {expected_n} actions, but model expects {model_action_n}. "
|
||||||
|
"Use --movement auto or pass the matching movement."
|
||||||
|
)
|
||||||
|
return movement_arg
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_model(user_path: str) -> Path:
|
||||||
|
if user_path:
|
||||||
|
p = Path(user_path).expanduser().resolve()
|
||||||
|
if not p.exists():
|
||||||
|
raise FileNotFoundError(f"Model not found: {p}")
|
||||||
|
return p
|
||||||
|
|
||||||
|
paths = ensure_artifact_paths()
|
||||||
|
latest = latest_model_path(paths.models)
|
||||||
|
if latest is None:
|
||||||
|
raise FileNotFoundError("No model found under artifacts/models. Please run training first.")
|
||||||
|
return latest
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_output_path(user_output: str) -> Path:
|
||||||
|
if user_output:
|
||||||
|
return Path(user_output).expanduser().resolve()
|
||||||
|
|
||||||
|
paths = ensure_artifact_paths()
|
||||||
|
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
return (paths.videos / f"mario_replay_{ts}.mp4").resolve()
|
||||||
|
|
||||||
|
|
||||||
|
def save_frames_fallback(frames, output_path: Path) -> Path:
|
||||||
|
frame_dir = output_path.with_suffix("")
|
||||||
|
frame_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
for i, frame in enumerate(frames):
|
||||||
|
imageio.imwrite(frame_dir / f"frame_{i:06d}.png", frame)
|
||||||
|
return frame_dir
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
seed_everything(args.seed)
|
||||||
|
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||||
|
|
||||||
|
model_path = resolve_model(args.model_path)
|
||||||
|
output_path = resolve_output_path(args.output)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
model = PPO.load(str(model_path))
|
||||||
|
movement = resolve_movement(args.movement, model)
|
||||||
|
print(f"[video] movement={movement} reward_mode={reward_mode}")
|
||||||
|
env = make_mario_env(
|
||||||
|
env_id=args.env_id,
|
||||||
|
seed=args.seed,
|
||||||
|
movement=movement,
|
||||||
|
reward_mode=reward_mode,
|
||||||
|
clip_reward=args.clip_reward,
|
||||||
|
frame_skip=args.frame_skip,
|
||||||
|
render_mode="rgb_array",
|
||||||
|
progress_scale=args.progress_scale,
|
||||||
|
death_penalty=args.death_penalty,
|
||||||
|
flag_bonus=args.flag_bonus,
|
||||||
|
stall_penalty=args.stall_penalty,
|
||||||
|
stall_steps=args.stall_steps,
|
||||||
|
backward_penalty_scale=args.backward_penalty_scale,
|
||||||
|
milestone_interval=args.milestone_interval,
|
||||||
|
milestone_bonus=args.milestone_bonus,
|
||||||
|
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||||
|
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||||
|
)
|
||||||
|
|
||||||
|
obs, _ = reset_compat(env, seed=args.seed)
|
||||||
|
frames = []
|
||||||
|
|
||||||
|
first_frame = env.render()
|
||||||
|
if first_frame is not None:
|
||||||
|
# nes-py may reuse the same frame buffer; copy to avoid aliasing all frames.
|
||||||
|
frames.append(first_frame.copy())
|
||||||
|
|
||||||
|
target_frames = max(1, args.fps * args.duration_sec)
|
||||||
|
done = False
|
||||||
|
step_count = 0
|
||||||
|
reward_sum = 0.0
|
||||||
|
episode_count = 1
|
||||||
|
|
||||||
|
while len(frames) < target_frames and step_count < args.max_steps:
|
||||||
|
action, _ = model.predict(obs, deterministic=not args.stochastic)
|
||||||
|
if isinstance(action, np.ndarray):
|
||||||
|
action = int(action.item())
|
||||||
|
obs, reward, terminated, truncated, info = step_compat(env, action)
|
||||||
|
reward_sum += float(reward)
|
||||||
|
step_count += 1
|
||||||
|
done = terminated or truncated
|
||||||
|
|
||||||
|
frame = env.render()
|
||||||
|
if frame is not None:
|
||||||
|
frames.append(frame.copy())
|
||||||
|
|
||||||
|
if done:
|
||||||
|
obs, _ = reset_compat(env, seed=args.seed + episode_count)
|
||||||
|
episode_count += 1
|
||||||
|
frame = env.render()
|
||||||
|
if frame is not None and len(frames) < target_frames:
|
||||||
|
frames.append(frame.copy())
|
||||||
|
done = False
|
||||||
|
|
||||||
|
if not frames:
|
||||||
|
raise RuntimeError("No frames captured. Check render_mode support and environment setup.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer = imageio.get_writer(str(output_path), fps=args.fps, codec="libx264", quality=8)
|
||||||
|
for frame in frames:
|
||||||
|
writer.append_data(frame)
|
||||||
|
writer.close()
|
||||||
|
print(f"[video] Saved mp4: {output_path}")
|
||||||
|
except Exception as exc:
|
||||||
|
frame_dir = save_frames_fallback(frames, output_path)
|
||||||
|
print(f"[warn] mp4 write failed: {exc}")
|
||||||
|
print(f"[fallback] Saved frame sequence: {frame_dir}")
|
||||||
|
print(
|
||||||
|
"[fallback] Convert frames to mp4 with: "
|
||||||
|
f"ffmpeg -framerate {args.fps} -i {frame_dir}/frame_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[stats] frames={len(frames)} approx_sec={len(frames)/max(args.fps,1):.2f} "
|
||||||
|
f"steps={step_count} reward_sum={reward_sum:.2f} episodes={episode_count}"
|
||||||
|
)
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
300
mario-rl-mvp/src/train_ppo.py
Normal file
300
mario-rl-mvp/src/train_ppo.py
Normal file
@@ -0,0 +1,300 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback
|
||||||
|
from stable_baselines3.common.save_util import load_from_zip_file
|
||||||
|
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
from src.env import make_env_fn
|
||||||
|
from src.utils import ensure_artifact_paths, resolve_torch_device, seed_everything, write_latest_pointer
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Train PPO agent on NES Super Mario Bros.")
|
||||||
|
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
|
||||||
|
parser.add_argument("--movement", type=str, default="right_only", choices=["right_only", "simple"])
|
||||||
|
parser.add_argument("--total-timesteps", type=int, default=1_000_000)
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument("--n-envs", type=int, default=4)
|
||||||
|
parser.add_argument("--frame-skip", type=int, default=4)
|
||||||
|
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
|
||||||
|
parser.add_argument("--progress-scale", type=float, default=0.02)
|
||||||
|
parser.add_argument("--death-penalty", type=float, default=-50.0)
|
||||||
|
parser.add_argument("--flag-bonus", type=float, default=100.0)
|
||||||
|
parser.add_argument("--stall-penalty", type=float, default=0.05)
|
||||||
|
parser.add_argument("--stall-steps", type=int, default=40)
|
||||||
|
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
|
||||||
|
parser.add_argument("--milestone-interval", type=int, default=32)
|
||||||
|
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||||
|
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||||
|
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||||
|
|
||||||
|
parser.add_argument("--learning-rate", type=float, default=2.5e-4)
|
||||||
|
parser.add_argument("--n-steps", type=int, default=128)
|
||||||
|
parser.add_argument("--batch-size", type=int, default=256)
|
||||||
|
parser.add_argument("--n-epochs", type=int, default=4)
|
||||||
|
parser.add_argument("--gamma", type=float, default=0.99)
|
||||||
|
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||||||
|
parser.add_argument("--ent-coef", type=float, default=0.01)
|
||||||
|
parser.add_argument("--clip-range", type=float, default=0.2)
|
||||||
|
|
||||||
|
parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.")
|
||||||
|
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "mps"])
|
||||||
|
parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.")
|
||||||
|
parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--init-model-path",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="Optional .zip model path to initialize weights before training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--allow-partial-init",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Allow partial init when checkpoint is from a different action space "
|
||||||
|
"(e.g. right_only -> simple). Compatible policy layers will be loaded, "
|
||||||
|
"incompatible layers (like action head) are skipped."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
parser.add_argument("--progress-bar", action="store_true", help="Enable SB3 progress bar (requires tqdm/rich).")
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_partial_policy_weights(model: PPO, init_model_path: Path, device: str) -> tuple[int, int]:
|
||||||
|
_, params, _ = load_from_zip_file(str(init_model_path), device=device)
|
||||||
|
if params is None or "policy" not in params:
|
||||||
|
raise RuntimeError(f"No policy parameters found in checkpoint: {init_model_path}")
|
||||||
|
|
||||||
|
src_state = params["policy"]
|
||||||
|
dst_state = model.policy.state_dict()
|
||||||
|
|
||||||
|
compatible_state = {}
|
||||||
|
skipped_count = 0
|
||||||
|
for key, value in src_state.items():
|
||||||
|
if key in dst_state and dst_state[key].shape == value.shape:
|
||||||
|
compatible_state[key] = value
|
||||||
|
else:
|
||||||
|
skipped_count += 1
|
||||||
|
|
||||||
|
if not compatible_state:
|
||||||
|
raise RuntimeError("No compatible policy tensors found for partial init.")
|
||||||
|
|
||||||
|
model.policy.load_state_dict(compatible_state, strict=False)
|
||||||
|
return len(compatible_state), skipped_count
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeEndLoggingCallback(BaseCallback):
|
||||||
|
"""Log per-episode terminal diagnostics to stdout and TensorBoard."""
|
||||||
|
|
||||||
|
REASON_TO_CODE = {"death": 0, "no_progress": 1, "timeout": 2, "clear": 3}
|
||||||
|
|
||||||
|
def __init__(self, n_envs: int, tb_log_dir: Path):
|
||||||
|
super().__init__(verbose=0)
|
||||||
|
self.n_envs = n_envs
|
||||||
|
self.tb_log_dir = tb_log_dir
|
||||||
|
self.tb_writer: SummaryWriter | None = None
|
||||||
|
self.episode_count = 0
|
||||||
|
self.episode_max_x_pos = [0.0 for _ in range(max(n_envs, 1))]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_done_reason(info: dict) -> str:
|
||||||
|
if bool(info.get("flag_get", False)):
|
||||||
|
return "clear"
|
||||||
|
if bool(info.get("terminated_by_stall", False)):
|
||||||
|
return "no_progress"
|
||||||
|
if bool(info.get("TimeLimit.truncated", False)):
|
||||||
|
return "timeout"
|
||||||
|
try:
|
||||||
|
if float(info.get("time", 1.0)) <= 0.0:
|
||||||
|
return "timeout"
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
pass
|
||||||
|
return "death"
|
||||||
|
|
||||||
|
def _on_training_start(self) -> None:
|
||||||
|
self.tb_log_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.tb_writer = SummaryWriter(log_dir=str(self.tb_log_dir))
|
||||||
|
|
||||||
|
def _on_step(self) -> bool:
|
||||||
|
infos = self.locals.get("infos")
|
||||||
|
dones = self.locals.get("dones")
|
||||||
|
if infos is None or dones is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
for env_idx, (info, done) in enumerate(zip(infos, dones)):
|
||||||
|
if env_idx >= len(self.episode_max_x_pos):
|
||||||
|
continue
|
||||||
|
x_pos = float(info.get("x_pos", 0.0))
|
||||||
|
if x_pos > self.episode_max_x_pos[env_idx]:
|
||||||
|
self.episode_max_x_pos[env_idx] = x_pos
|
||||||
|
|
||||||
|
if not bool(done):
|
||||||
|
continue
|
||||||
|
|
||||||
|
self.episode_count += 1
|
||||||
|
max_x_pos = self.episode_max_x_pos[env_idx]
|
||||||
|
flag_get = 1.0 if bool(info.get("flag_get", False)) else 0.0
|
||||||
|
done_reason = self._resolve_done_reason(info)
|
||||||
|
done_reason_code = float(self.REASON_TO_CODE[done_reason])
|
||||||
|
episode_step = self.episode_count
|
||||||
|
|
||||||
|
self.logger.record_mean("rollout/episode_max_x_pos", max_x_pos)
|
||||||
|
self.logger.record_mean("rollout/flag_get", flag_get)
|
||||||
|
self.logger.record_mean(f"rollout/done_reason_{done_reason}", 1.0)
|
||||||
|
|
||||||
|
if self.tb_writer is not None:
|
||||||
|
self.tb_writer.add_scalar("episode_end/episode_max_x_pos", max_x_pos, episode_step)
|
||||||
|
self.tb_writer.add_scalar("episode_end/flag_get", flag_get, episode_step)
|
||||||
|
self.tb_writer.add_scalar("episode_end/done_reason_code", done_reason_code, episode_step)
|
||||||
|
self.tb_writer.add_scalar("episode_end/done_reason_death", 1.0 if done_reason == "death" else 0.0, episode_step)
|
||||||
|
self.tb_writer.add_scalar(
|
||||||
|
"episode_end/done_reason_no_progress", 1.0 if done_reason == "no_progress" else 0.0, episode_step
|
||||||
|
)
|
||||||
|
self.tb_writer.add_scalar("episode_end/done_reason_timeout", 1.0 if done_reason == "timeout" else 0.0, episode_step)
|
||||||
|
self.tb_writer.add_scalar("episode_end/done_reason_clear", 1.0 if done_reason == "clear" else 0.0, episode_step)
|
||||||
|
self.tb_writer.flush()
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[episode_end] ep={episode_step} env={env_idx} reason={done_reason} "
|
||||||
|
f"max_x={max_x_pos:.1f} flag_get={bool(flag_get)}"
|
||||||
|
)
|
||||||
|
self.episode_max_x_pos[env_idx] = 0.0
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _on_training_end(self) -> None:
|
||||||
|
if self.tb_writer is not None:
|
||||||
|
self.tb_writer.close()
|
||||||
|
self.tb_writer = None
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
seed_everything(args.seed)
|
||||||
|
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||||
|
|
||||||
|
paths = ensure_artifact_paths()
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
env_slug = args.env_id.replace("/", "_")
|
||||||
|
run_name = args.run_name.strip() or f"ppo_{env_slug}_{timestamp}"
|
||||||
|
|
||||||
|
run_model_dir = paths.models / run_name
|
||||||
|
run_log_dir = paths.logs / run_name
|
||||||
|
run_tb_dir = run_log_dir / "tb"
|
||||||
|
run_model_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
run_tb_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
device, device_msg = resolve_torch_device(args.device)
|
||||||
|
print(f"[device] {device} | {device_msg}")
|
||||||
|
print(f"[run] {run_name}")
|
||||||
|
print(
|
||||||
|
f"[config] env_id={args.env_id}, movement={args.movement}, reward_mode={reward_mode}, "
|
||||||
|
f"clip_reward={args.clip_reward}, n_envs={args.n_envs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
env_fns = [
|
||||||
|
make_env_fn(
|
||||||
|
env_id=args.env_id,
|
||||||
|
seed=args.seed + i,
|
||||||
|
movement=args.movement,
|
||||||
|
reward_mode=reward_mode,
|
||||||
|
clip_reward=args.clip_reward,
|
||||||
|
frame_skip=args.frame_skip,
|
||||||
|
progress_scale=args.progress_scale,
|
||||||
|
death_penalty=args.death_penalty,
|
||||||
|
flag_bonus=args.flag_bonus,
|
||||||
|
stall_penalty=args.stall_penalty,
|
||||||
|
stall_steps=args.stall_steps,
|
||||||
|
backward_penalty_scale=args.backward_penalty_scale,
|
||||||
|
milestone_interval=args.milestone_interval,
|
||||||
|
milestone_bonus=args.milestone_bonus,
|
||||||
|
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||||
|
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||||
|
)
|
||||||
|
for i in range(args.n_envs)
|
||||||
|
]
|
||||||
|
vec_env = DummyVecEnv(env_fns)
|
||||||
|
vec_env = VecMonitor(vec_env, filename=str(run_log_dir / "monitor.csv"))
|
||||||
|
|
||||||
|
model = PPO(
|
||||||
|
policy="CnnPolicy",
|
||||||
|
env=vec_env,
|
||||||
|
learning_rate=args.learning_rate,
|
||||||
|
n_steps=args.n_steps,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
n_epochs=args.n_epochs,
|
||||||
|
gamma=args.gamma,
|
||||||
|
gae_lambda=args.gae_lambda,
|
||||||
|
ent_coef=args.ent_coef,
|
||||||
|
clip_range=args.clip_range,
|
||||||
|
tensorboard_log=str(run_tb_dir),
|
||||||
|
seed=args.seed,
|
||||||
|
verbose=1,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.init_model_path.strip():
|
||||||
|
init_model_path = Path(args.init_model_path).expanduser().resolve()
|
||||||
|
if not init_model_path.exists():
|
||||||
|
raise FileNotFoundError(f"Init model not found: {init_model_path}")
|
||||||
|
try:
|
||||||
|
model.set_parameters(str(init_model_path), exact_match=False, device=device)
|
||||||
|
print(f"[init] Loaded initial weights from: {init_model_path}")
|
||||||
|
except RuntimeError as exc:
|
||||||
|
if not args.allow_partial_init:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{exc}\n"
|
||||||
|
"Hint: checkpoint and current env may use different action spaces. "
|
||||||
|
"Try one of:\n"
|
||||||
|
"1) keep the same --movement as the checkpoint;\n"
|
||||||
|
"2) remove --init-model-path and train from scratch;\n"
|
||||||
|
"3) add --allow-partial-init to load compatible layers only."
|
||||||
|
) from exc
|
||||||
|
loaded_count, skipped_count = load_partial_policy_weights(model, init_model_path, device=device)
|
||||||
|
print(
|
||||||
|
f"[init] Partial init from {init_model_path}: "
|
||||||
|
f"loaded_tensors={loaded_count}, skipped_tensors={skipped_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
callback = CheckpointCallback(
|
||||||
|
save_freq=max(args.save_freq // max(args.n_envs, 1), 1),
|
||||||
|
save_path=str(run_model_dir),
|
||||||
|
name_prefix="ppo_mario_ckpt",
|
||||||
|
save_replay_buffer=False,
|
||||||
|
save_vecnormalize=False,
|
||||||
|
)
|
||||||
|
episode_end_logging_callback = EpisodeEndLoggingCallback(
|
||||||
|
n_envs=args.n_envs,
|
||||||
|
tb_log_dir=run_tb_dir / "episode_end",
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
model.learn(
|
||||||
|
total_timesteps=args.total_timesteps,
|
||||||
|
callback=CallbackList([callback, episode_end_logging_callback]),
|
||||||
|
tb_log_name="ppo",
|
||||||
|
progress_bar=args.progress_bar,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
vec_env.close()
|
||||||
|
|
||||||
|
final_model = run_model_dir / "final_model"
|
||||||
|
model.save(str(final_model))
|
||||||
|
|
||||||
|
latest_model = paths.models / "latest_model.zip"
|
||||||
|
shutil.copy2(str(final_model) + ".zip", latest_model)
|
||||||
|
pointer = write_latest_pointer(paths.models, latest_model)
|
||||||
|
|
||||||
|
print(f"[done] Final model: {final_model}.zip")
|
||||||
|
print(f"[done] Latest model pointer: {pointer}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
75
mario-rl-mvp/src/utils.py
Normal file
75
mario-rl-mvp/src/utils.py
Normal file
@@ -0,0 +1,75 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ArtifactPaths:
|
||||||
|
root: Path
|
||||||
|
models: Path
|
||||||
|
videos: Path
|
||||||
|
logs: Path
|
||||||
|
|
||||||
|
|
||||||
|
def project_root() -> Path:
|
||||||
|
return Path(__file__).resolve().parents[1]
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_artifact_paths(root: Optional[Path] = None) -> ArtifactPaths:
|
||||||
|
root = root or project_root()
|
||||||
|
artifacts = root / "artifacts"
|
||||||
|
models = artifacts / "models"
|
||||||
|
videos = artifacts / "videos"
|
||||||
|
logs = artifacts / "logs"
|
||||||
|
for p in (artifacts, models, videos, logs):
|
||||||
|
p.mkdir(parents=True, exist_ok=True)
|
||||||
|
return ArtifactPaths(root=artifacts, models=models, videos=videos, logs=logs)
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed: int) -> None:
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]:
|
||||||
|
requested = requested.lower().strip()
|
||||||
|
if requested == "cpu":
|
||||||
|
return "cpu", "User requested CPU."
|
||||||
|
|
||||||
|
if requested not in {"auto", "mps"}:
|
||||||
|
return "cpu", f"Unknown device '{requested}', fallback to CPU."
|
||||||
|
|
||||||
|
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||||
|
if not mps_available:
|
||||||
|
if requested == "mps":
|
||||||
|
return "cpu", "MPS requested but unavailable, fallback to CPU."
|
||||||
|
return "cpu", "MPS unavailable, using CPU."
|
||||||
|
|
||||||
|
try:
|
||||||
|
x = torch.ones(8, device="mps")
|
||||||
|
_ = (x * 2).cpu().numpy()
|
||||||
|
return "mps", "MPS is available and passed a quick tensor sanity check."
|
||||||
|
except Exception as exc: # pragma: no cover - hardware dependent
|
||||||
|
return "cpu", f"MPS check failed ({exc}), fallback to CPU."
|
||||||
|
|
||||||
|
|
||||||
|
def latest_model_path(models_dir: Path) -> Optional[Path]:
|
||||||
|
candidates = [p for p in models_dir.rglob("*.zip") if p.is_file()]
|
||||||
|
if not candidates:
|
||||||
|
return None
|
||||||
|
return max(candidates, key=lambda p: p.stat().st_mtime)
|
||||||
|
|
||||||
|
|
||||||
|
def write_latest_pointer(models_dir: Path, model_path: Path) -> Path:
|
||||||
|
pointer = models_dir / "LATEST_MODEL.txt"
|
||||||
|
pointer.write_text(str(model_path.resolve()) + "\n", encoding="utf-8")
|
||||||
|
return pointer
|
||||||
Reference in New Issue
Block a user