From d23de69b9ad0271b648d816bb90b1e6446f72841 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roog=20=28=E9=A1=BE=E6=96=B0=E5=9F=B9=29?= Date: Thu, 12 Feb 2026 18:54:06 +0800 Subject: [PATCH] feat: initial mario rl mvp --- .gitignore | 44 ++++ mario-rl-mvp/README.md | 339 +++++++++++++++++++++++++++++++ mario-rl-mvp/requirements.txt | 12 ++ mario-rl-mvp/src/__init__.py | 1 + mario-rl-mvp/src/env.py | 337 ++++++++++++++++++++++++++++++ mario-rl-mvp/src/eval.py | 162 +++++++++++++++ mario-rl-mvp/src/record_video.py | 200 ++++++++++++++++++ mario-rl-mvp/src/train_ppo.py | 300 +++++++++++++++++++++++++++ mario-rl-mvp/src/utils.py | 75 +++++++ 9 files changed, 1470 insertions(+) create mode 100644 .gitignore create mode 100644 mario-rl-mvp/README.md create mode 100644 mario-rl-mvp/requirements.txt create mode 100644 mario-rl-mvp/src/__init__.py create mode 100644 mario-rl-mvp/src/env.py create mode 100644 mario-rl-mvp/src/eval.py create mode 100644 mario-rl-mvp/src/record_video.py create mode 100644 mario-rl-mvp/src/train_ppo.py create mode 100644 mario-rl-mvp/src/utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..38c13d1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,44 @@ +# macOS +.DS_Store + +# Python cache / build +__pycache__/ +*.py[cod] +*.pyo +*.pyd +*.so +*.egg-info/ +.eggs/ +dist/ +build/ + +# Virtual environments +.venv/ +venv/ +env/ +*/.venv/ + +# Logs / test / tooling +*.log +.pytest_cache/ +.mypy_cache/ +.ruff_cache/ +.coverage +htmlcov/ + +# Jupyter +.ipynb_checkpoints/ + +# RL artifacts +artifacts/ +*/artifacts/ + +# TensorBoard / checkpoints / models +runs/ +checkpoints/ +models/ +videos/ + +# IDE +.vscode/ +.idea/ diff --git a/mario-rl-mvp/README.md b/mario-rl-mvp/README.md new file mode 100644 index 0000000..e5d2bc3 --- /dev/null +++ b/mario-rl-mvp/README.md @@ -0,0 +1,339 @@ +# Mario RL MVP (macOS Apple Silicon) + +最小可运行工程:使用像素输入 + 传统 CNN policy(`stable-baselines3` PPO)训练 `gym-super-mario-bros / nes-py` 智能体,不使用大语言模型。 + +## 1. 项目结构 + +```text +mario-rl-mvp/ + src/ + env.py + train_ppo.py + eval.py + record_video.py + utils.py + artifacts/ + models/ + videos/ + logs/ + requirements.txt + README.md +``` + +## 2. 环境准备(macOS / Apple Silicon) + +建议 Python 3.9+(本机默认 `python3` 即可)。 + +```bash +cd /Users/roog/Work/FNT/SpMr/mario-rl-mvp +python3 -m venv .venv +source .venv/bin/activate +python -m pip install --upgrade pip setuptools wheel +pip install -r requirements.txt +``` + +可选系统依赖(用于 ffmpeg 转码与潜在 SDL 兼容): + +```bash +brew install ffmpeg sdl2 +``` + +## 3. 一条命令开始训练 + +默认 CPU 训练(如果检测到可用且稳定的 MPS,会自动尝试启用,否则自动回退 CPU): + +```bash +python -m src.train_ppo +``` + +常用覆盖参数: + +```bash +python -m src.train_ppo \ + --total-timesteps 1000000 \ + --n-envs 4 \ + --save-freq 50000 \ + --env-id SuperMarioBros-1-1-v0 \ + --movement right_only \ + --seed 42 +``` + +从已有 checkpoint 初始化后继续训练(可同时改超参数,如 `--ent-coef`): + +```bash +python -m src.train_ppo \ + --init-model-path artifacts/models/latest_model.zip \ + --total-timesteps 500000 \ + --ent-coef 0.02 \ + --learning-rate 1e-4 +``` + +如果切换了动作空间(例如 checkpoint 是 `right_only`,当前想用 `simple`),可用部分加载: + +```bash +python -m src.train_ppo \ + --init-model-path artifacts/models/latest_model.zip \ + --allow-partial-init \ + --movement simple \ + --total-timesteps 300000 +``` + +### 3.1 从已有模型继续训练(`--init-model-path`) + +- 用途:加载已有 `.zip` 权重后继续训练,适合“不中断实验目标但调整探索参数”。 +- 常见场景:当前策略陷入局部最优(例如 `approx_kl` 和 `policy_gradient_loss` 长期接近 0),希望从已有模型继续探索。 +- 注意:这不是“热更新”,仍然需要停止当前训练进程后用新命令重启。 + +```bash +python -m src.train_ppo \ + --init-model-path artifacts/models/latest_model.zip \ + --ent-coef 0.02 \ + --learning-rate 1e-4 \ + --total-timesteps 300000 +``` + +训练输出: +- stdout:PPO 训练日志 +- TensorBoard:`artifacts/logs//tb/` +- checkpoint:`artifacts/models//ppo_mario_ckpt_*.zip` +- final model:`artifacts/models//final_model.zip` +- latest 指针:`artifacts/models/latest_model.zip` + `LATEST_MODEL.txt` + +启动 TensorBoard: + +```bash +tensorboard --logdir artifacts/logs --port 6006 +``` + +## 3.2 训练日志字段速查(PPO) + +训练时你会看到类似: + +```text +| rollout/ep_len_mean | rollout/ep_rew_mean | ... | +| train/approx_kl | train/entropy_loss | ... | +``` + +下面是常用字段的含义(对应你贴出来那组): + +- `rollout/ep_len_mean`:最近一批 episode 的平均步数。越大不一定越好,要结合 reward 一起看。 +- `rollout/ep_rew_mean`:最近一批 episode 的平均回报。通常越高越好。 +- `time/fps`:训练吞吐(每秒环境步数),只代表速度,不代表策略质量。 +- `time/iterations`:第几次 rollout + update 循环。 +- `time/time_elapsed`:训练已运行的秒数。 +- `time/total_timesteps`:累计环境交互步数(达到你设定的 `--total-timesteps` 会停止)。 +- `train/approx_kl`:新旧策略差异大小。太大说明更新过猛;接近 0 说明几乎没在更新。极小负数通常是数值误差,可当作 0。 +- `train/clip_fraction`:有多少样本触发 PPO clipping。长期为 0 且 KL 也接近 0,常见于“策略基本不再更新”。 +- `train/clip_range`:PPO 的 clipping 阈值(默认 0.2)。 +- `train/entropy_loss`:探索强度指标。绝对值越接近 0,策略越确定、探索越少。 +- `train/explained_variance`:价值网络对回报的解释度,越接近 1 越好,接近 0 说明 value 还不稳。 +- `train/learning_rate`:优化器步长(参数更新幅度),不是硬件速度。 +- `train/loss`:总损失(由多个部分组成),主要看趋势,不单看绝对值。 +- `train/policy_gradient_loss`:策略网络的更新信号。长期接近 0 可能表示 actor 更新很弱。 +- `train/value_loss`:价值网络误差。过大通常代表 critic 拟合还不稳定。 + +### 快速判断(实用版) + +- `ep_rew_mean` / `avg_max_x_pos` 持续上升:一般在变好。 +- `approx_kl≈0` + `clip_fraction=0` + `policy_gradient_loss≈0`:大概率卡住(更新几乎停了)。 +- `entropy_loss` 绝对值太小且长期不变:探索不足,可尝试提高 `--ent-coef`。 + +## 4. 评估模型 + +加载最新模型,跑 N 个 episode,输出平均指标: + +```bash +python -m src.eval --episodes 5 +``` + +可指定模型: + +```bash +python -m src.eval --model-path artifacts/models/latest_model.zip --episodes 10 +``` + +注意:`eval.py` 默认 `--movement auto`,会按模型动作维度自动匹配 `right_only/simple`,避免动作空间不一致导致 `KeyError`。 + +输出指标包括: +- `avg_reward` +- `avg_max_x_pos` +- `clear_rate`(`flag_get=True` 的比例) + + +查看步数 + +```php +unzip -p artifacts/models/latest_model.zip data | rg '"num_timesteps"|"_total_timesteps"|"_tensorboard_log"' +``` + +num_timesteps = 151552 +_total_timesteps = 150000 + +## 5. 录制回放视频(无窗口/headless) + +默认录制约 10 秒 mp4 到 `artifacts/videos/`: + +```bash +python -m src.record_video --duration-sec 10 --fps 30 +``` + +可指定输出路径: + +```bash +python -m src.record_video --output artifacts/videos/demo.mp4 --duration-sec 10 +``` + +注意:`record_video.py` 默认 `--movement auto`,会按模型自动匹配动作空间。 + +实现方式: +- 使用 `render_mode=rgb_array`,无需打开窗口 +- 默认通过 `imageio + ffmpeg` 输出 mp4 +- 若 mp4 写入失败,会自动降级保存帧序列(PNG),并打印 ffmpeg 转码命令 + +## 6. 动作空间选择说明 + +默认 `RIGHT_ONLY`,原因: +- 动作更少,探索空间更小,MVP 更快收敛到“向右推进”策略 +- 适合先验证训练闭环 + +可切到 `SIMPLE_MOVEMENT`(动作更丰富): + +```bash +python -m src.train_ppo --movement simple +``` + +## 7. 预处理与奖励 + +默认预处理链路: +- 跳帧:`frame_skip=4` +- 灰度:`RGB -> Gray` +- 缩放:`84x84` +- 帧堆叠:`4` +- 通道布局:`CHW`(兼容 `CnnPolicy`) + +奖励: +- `--reward-mode raw`:原始奖励(默认) +- `--reward-mode clip`:裁剪奖励 `sign(reward)`(等价于旧参数 `--clip-reward`) +- `--reward-mode progress`:奖励塑形模式,额外包含: + - 前进增益奖励(`--progress-scale`) + - 死亡惩罚(`--death-penalty`) + - 通关奖励(`--flag-bonus`) + - 原地卡住惩罚(`--stall-penalty` + `--stall-steps`) + - 后退惩罚(`--backward-penalty-scale`) + +使用裁剪奖励: + +```bash +python -m src.train_ppo --reward-mode clip +``` + +针对“卡在固定位置(如 x=314 撞蘑菇)”的推荐续训命令: + +```bash +python -m src.train_ppo \ + --init-model-path artifacts/models/latest_model.zip \ + --allow-partial-init \ + --reward-mode progress \ + --movement simple \ + --ent-coef 0.04 \ + --learning-rate 1e-4 \ + --n-steps 512 \ + --gamma 0.995 \ + --death-penalty -120 \ + --stall-penalty 0.2 \ + --stall-steps 20 \ + --backward-penalty-scale 0.03 \ + --milestone-bonus 2.0 \ + --no-progress-terminate-steps 80 \ + --no-progress-terminate-penalty 30 \ + --total-timesteps 150000 +``` + +## 8. 常见问题排查 + +### 8.1 `pip install` 失败(nes-py/gym-super-mario-bros 编译问题) + +先安装工具链并重试: + +```bash +xcode-select --install +brew install cmake swig sdl2 +pip install --upgrade pip setuptools wheel +pip install -r requirements.txt +``` + +### 8.2 MPS 不稳定或报错 + +强制 CPU: + +```bash +python -m src.train_ppo --device cpu +``` + +说明:脚本会先做一次 MPS 张量 sanity check,失败自动回退 CPU。 + +### 8.3 视频写入失败(ffmpeg/codec) + +1) 安装系统 ffmpeg: +```bash +brew install ffmpeg +``` + +2) 已降级保存帧序列时,手动转码: +```bash +ffmpeg -framerate 30 -i artifacts/videos/_frames/frame_%06d.png -c:v libx264 -pix_fmt yuv420p artifacts/videos/.mp4 +``` + +### 8.4 图形窗口相关报错 + +本工程默认 `rgb_array` 录制,不依赖 GUI 窗口。 +若仍遇到 SDL 问题,可显式设置: + +```bash +export SDL_VIDEODRIVER=dummy +python -m src.record_video --duration-sec 10 +``` + +### 8.5 `size mismatch for action_net`(加载旧模型时报错) + +典型原因:旧 checkpoint 的动作空间与当前配置不同(如 `right_only`=5 动作,`simple`=7 动作)。 + +可选修复: + +1) 保持和 checkpoint 一致的动作空间: +```bash +python -m src.train_ppo --init-model-path /Users/roog/Work/FNT/SpMr/mario-rl-mvp/artifacts/models/ppo_SuperMarioBros-1-1-v0_20260212_164717/ppo_mario_ckpt_150000_steps.zip --movement right_only +``` + +2) 若你确实要切动作空间,用部分初始化(跳过不兼容动作头): +```bash +python -m src.train_ppo --init-model-path artifacts/models/latest_model.zip --movement simple --allow-partial-init +``` + +3) 或者直接不加载旧模型,从头训练新动作空间。 + +## 9. 最小 smoke test(按顺序执行) + +```bash +cd /Users/roog/Work/FNT/SpMr/mario-rl-mvp +source .venv/bin/activate + +# 1) 训练 1e4 steps,并至少写出 checkpoint + final model +python -m src.train_ppo \ + --total-timesteps 10000 \ + --save-freq 2000 \ + --n-envs 1 \ + --device cpu \ + --movement right_only + +# 2) 快速评估 +python -m src.eval --episodes 2 --max-steps 2000 + +# 3) 录制 10 秒视频 +python -m src.record_video --duration-sec 10 --fps 30 +``` + +验收标准: +- `artifacts/models/` 下有 `.zip` 模型 +- `artifacts/logs/` 下有 TensorBoard event 文件 +- `artifacts/videos/` 下有 `.mp4`(或失败时有 `_frames/` 帧序列) diff --git a/mario-rl-mvp/requirements.txt b/mario-rl-mvp/requirements.txt new file mode 100644 index 0000000..a747e8f --- /dev/null +++ b/mario-rl-mvp/requirements.txt @@ -0,0 +1,12 @@ +torch==2.5.1 +stable-baselines3==2.3.2 +gym==0.26.2 +gymnasium==0.29.1 +shimmy==1.3.0 +gym-super-mario-bros==7.4.0 +nes-py==8.2.1 +opencv-python==4.10.0.84 +numpy==1.26.4 +tensorboard==2.18.0 +imageio==2.36.1 +imageio-ffmpeg==0.5.1 diff --git a/mario-rl-mvp/src/__init__.py b/mario-rl-mvp/src/__init__.py new file mode 100644 index 0000000..216ba02 --- /dev/null +++ b/mario-rl-mvp/src/__init__.py @@ -0,0 +1 @@ +"""Mario RL MVP package.""" diff --git a/mario-rl-mvp/src/env.py b/mario-rl-mvp/src/env.py new file mode 100644 index 0000000..d82b6d6 --- /dev/null +++ b/mario-rl-mvp/src/env.py @@ -0,0 +1,337 @@ +from __future__ import annotations + +from collections import deque +from typing import Any, Callable, Deque, Dict, Optional, Tuple + +import cv2 +import gym +import gym_super_mario_bros +import numpy as np +from gym.spaces import Box +from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT +from nes_py.wrappers import JoypadSpace + + +def reset_compat(env: gym.Env, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict[str, Any]]: + try: + result = env.reset(seed=seed) + except TypeError: + result = env.reset() + if seed is not None and hasattr(env, "seed"): + env.seed(seed) + + if isinstance(result, tuple) and len(result) == 2: + obs, info = result + return obs, info + return result, {} + + +def step_compat(env: gym.Env, action: Any) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]: + result = env.step(action) + if isinstance(result, tuple) and len(result) == 5: + obs, reward, terminated, truncated, info = result + return obs, float(reward), bool(terminated), bool(truncated), info + if isinstance(result, tuple) and len(result) == 4: + obs, reward, done, info = result + return obs, float(reward), bool(done), False, info + raise RuntimeError(f"Unexpected step return format: {type(result)} / {result}") + + +class SkipFrame(gym.Wrapper): + def __init__(self, env: gym.Env, skip: int = 4): + super().__init__(env) + self._skip = skip + + def step(self, action: Any): + total_reward = 0.0 + terminated = False + truncated = False + info: Dict[str, Any] = {} + obs = None + for _ in range(self._skip): + obs, reward, terminated, truncated, info = step_compat(self.env, action) + total_reward += reward + if terminated or truncated: + break + return obs, total_reward, terminated, truncated, info + + +class PreprocessFrame(gym.ObservationWrapper): + """Convert RGB frame to grayscale 84x84 uint8.""" + + def __init__(self, env: gym.Env, width: int = 84, height: int = 84): + super().__init__(env) + self.width = width + self.height = height + self.observation_space = Box(low=0, high=255, shape=(height, width, 1), dtype=np.uint8) + + def observation(self, observation: np.ndarray) -> np.ndarray: + if observation.ndim == 3 and observation.shape[2] == 3: + gray = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY) + elif observation.ndim == 2: + gray = observation + else: + gray = np.squeeze(observation) + resized = cv2.resize(gray, (self.width, self.height), interpolation=cv2.INTER_AREA) + return resized[:, :, None].astype(np.uint8) + + +class ChannelLastFrameStack(gym.Wrapper): + """Stack frames on the channel axis: (H, W, C*num_stack).""" + + def __init__(self, env: gym.Env, num_stack: int = 4): + super().__init__(env) + self.num_stack = num_stack + self.frames: Deque[np.ndarray] = deque(maxlen=num_stack) + + obs_space = env.observation_space + assert isinstance(obs_space, Box), "Frame stack requires Box observation space." + h, w, c = obs_space.shape + self.observation_space = Box( + low=0, + high=255, + shape=(h, w, c * num_stack), + dtype=np.uint8, + ) + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + del options + obs, info = reset_compat(self.env, seed=seed) + self.frames.clear() + for _ in range(self.num_stack): + self.frames.append(obs) + return self._get_observation(), info + + def step(self, action: Any): + obs, reward, terminated, truncated, info = step_compat(self.env, action) + self.frames.append(obs) + return self._get_observation(), reward, terminated, truncated, info + + def _get_observation(self) -> np.ndarray: + assert len(self.frames) == self.num_stack + return np.concatenate(list(self.frames), axis=2) + + +class TransposeObservation(gym.ObservationWrapper): + """Convert observation from HWC to CHW for CNN policy.""" + + def __init__(self, env: gym.Env): + super().__init__(env) + obs_space = env.observation_space + assert isinstance(obs_space, Box), "TransposeObservation requires Box observation space." + h, w, c = obs_space.shape + self.observation_space = Box(low=0, high=255, shape=(c, h, w), dtype=obs_space.dtype) + + def observation(self, observation: np.ndarray) -> np.ndarray: + return np.transpose(observation, (2, 0, 1)).astype(np.uint8) + + +class ClipRewardEnv(gym.RewardWrapper): + def reward(self, reward): + return float(np.sign(reward)) + + +class ProgressRewardEnv(gym.Wrapper): + """Reward shaping focused on moving right and avoiding local traps.""" + + def __init__( + self, + env: gym.Env, + progress_scale: float = 0.02, + death_penalty: float = -50.0, + flag_bonus: float = 100.0, + stall_penalty: float = 0.05, + stall_steps: int = 40, + backward_penalty_scale: float = 0.01, + milestone_interval: int = 32, + milestone_bonus: float = 1.0, + no_progress_terminate_steps: int = 120, + no_progress_terminate_penalty: float = 20.0, + ): + super().__init__(env) + self.progress_scale = progress_scale + self.death_penalty = death_penalty + self.flag_bonus = flag_bonus + self.stall_penalty = stall_penalty + self.stall_steps = stall_steps + self.backward_penalty_scale = backward_penalty_scale + self.milestone_interval = milestone_interval + self.milestone_bonus = milestone_bonus + self.no_progress_terminate_steps = no_progress_terminate_steps + self.no_progress_terminate_penalty = no_progress_terminate_penalty + self._last_x_pos: Optional[float] = None + self._best_x_pos = 0.0 + self._stall_count = 0 + self._next_milestone_x = float(milestone_interval) + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + del options + obs, info = reset_compat(self.env, seed=seed) + self._last_x_pos = float(info.get("x_pos", 0.0)) + self._best_x_pos = self._last_x_pos + self._stall_count = 0 + if self.milestone_interval > 0: + k = int(self._best_x_pos // self.milestone_interval) + 1 + self._next_milestone_x = float(k * self.milestone_interval) + return obs, info + + def step(self, action: Any): + obs, reward, terminated, truncated, info = step_compat(self.env, action) + x_pos = float(info.get("x_pos", 0.0)) + + if self._last_x_pos is None: + delta_x = 0.0 + else: + delta_x = x_pos - self._last_x_pos + self._last_x_pos = x_pos + + shaped_reward = float(reward) + if delta_x > 0: + shaped_reward += self.progress_scale * delta_x + self._stall_count = 0 + if x_pos > self._best_x_pos: + self._best_x_pos = x_pos + if self.milestone_interval > 0 and self.milestone_bonus != 0.0: + while x_pos >= self._next_milestone_x: + shaped_reward += self.milestone_bonus + self._next_milestone_x += self.milestone_interval + else: + self._stall_count += 1 + if delta_x < 0: + shaped_reward -= self.backward_penalty_scale * abs(delta_x) + + if self._stall_count >= self.stall_steps: + shaped_reward -= self.stall_penalty + + if ( + self.no_progress_terminate_steps > 0 + and not terminated + and not truncated + and self._stall_count >= self.no_progress_terminate_steps + ): + truncated = True + shaped_reward -= self.no_progress_terminate_penalty + info["terminated_by_stall"] = True + + if terminated or truncated: + if bool(info.get("flag_get", False)): + shaped_reward += self.flag_bonus + elif terminated: + shaped_reward += self.death_penalty + + return obs, shaped_reward, terminated, truncated, info + + +def get_action_set(name: str): + name = name.lower().strip() + if name == "simple": + return SIMPLE_MOVEMENT + if name == "right_only": + return RIGHT_ONLY + raise ValueError(f"Unsupported movement='{name}'. Use one of: right_only, simple") + + +def make_mario_env( + env_id: str = "SuperMarioBros-1-1-v0", + seed: int = 0, + movement: str = "right_only", + reward_mode: str = "raw", + clip_reward: bool = False, + frame_skip: int = 4, + render_mode: Optional[str] = None, + progress_scale: float = 0.02, + death_penalty: float = -50.0, + flag_bonus: float = 100.0, + stall_penalty: float = 0.05, + stall_steps: int = 40, + backward_penalty_scale: float = 0.01, + milestone_interval: int = 32, + milestone_bonus: float = 1.0, + no_progress_terminate_steps: int = 120, + no_progress_terminate_penalty: float = 20.0, +) -> gym.Env: + kwargs: Dict[str, Any] = {} + if render_mode is not None: + kwargs["render_mode"] = render_mode + + try: + env = gym_super_mario_bros.make(env_id, apply_api_compatibility=True, **kwargs) + except TypeError: + env = gym_super_mario_bros.make(env_id, **kwargs) + + env = JoypadSpace(env, get_action_set(movement)) + env = SkipFrame(env, skip=frame_skip) + + mode = reward_mode.lower().strip() + if clip_reward and mode == "raw": + mode = "clip" + if mode == "clip": + env = ClipRewardEnv(env) + elif mode == "progress": + env = ProgressRewardEnv( + env=env, + progress_scale=progress_scale, + death_penalty=death_penalty, + flag_bonus=flag_bonus, + stall_penalty=stall_penalty, + stall_steps=stall_steps, + backward_penalty_scale=backward_penalty_scale, + milestone_interval=milestone_interval, + milestone_bonus=milestone_bonus, + no_progress_terminate_steps=no_progress_terminate_steps, + no_progress_terminate_penalty=no_progress_terminate_penalty, + ) + elif mode != "raw": + raise ValueError(f"Unsupported reward_mode='{reward_mode}'. Use one of: raw, clip, progress") + + env = PreprocessFrame(env, width=84, height=84) + env = ChannelLastFrameStack(env, num_stack=4) + env = TransposeObservation(env) + + # Seed once so each env subprocess/dummy env has deterministic startup. + reset_compat(env, seed=seed) + if hasattr(env.action_space, "seed"): + env.action_space.seed(seed) + return env + + +def make_env_fn( + env_id: str, + seed: int, + movement: str, + reward_mode: str, + clip_reward: bool, + frame_skip: int, + progress_scale: float, + death_penalty: float, + flag_bonus: float, + stall_penalty: float, + stall_steps: int, + backward_penalty_scale: float, + milestone_interval: int, + milestone_bonus: float, + no_progress_terminate_steps: int, + no_progress_terminate_penalty: float, +) -> Callable[[], gym.Env]: + def _thunk() -> gym.Env: + return make_mario_env( + env_id=env_id, + seed=seed, + movement=movement, + reward_mode=reward_mode, + clip_reward=clip_reward, + frame_skip=frame_skip, + render_mode=None, + progress_scale=progress_scale, + death_penalty=death_penalty, + flag_bonus=flag_bonus, + stall_penalty=stall_penalty, + stall_steps=stall_steps, + backward_penalty_scale=backward_penalty_scale, + milestone_interval=milestone_interval, + milestone_bonus=milestone_bonus, + no_progress_terminate_steps=no_progress_terminate_steps, + no_progress_terminate_penalty=no_progress_terminate_penalty, + ) + + return _thunk diff --git a/mario-rl-mvp/src/eval.py b/mario-rl-mvp/src/eval.py new file mode 100644 index 0000000..b903506 --- /dev/null +++ b/mario-rl-mvp/src/eval.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import argparse +import json +from pathlib import Path +from statistics import mean + +import numpy as np +from stable_baselines3 import PPO + +from src.env import get_action_set, make_mario_env, reset_compat, step_compat +from src.utils import ensure_artifact_paths, latest_model_path, seed_everything + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Evaluate trained Mario PPO agent.") + parser.add_argument("--model-path", type=str, default="", help="Path to .zip model. If empty, use latest model.") + parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0") + parser.add_argument("--movement", type=str, default="auto", choices=["auto", "right_only", "simple"]) + parser.add_argument("--episodes", type=int, default=5) + parser.add_argument("--max-steps", type=int, default=5_000) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--frame-skip", type=int, default=4) + parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"]) + parser.add_argument("--progress-scale", type=float, default=0.02) + parser.add_argument("--death-penalty", type=float, default=-50.0) + parser.add_argument("--flag-bonus", type=float, default=100.0) + parser.add_argument("--stall-penalty", type=float, default=0.05) + parser.add_argument("--stall-steps", type=int, default=40) + parser.add_argument("--backward-penalty-scale", type=float, default=0.01) + parser.add_argument("--milestone-interval", type=int, default=32) + parser.add_argument("--milestone-bonus", type=float, default=1.0) + parser.add_argument("--no-progress-terminate-steps", type=int, default=120) + parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0) + parser.add_argument("--clip-reward", action="store_true") + parser.add_argument("--stochastic", action="store_true", help="Use stochastic policy (deterministic=False).") + return parser.parse_args() + + +def _action_count_for_movement(movement: str) -> int: + return len(get_action_set(movement)) + + +def _model_action_count(model: PPO): + if hasattr(model, "action_space") and hasattr(model.action_space, "n"): + return int(model.action_space.n) + action_net = getattr(getattr(model, "policy", None), "action_net", None) + out_features = getattr(action_net, "out_features", None) + if out_features is not None: + return int(out_features) + return None + + +def resolve_movement(movement_arg: str, model: PPO) -> str: + model_action_n = _model_action_count(model) + + if movement_arg == "auto": + if model_action_n is not None: + for candidate in ("right_only", "simple"): + if _action_count_for_movement(candidate) == model_action_n: + return candidate + return "right_only" + + if model_action_n is not None: + expected_n = _action_count_for_movement(movement_arg) + if expected_n != model_action_n: + raise ValueError( + f"movement='{movement_arg}' has {expected_n} actions, but model expects {model_action_n}. " + "Use --movement auto or pass the matching movement." + ) + return movement_arg + + +def resolve_model_path(user_path: str) -> Path: + if user_path: + p = Path(user_path).expanduser().resolve() + if not p.exists(): + raise FileNotFoundError(f"Model not found: {p}") + return p + + paths = ensure_artifact_paths() + latest = latest_model_path(paths.models) + if latest is None: + raise FileNotFoundError("No model found under artifacts/models. Please run training first.") + return latest + + +def main() -> None: + args = parse_args() + seed_everything(args.seed) + reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode + + model_path = resolve_model_path(args.model_path) + print(f"[eval] model={model_path}") + model = PPO.load(str(model_path)) + movement = resolve_movement(args.movement, model) + print(f"[eval] movement={movement} reward_mode={reward_mode}") + + env = make_mario_env( + env_id=args.env_id, + seed=args.seed, + movement=movement, + reward_mode=reward_mode, + clip_reward=args.clip_reward, + frame_skip=args.frame_skip, + render_mode=None, + progress_scale=args.progress_scale, + death_penalty=args.death_penalty, + flag_bonus=args.flag_bonus, + stall_penalty=args.stall_penalty, + stall_steps=args.stall_steps, + backward_penalty_scale=args.backward_penalty_scale, + milestone_interval=args.milestone_interval, + milestone_bonus=args.milestone_bonus, + no_progress_terminate_steps=args.no_progress_terminate_steps, + no_progress_terminate_penalty=args.no_progress_terminate_penalty, + ) + + rewards = [] + max_x_positions = [] + clear_flags = [] + + for ep in range(1, args.episodes + 1): + obs, info = reset_compat(env, seed=args.seed + ep) + done = False + ep_reward = 0.0 + ep_max_x = float(info.get("x_pos", 0.0)) + flag_get = False + step_count = 0 + + while not done and step_count < args.max_steps: + action, _ = model.predict(obs, deterministic=not args.stochastic) + if isinstance(action, np.ndarray): + action = int(action.item()) + obs, reward, terminated, truncated, info = step_compat(env, action) + ep_reward += float(reward) + ep_max_x = max(ep_max_x, float(info.get("x_pos", 0.0))) + flag_get = flag_get or bool(info.get("flag_get", False)) + done = terminated or truncated + step_count += 1 + + rewards.append(ep_reward) + max_x_positions.append(ep_max_x) + clear_flags.append(1.0 if flag_get else 0.0) + print( + f"[episode {ep}] reward={ep_reward:.2f} max_x={ep_max_x:.1f} " + f"clear={flag_get} steps={step_count}" + ) + + summary = { + "episodes": args.episodes, + "avg_reward": mean(rewards) if rewards else 0.0, + "avg_max_x_pos": mean(max_x_positions) if max_x_positions else 0.0, + "clear_rate": mean(clear_flags) if clear_flags else 0.0, + } + print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2)) + + env.close() + + +if __name__ == "__main__": + main() diff --git a/mario-rl-mvp/src/record_video.py b/mario-rl-mvp/src/record_video.py new file mode 100644 index 0000000..8bfffa6 --- /dev/null +++ b/mario-rl-mvp/src/record_video.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import argparse +from datetime import datetime +from pathlib import Path + +import imageio.v2 as imageio +import numpy as np +from stable_baselines3 import PPO + +from src.env import get_action_set, make_mario_env, reset_compat, step_compat +from src.utils import ensure_artifact_paths, latest_model_path, seed_everything + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Record Mario agent rollout to mp4 (headless).") + parser.add_argument("--model-path", type=str, default="", help="Path to .zip model. Empty means latest model.") + parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0") + parser.add_argument("--movement", type=str, default="auto", choices=["auto", "right_only", "simple"]) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--frame-skip", type=int, default=4) + parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"]) + parser.add_argument("--progress-scale", type=float, default=0.02) + parser.add_argument("--death-penalty", type=float, default=-50.0) + parser.add_argument("--flag-bonus", type=float, default=100.0) + parser.add_argument("--stall-penalty", type=float, default=0.05) + parser.add_argument("--stall-steps", type=int, default=40) + parser.add_argument("--backward-penalty-scale", type=float, default=0.01) + parser.add_argument("--milestone-interval", type=int, default=32) + parser.add_argument("--milestone-bonus", type=float, default=1.0) + parser.add_argument("--no-progress-terminate-steps", type=int, default=120) + parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0) + parser.add_argument("--clip-reward", action="store_true") + parser.add_argument("--fps", type=int, default=30) + parser.add_argument("--duration-sec", type=int, default=10) + parser.add_argument("--max-steps", type=int, default=5_000) + parser.add_argument("--stochastic", action="store_true") + parser.add_argument("--output", type=str, default="", help="Output mp4 path.") + return parser.parse_args() + + +def _action_count_for_movement(movement: str) -> int: + return len(get_action_set(movement)) + + +def _model_action_count(model: PPO): + if hasattr(model, "action_space") and hasattr(model.action_space, "n"): + return int(model.action_space.n) + action_net = getattr(getattr(model, "policy", None), "action_net", None) + out_features = getattr(action_net, "out_features", None) + if out_features is not None: + return int(out_features) + return None + + +def resolve_movement(movement_arg: str, model: PPO) -> str: + model_action_n = _model_action_count(model) + + if movement_arg == "auto": + if model_action_n is not None: + for candidate in ("right_only", "simple"): + if _action_count_for_movement(candidate) == model_action_n: + return candidate + return "right_only" + + if model_action_n is not None: + expected_n = _action_count_for_movement(movement_arg) + if expected_n != model_action_n: + raise ValueError( + f"movement='{movement_arg}' has {expected_n} actions, but model expects {model_action_n}. " + "Use --movement auto or pass the matching movement." + ) + return movement_arg + + +def resolve_model(user_path: str) -> Path: + if user_path: + p = Path(user_path).expanduser().resolve() + if not p.exists(): + raise FileNotFoundError(f"Model not found: {p}") + return p + + paths = ensure_artifact_paths() + latest = latest_model_path(paths.models) + if latest is None: + raise FileNotFoundError("No model found under artifacts/models. Please run training first.") + return latest + + +def resolve_output_path(user_output: str) -> Path: + if user_output: + return Path(user_output).expanduser().resolve() + + paths = ensure_artifact_paths() + ts = datetime.now().strftime("%Y%m%d_%H%M%S") + return (paths.videos / f"mario_replay_{ts}.mp4").resolve() + + +def save_frames_fallback(frames, output_path: Path) -> Path: + frame_dir = output_path.with_suffix("") + frame_dir.mkdir(parents=True, exist_ok=True) + for i, frame in enumerate(frames): + imageio.imwrite(frame_dir / f"frame_{i:06d}.png", frame) + return frame_dir + + +def main() -> None: + args = parse_args() + seed_everything(args.seed) + reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode + + model_path = resolve_model(args.model_path) + output_path = resolve_output_path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + + model = PPO.load(str(model_path)) + movement = resolve_movement(args.movement, model) + print(f"[video] movement={movement} reward_mode={reward_mode}") + env = make_mario_env( + env_id=args.env_id, + seed=args.seed, + movement=movement, + reward_mode=reward_mode, + clip_reward=args.clip_reward, + frame_skip=args.frame_skip, + render_mode="rgb_array", + progress_scale=args.progress_scale, + death_penalty=args.death_penalty, + flag_bonus=args.flag_bonus, + stall_penalty=args.stall_penalty, + stall_steps=args.stall_steps, + backward_penalty_scale=args.backward_penalty_scale, + milestone_interval=args.milestone_interval, + milestone_bonus=args.milestone_bonus, + no_progress_terminate_steps=args.no_progress_terminate_steps, + no_progress_terminate_penalty=args.no_progress_terminate_penalty, + ) + + obs, _ = reset_compat(env, seed=args.seed) + frames = [] + + first_frame = env.render() + if first_frame is not None: + # nes-py may reuse the same frame buffer; copy to avoid aliasing all frames. + frames.append(first_frame.copy()) + + target_frames = max(1, args.fps * args.duration_sec) + done = False + step_count = 0 + reward_sum = 0.0 + episode_count = 1 + + while len(frames) < target_frames and step_count < args.max_steps: + action, _ = model.predict(obs, deterministic=not args.stochastic) + if isinstance(action, np.ndarray): + action = int(action.item()) + obs, reward, terminated, truncated, info = step_compat(env, action) + reward_sum += float(reward) + step_count += 1 + done = terminated or truncated + + frame = env.render() + if frame is not None: + frames.append(frame.copy()) + + if done: + obs, _ = reset_compat(env, seed=args.seed + episode_count) + episode_count += 1 + frame = env.render() + if frame is not None and len(frames) < target_frames: + frames.append(frame.copy()) + done = False + + if not frames: + raise RuntimeError("No frames captured. Check render_mode support and environment setup.") + + try: + writer = imageio.get_writer(str(output_path), fps=args.fps, codec="libx264", quality=8) + for frame in frames: + writer.append_data(frame) + writer.close() + print(f"[video] Saved mp4: {output_path}") + except Exception as exc: + frame_dir = save_frames_fallback(frames, output_path) + print(f"[warn] mp4 write failed: {exc}") + print(f"[fallback] Saved frame sequence: {frame_dir}") + print( + "[fallback] Convert frames to mp4 with: " + f"ffmpeg -framerate {args.fps} -i {frame_dir}/frame_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}" + ) + + print( + f"[stats] frames={len(frames)} approx_sec={len(frames)/max(args.fps,1):.2f} " + f"steps={step_count} reward_sum={reward_sum:.2f} episodes={episode_count}" + ) + env.close() + + +if __name__ == "__main__": + main() diff --git a/mario-rl-mvp/src/train_ppo.py b/mario-rl-mvp/src/train_ppo.py new file mode 100644 index 0000000..a809d95 --- /dev/null +++ b/mario-rl-mvp/src/train_ppo.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import argparse +import shutil +from datetime import datetime +from pathlib import Path + +from stable_baselines3 import PPO +from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback +from stable_baselines3.common.save_util import load_from_zip_file +from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor +from torch.utils.tensorboard import SummaryWriter + +from src.env import make_env_fn +from src.utils import ensure_artifact_paths, resolve_torch_device, seed_everything, write_latest_pointer + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Train PPO agent on NES Super Mario Bros.") + parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0") + parser.add_argument("--movement", type=str, default="right_only", choices=["right_only", "simple"]) + parser.add_argument("--total-timesteps", type=int, default=1_000_000) + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--n-envs", type=int, default=4) + parser.add_argument("--frame-skip", type=int, default=4) + parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"]) + parser.add_argument("--progress-scale", type=float, default=0.02) + parser.add_argument("--death-penalty", type=float, default=-50.0) + parser.add_argument("--flag-bonus", type=float, default=100.0) + parser.add_argument("--stall-penalty", type=float, default=0.05) + parser.add_argument("--stall-steps", type=int, default=40) + parser.add_argument("--backward-penalty-scale", type=float, default=0.01) + parser.add_argument("--milestone-interval", type=int, default=32) + parser.add_argument("--milestone-bonus", type=float, default=1.0) + parser.add_argument("--no-progress-terminate-steps", type=int, default=120) + parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0) + + parser.add_argument("--learning-rate", type=float, default=2.5e-4) + parser.add_argument("--n-steps", type=int, default=128) + parser.add_argument("--batch-size", type=int, default=256) + parser.add_argument("--n-epochs", type=int, default=4) + parser.add_argument("--gamma", type=float, default=0.99) + parser.add_argument("--gae-lambda", type=float, default=0.95) + parser.add_argument("--ent-coef", type=float, default=0.01) + parser.add_argument("--clip-range", type=float, default=0.2) + + parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.") + parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "mps"]) + parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.") + parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.") + parser.add_argument( + "--init-model-path", + type=str, + default="", + help="Optional .zip model path to initialize weights before training.", + ) + parser.add_argument( + "--allow-partial-init", + action="store_true", + help=( + "Allow partial init when checkpoint is from a different action space " + "(e.g. right_only -> simple). Compatible policy layers will be loaded, " + "incompatible layers (like action head) are skipped." + ), + ) + parser.add_argument("--progress-bar", action="store_true", help="Enable SB3 progress bar (requires tqdm/rich).") + return parser.parse_args() + + +def load_partial_policy_weights(model: PPO, init_model_path: Path, device: str) -> tuple[int, int]: + _, params, _ = load_from_zip_file(str(init_model_path), device=device) + if params is None or "policy" not in params: + raise RuntimeError(f"No policy parameters found in checkpoint: {init_model_path}") + + src_state = params["policy"] + dst_state = model.policy.state_dict() + + compatible_state = {} + skipped_count = 0 + for key, value in src_state.items(): + if key in dst_state and dst_state[key].shape == value.shape: + compatible_state[key] = value + else: + skipped_count += 1 + + if not compatible_state: + raise RuntimeError("No compatible policy tensors found for partial init.") + + model.policy.load_state_dict(compatible_state, strict=False) + return len(compatible_state), skipped_count + + +class EpisodeEndLoggingCallback(BaseCallback): + """Log per-episode terminal diagnostics to stdout and TensorBoard.""" + + REASON_TO_CODE = {"death": 0, "no_progress": 1, "timeout": 2, "clear": 3} + + def __init__(self, n_envs: int, tb_log_dir: Path): + super().__init__(verbose=0) + self.n_envs = n_envs + self.tb_log_dir = tb_log_dir + self.tb_writer: SummaryWriter | None = None + self.episode_count = 0 + self.episode_max_x_pos = [0.0 for _ in range(max(n_envs, 1))] + + @staticmethod + def _resolve_done_reason(info: dict) -> str: + if bool(info.get("flag_get", False)): + return "clear" + if bool(info.get("terminated_by_stall", False)): + return "no_progress" + if bool(info.get("TimeLimit.truncated", False)): + return "timeout" + try: + if float(info.get("time", 1.0)) <= 0.0: + return "timeout" + except (TypeError, ValueError): + pass + return "death" + + def _on_training_start(self) -> None: + self.tb_log_dir.mkdir(parents=True, exist_ok=True) + self.tb_writer = SummaryWriter(log_dir=str(self.tb_log_dir)) + + def _on_step(self) -> bool: + infos = self.locals.get("infos") + dones = self.locals.get("dones") + if infos is None or dones is None: + return True + + for env_idx, (info, done) in enumerate(zip(infos, dones)): + if env_idx >= len(self.episode_max_x_pos): + continue + x_pos = float(info.get("x_pos", 0.0)) + if x_pos > self.episode_max_x_pos[env_idx]: + self.episode_max_x_pos[env_idx] = x_pos + + if not bool(done): + continue + + self.episode_count += 1 + max_x_pos = self.episode_max_x_pos[env_idx] + flag_get = 1.0 if bool(info.get("flag_get", False)) else 0.0 + done_reason = self._resolve_done_reason(info) + done_reason_code = float(self.REASON_TO_CODE[done_reason]) + episode_step = self.episode_count + + self.logger.record_mean("rollout/episode_max_x_pos", max_x_pos) + self.logger.record_mean("rollout/flag_get", flag_get) + self.logger.record_mean(f"rollout/done_reason_{done_reason}", 1.0) + + if self.tb_writer is not None: + self.tb_writer.add_scalar("episode_end/episode_max_x_pos", max_x_pos, episode_step) + self.tb_writer.add_scalar("episode_end/flag_get", flag_get, episode_step) + self.tb_writer.add_scalar("episode_end/done_reason_code", done_reason_code, episode_step) + self.tb_writer.add_scalar("episode_end/done_reason_death", 1.0 if done_reason == "death" else 0.0, episode_step) + self.tb_writer.add_scalar( + "episode_end/done_reason_no_progress", 1.0 if done_reason == "no_progress" else 0.0, episode_step + ) + self.tb_writer.add_scalar("episode_end/done_reason_timeout", 1.0 if done_reason == "timeout" else 0.0, episode_step) + self.tb_writer.add_scalar("episode_end/done_reason_clear", 1.0 if done_reason == "clear" else 0.0, episode_step) + self.tb_writer.flush() + + print( + f"[episode_end] ep={episode_step} env={env_idx} reason={done_reason} " + f"max_x={max_x_pos:.1f} flag_get={bool(flag_get)}" + ) + self.episode_max_x_pos[env_idx] = 0.0 + + return True + + def _on_training_end(self) -> None: + if self.tb_writer is not None: + self.tb_writer.close() + self.tb_writer = None + + +def main() -> None: + args = parse_args() + seed_everything(args.seed) + reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode + + paths = ensure_artifact_paths() + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + env_slug = args.env_id.replace("/", "_") + run_name = args.run_name.strip() or f"ppo_{env_slug}_{timestamp}" + + run_model_dir = paths.models / run_name + run_log_dir = paths.logs / run_name + run_tb_dir = run_log_dir / "tb" + run_model_dir.mkdir(parents=True, exist_ok=True) + run_tb_dir.mkdir(parents=True, exist_ok=True) + + device, device_msg = resolve_torch_device(args.device) + print(f"[device] {device} | {device_msg}") + print(f"[run] {run_name}") + print( + f"[config] env_id={args.env_id}, movement={args.movement}, reward_mode={reward_mode}, " + f"clip_reward={args.clip_reward}, n_envs={args.n_envs}" + ) + + env_fns = [ + make_env_fn( + env_id=args.env_id, + seed=args.seed + i, + movement=args.movement, + reward_mode=reward_mode, + clip_reward=args.clip_reward, + frame_skip=args.frame_skip, + progress_scale=args.progress_scale, + death_penalty=args.death_penalty, + flag_bonus=args.flag_bonus, + stall_penalty=args.stall_penalty, + stall_steps=args.stall_steps, + backward_penalty_scale=args.backward_penalty_scale, + milestone_interval=args.milestone_interval, + milestone_bonus=args.milestone_bonus, + no_progress_terminate_steps=args.no_progress_terminate_steps, + no_progress_terminate_penalty=args.no_progress_terminate_penalty, + ) + for i in range(args.n_envs) + ] + vec_env = DummyVecEnv(env_fns) + vec_env = VecMonitor(vec_env, filename=str(run_log_dir / "monitor.csv")) + + model = PPO( + policy="CnnPolicy", + env=vec_env, + learning_rate=args.learning_rate, + n_steps=args.n_steps, + batch_size=args.batch_size, + n_epochs=args.n_epochs, + gamma=args.gamma, + gae_lambda=args.gae_lambda, + ent_coef=args.ent_coef, + clip_range=args.clip_range, + tensorboard_log=str(run_tb_dir), + seed=args.seed, + verbose=1, + device=device, + ) + + if args.init_model_path.strip(): + init_model_path = Path(args.init_model_path).expanduser().resolve() + if not init_model_path.exists(): + raise FileNotFoundError(f"Init model not found: {init_model_path}") + try: + model.set_parameters(str(init_model_path), exact_match=False, device=device) + print(f"[init] Loaded initial weights from: {init_model_path}") + except RuntimeError as exc: + if not args.allow_partial_init: + raise RuntimeError( + f"{exc}\n" + "Hint: checkpoint and current env may use different action spaces. " + "Try one of:\n" + "1) keep the same --movement as the checkpoint;\n" + "2) remove --init-model-path and train from scratch;\n" + "3) add --allow-partial-init to load compatible layers only." + ) from exc + loaded_count, skipped_count = load_partial_policy_weights(model, init_model_path, device=device) + print( + f"[init] Partial init from {init_model_path}: " + f"loaded_tensors={loaded_count}, skipped_tensors={skipped_count}" + ) + + callback = CheckpointCallback( + save_freq=max(args.save_freq // max(args.n_envs, 1), 1), + save_path=str(run_model_dir), + name_prefix="ppo_mario_ckpt", + save_replay_buffer=False, + save_vecnormalize=False, + ) + episode_end_logging_callback = EpisodeEndLoggingCallback( + n_envs=args.n_envs, + tb_log_dir=run_tb_dir / "episode_end", + ) + + try: + model.learn( + total_timesteps=args.total_timesteps, + callback=CallbackList([callback, episode_end_logging_callback]), + tb_log_name="ppo", + progress_bar=args.progress_bar, + ) + finally: + vec_env.close() + + final_model = run_model_dir / "final_model" + model.save(str(final_model)) + + latest_model = paths.models / "latest_model.zip" + shutil.copy2(str(final_model) + ".zip", latest_model) + pointer = write_latest_pointer(paths.models, latest_model) + + print(f"[done] Final model: {final_model}.zip") + print(f"[done] Latest model pointer: {pointer}") + + +if __name__ == "__main__": + main() diff --git a/mario-rl-mvp/src/utils.py b/mario-rl-mvp/src/utils.py new file mode 100644 index 0000000..fe62204 --- /dev/null +++ b/mario-rl-mvp/src/utils.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import random +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Tuple + +import numpy as np +import torch + + +@dataclass(frozen=True) +class ArtifactPaths: + root: Path + models: Path + videos: Path + logs: Path + + +def project_root() -> Path: + return Path(__file__).resolve().parents[1] + + +def ensure_artifact_paths(root: Optional[Path] = None) -> ArtifactPaths: + root = root or project_root() + artifacts = root / "artifacts" + models = artifacts / "models" + videos = artifacts / "videos" + logs = artifacts / "logs" + for p in (artifacts, models, videos, logs): + p.mkdir(parents=True, exist_ok=True) + return ArtifactPaths(root=artifacts, models=models, videos=videos, logs=logs) + + +def seed_everything(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]: + requested = requested.lower().strip() + if requested == "cpu": + return "cpu", "User requested CPU." + + if requested not in {"auto", "mps"}: + return "cpu", f"Unknown device '{requested}', fallback to CPU." + + mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + if not mps_available: + if requested == "mps": + return "cpu", "MPS requested but unavailable, fallback to CPU." + return "cpu", "MPS unavailable, using CPU." + + try: + x = torch.ones(8, device="mps") + _ = (x * 2).cpu().numpy() + return "mps", "MPS is available and passed a quick tensor sanity check." + except Exception as exc: # pragma: no cover - hardware dependent + return "cpu", f"MPS check failed ({exc}), fallback to CPU." + + +def latest_model_path(models_dir: Path) -> Optional[Path]: + candidates = [p for p in models_dir.rglob("*.zip") if p.is_file()] + if not candidates: + return None + return max(candidates, key=lambda p: p.stat().st_mtime) + + +def write_latest_pointer(models_dir: Path, model_path: Path) -> Path: + pointer = models_dir / "LATEST_MODEL.txt" + pointer.write_text(str(model_path.resolve()) + "\n", encoding="utf-8") + return pointer