feat: initial mario rl mvp

This commit is contained in:
2026-02-12 18:54:06 +08:00
commit d23de69b9a
9 changed files with 1470 additions and 0 deletions

View File

@@ -0,0 +1 @@
"""Mario RL MVP package."""

337
mario-rl-mvp/src/env.py Normal file
View File

@@ -0,0 +1,337 @@
from __future__ import annotations
from collections import deque
from typing import Any, Callable, Deque, Dict, Optional, Tuple
import cv2
import gym
import gym_super_mario_bros
import numpy as np
from gym.spaces import Box
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT
from nes_py.wrappers import JoypadSpace
def reset_compat(env: gym.Env, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict[str, Any]]:
try:
result = env.reset(seed=seed)
except TypeError:
result = env.reset()
if seed is not None and hasattr(env, "seed"):
env.seed(seed)
if isinstance(result, tuple) and len(result) == 2:
obs, info = result
return obs, info
return result, {}
def step_compat(env: gym.Env, action: Any) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
result = env.step(action)
if isinstance(result, tuple) and len(result) == 5:
obs, reward, terminated, truncated, info = result
return obs, float(reward), bool(terminated), bool(truncated), info
if isinstance(result, tuple) and len(result) == 4:
obs, reward, done, info = result
return obs, float(reward), bool(done), False, info
raise RuntimeError(f"Unexpected step return format: {type(result)} / {result}")
class SkipFrame(gym.Wrapper):
def __init__(self, env: gym.Env, skip: int = 4):
super().__init__(env)
self._skip = skip
def step(self, action: Any):
total_reward = 0.0
terminated = False
truncated = False
info: Dict[str, Any] = {}
obs = None
for _ in range(self._skip):
obs, reward, terminated, truncated, info = step_compat(self.env, action)
total_reward += reward
if terminated or truncated:
break
return obs, total_reward, terminated, truncated, info
class PreprocessFrame(gym.ObservationWrapper):
"""Convert RGB frame to grayscale 84x84 uint8."""
def __init__(self, env: gym.Env, width: int = 84, height: int = 84):
super().__init__(env)
self.width = width
self.height = height
self.observation_space = Box(low=0, high=255, shape=(height, width, 1), dtype=np.uint8)
def observation(self, observation: np.ndarray) -> np.ndarray:
if observation.ndim == 3 and observation.shape[2] == 3:
gray = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
elif observation.ndim == 2:
gray = observation
else:
gray = np.squeeze(observation)
resized = cv2.resize(gray, (self.width, self.height), interpolation=cv2.INTER_AREA)
return resized[:, :, None].astype(np.uint8)
class ChannelLastFrameStack(gym.Wrapper):
"""Stack frames on the channel axis: (H, W, C*num_stack)."""
def __init__(self, env: gym.Env, num_stack: int = 4):
super().__init__(env)
self.num_stack = num_stack
self.frames: Deque[np.ndarray] = deque(maxlen=num_stack)
obs_space = env.observation_space
assert isinstance(obs_space, Box), "Frame stack requires Box observation space."
h, w, c = obs_space.shape
self.observation_space = Box(
low=0,
high=255,
shape=(h, w, c * num_stack),
dtype=np.uint8,
)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
del options
obs, info = reset_compat(self.env, seed=seed)
self.frames.clear()
for _ in range(self.num_stack):
self.frames.append(obs)
return self._get_observation(), info
def step(self, action: Any):
obs, reward, terminated, truncated, info = step_compat(self.env, action)
self.frames.append(obs)
return self._get_observation(), reward, terminated, truncated, info
def _get_observation(self) -> np.ndarray:
assert len(self.frames) == self.num_stack
return np.concatenate(list(self.frames), axis=2)
class TransposeObservation(gym.ObservationWrapper):
"""Convert observation from HWC to CHW for CNN policy."""
def __init__(self, env: gym.Env):
super().__init__(env)
obs_space = env.observation_space
assert isinstance(obs_space, Box), "TransposeObservation requires Box observation space."
h, w, c = obs_space.shape
self.observation_space = Box(low=0, high=255, shape=(c, h, w), dtype=obs_space.dtype)
def observation(self, observation: np.ndarray) -> np.ndarray:
return np.transpose(observation, (2, 0, 1)).astype(np.uint8)
class ClipRewardEnv(gym.RewardWrapper):
def reward(self, reward):
return float(np.sign(reward))
class ProgressRewardEnv(gym.Wrapper):
"""Reward shaping focused on moving right and avoiding local traps."""
def __init__(
self,
env: gym.Env,
progress_scale: float = 0.02,
death_penalty: float = -50.0,
flag_bonus: float = 100.0,
stall_penalty: float = 0.05,
stall_steps: int = 40,
backward_penalty_scale: float = 0.01,
milestone_interval: int = 32,
milestone_bonus: float = 1.0,
no_progress_terminate_steps: int = 120,
no_progress_terminate_penalty: float = 20.0,
):
super().__init__(env)
self.progress_scale = progress_scale
self.death_penalty = death_penalty
self.flag_bonus = flag_bonus
self.stall_penalty = stall_penalty
self.stall_steps = stall_steps
self.backward_penalty_scale = backward_penalty_scale
self.milestone_interval = milestone_interval
self.milestone_bonus = milestone_bonus
self.no_progress_terminate_steps = no_progress_terminate_steps
self.no_progress_terminate_penalty = no_progress_terminate_penalty
self._last_x_pos: Optional[float] = None
self._best_x_pos = 0.0
self._stall_count = 0
self._next_milestone_x = float(milestone_interval)
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
del options
obs, info = reset_compat(self.env, seed=seed)
self._last_x_pos = float(info.get("x_pos", 0.0))
self._best_x_pos = self._last_x_pos
self._stall_count = 0
if self.milestone_interval > 0:
k = int(self._best_x_pos // self.milestone_interval) + 1
self._next_milestone_x = float(k * self.milestone_interval)
return obs, info
def step(self, action: Any):
obs, reward, terminated, truncated, info = step_compat(self.env, action)
x_pos = float(info.get("x_pos", 0.0))
if self._last_x_pos is None:
delta_x = 0.0
else:
delta_x = x_pos - self._last_x_pos
self._last_x_pos = x_pos
shaped_reward = float(reward)
if delta_x > 0:
shaped_reward += self.progress_scale * delta_x
self._stall_count = 0
if x_pos > self._best_x_pos:
self._best_x_pos = x_pos
if self.milestone_interval > 0 and self.milestone_bonus != 0.0:
while x_pos >= self._next_milestone_x:
shaped_reward += self.milestone_bonus
self._next_milestone_x += self.milestone_interval
else:
self._stall_count += 1
if delta_x < 0:
shaped_reward -= self.backward_penalty_scale * abs(delta_x)
if self._stall_count >= self.stall_steps:
shaped_reward -= self.stall_penalty
if (
self.no_progress_terminate_steps > 0
and not terminated
and not truncated
and self._stall_count >= self.no_progress_terminate_steps
):
truncated = True
shaped_reward -= self.no_progress_terminate_penalty
info["terminated_by_stall"] = True
if terminated or truncated:
if bool(info.get("flag_get", False)):
shaped_reward += self.flag_bonus
elif terminated:
shaped_reward += self.death_penalty
return obs, shaped_reward, terminated, truncated, info
def get_action_set(name: str):
name = name.lower().strip()
if name == "simple":
return SIMPLE_MOVEMENT
if name == "right_only":
return RIGHT_ONLY
raise ValueError(f"Unsupported movement='{name}'. Use one of: right_only, simple")
def make_mario_env(
env_id: str = "SuperMarioBros-1-1-v0",
seed: int = 0,
movement: str = "right_only",
reward_mode: str = "raw",
clip_reward: bool = False,
frame_skip: int = 4,
render_mode: Optional[str] = None,
progress_scale: float = 0.02,
death_penalty: float = -50.0,
flag_bonus: float = 100.0,
stall_penalty: float = 0.05,
stall_steps: int = 40,
backward_penalty_scale: float = 0.01,
milestone_interval: int = 32,
milestone_bonus: float = 1.0,
no_progress_terminate_steps: int = 120,
no_progress_terminate_penalty: float = 20.0,
) -> gym.Env:
kwargs: Dict[str, Any] = {}
if render_mode is not None:
kwargs["render_mode"] = render_mode
try:
env = gym_super_mario_bros.make(env_id, apply_api_compatibility=True, **kwargs)
except TypeError:
env = gym_super_mario_bros.make(env_id, **kwargs)
env = JoypadSpace(env, get_action_set(movement))
env = SkipFrame(env, skip=frame_skip)
mode = reward_mode.lower().strip()
if clip_reward and mode == "raw":
mode = "clip"
if mode == "clip":
env = ClipRewardEnv(env)
elif mode == "progress":
env = ProgressRewardEnv(
env=env,
progress_scale=progress_scale,
death_penalty=death_penalty,
flag_bonus=flag_bonus,
stall_penalty=stall_penalty,
stall_steps=stall_steps,
backward_penalty_scale=backward_penalty_scale,
milestone_interval=milestone_interval,
milestone_bonus=milestone_bonus,
no_progress_terminate_steps=no_progress_terminate_steps,
no_progress_terminate_penalty=no_progress_terminate_penalty,
)
elif mode != "raw":
raise ValueError(f"Unsupported reward_mode='{reward_mode}'. Use one of: raw, clip, progress")
env = PreprocessFrame(env, width=84, height=84)
env = ChannelLastFrameStack(env, num_stack=4)
env = TransposeObservation(env)
# Seed once so each env subprocess/dummy env has deterministic startup.
reset_compat(env, seed=seed)
if hasattr(env.action_space, "seed"):
env.action_space.seed(seed)
return env
def make_env_fn(
env_id: str,
seed: int,
movement: str,
reward_mode: str,
clip_reward: bool,
frame_skip: int,
progress_scale: float,
death_penalty: float,
flag_bonus: float,
stall_penalty: float,
stall_steps: int,
backward_penalty_scale: float,
milestone_interval: int,
milestone_bonus: float,
no_progress_terminate_steps: int,
no_progress_terminate_penalty: float,
) -> Callable[[], gym.Env]:
def _thunk() -> gym.Env:
return make_mario_env(
env_id=env_id,
seed=seed,
movement=movement,
reward_mode=reward_mode,
clip_reward=clip_reward,
frame_skip=frame_skip,
render_mode=None,
progress_scale=progress_scale,
death_penalty=death_penalty,
flag_bonus=flag_bonus,
stall_penalty=stall_penalty,
stall_steps=stall_steps,
backward_penalty_scale=backward_penalty_scale,
milestone_interval=milestone_interval,
milestone_bonus=milestone_bonus,
no_progress_terminate_steps=no_progress_terminate_steps,
no_progress_terminate_penalty=no_progress_terminate_penalty,
)
return _thunk

162
mario-rl-mvp/src/eval.py Normal file
View File

@@ -0,0 +1,162 @@
from __future__ import annotations
import argparse
import json
from pathlib import Path
from statistics import mean
import numpy as np
from stable_baselines3 import PPO
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate trained Mario PPO agent.")
parser.add_argument("--model-path", type=str, default="", help="Path to .zip model. If empty, use latest model.")
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
parser.add_argument("--movement", type=str, default="auto", choices=["auto", "right_only", "simple"])
parser.add_argument("--episodes", type=int, default=5)
parser.add_argument("--max-steps", type=int, default=5_000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--frame-skip", type=int, default=4)
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
parser.add_argument("--progress-scale", type=float, default=0.02)
parser.add_argument("--death-penalty", type=float, default=-50.0)
parser.add_argument("--flag-bonus", type=float, default=100.0)
parser.add_argument("--stall-penalty", type=float, default=0.05)
parser.add_argument("--stall-steps", type=int, default=40)
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
parser.add_argument("--milestone-interval", type=int, default=32)
parser.add_argument("--milestone-bonus", type=float, default=1.0)
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
parser.add_argument("--clip-reward", action="store_true")
parser.add_argument("--stochastic", action="store_true", help="Use stochastic policy (deterministic=False).")
return parser.parse_args()
def _action_count_for_movement(movement: str) -> int:
return len(get_action_set(movement))
def _model_action_count(model: PPO):
if hasattr(model, "action_space") and hasattr(model.action_space, "n"):
return int(model.action_space.n)
action_net = getattr(getattr(model, "policy", None), "action_net", None)
out_features = getattr(action_net, "out_features", None)
if out_features is not None:
return int(out_features)
return None
def resolve_movement(movement_arg: str, model: PPO) -> str:
model_action_n = _model_action_count(model)
if movement_arg == "auto":
if model_action_n is not None:
for candidate in ("right_only", "simple"):
if _action_count_for_movement(candidate) == model_action_n:
return candidate
return "right_only"
if model_action_n is not None:
expected_n = _action_count_for_movement(movement_arg)
if expected_n != model_action_n:
raise ValueError(
f"movement='{movement_arg}' has {expected_n} actions, but model expects {model_action_n}. "
"Use --movement auto or pass the matching movement."
)
return movement_arg
def resolve_model_path(user_path: str) -> Path:
if user_path:
p = Path(user_path).expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Model not found: {p}")
return p
paths = ensure_artifact_paths()
latest = latest_model_path(paths.models)
if latest is None:
raise FileNotFoundError("No model found under artifacts/models. Please run training first.")
return latest
def main() -> None:
args = parse_args()
seed_everything(args.seed)
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
model_path = resolve_model_path(args.model_path)
print(f"[eval] model={model_path}")
model = PPO.load(str(model_path))
movement = resolve_movement(args.movement, model)
print(f"[eval] movement={movement} reward_mode={reward_mode}")
env = make_mario_env(
env_id=args.env_id,
seed=args.seed,
movement=movement,
reward_mode=reward_mode,
clip_reward=args.clip_reward,
frame_skip=args.frame_skip,
render_mode=None,
progress_scale=args.progress_scale,
death_penalty=args.death_penalty,
flag_bonus=args.flag_bonus,
stall_penalty=args.stall_penalty,
stall_steps=args.stall_steps,
backward_penalty_scale=args.backward_penalty_scale,
milestone_interval=args.milestone_interval,
milestone_bonus=args.milestone_bonus,
no_progress_terminate_steps=args.no_progress_terminate_steps,
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
)
rewards = []
max_x_positions = []
clear_flags = []
for ep in range(1, args.episodes + 1):
obs, info = reset_compat(env, seed=args.seed + ep)
done = False
ep_reward = 0.0
ep_max_x = float(info.get("x_pos", 0.0))
flag_get = False
step_count = 0
while not done and step_count < args.max_steps:
action, _ = model.predict(obs, deterministic=not args.stochastic)
if isinstance(action, np.ndarray):
action = int(action.item())
obs, reward, terminated, truncated, info = step_compat(env, action)
ep_reward += float(reward)
ep_max_x = max(ep_max_x, float(info.get("x_pos", 0.0)))
flag_get = flag_get or bool(info.get("flag_get", False))
done = terminated or truncated
step_count += 1
rewards.append(ep_reward)
max_x_positions.append(ep_max_x)
clear_flags.append(1.0 if flag_get else 0.0)
print(
f"[episode {ep}] reward={ep_reward:.2f} max_x={ep_max_x:.1f} "
f"clear={flag_get} steps={step_count}"
)
summary = {
"episodes": args.episodes,
"avg_reward": mean(rewards) if rewards else 0.0,
"avg_max_x_pos": mean(max_x_positions) if max_x_positions else 0.0,
"clear_rate": mean(clear_flags) if clear_flags else 0.0,
}
print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2))
env.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,200 @@
from __future__ import annotations
import argparse
from datetime import datetime
from pathlib import Path
import imageio.v2 as imageio
import numpy as np
from stable_baselines3 import PPO
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Record Mario agent rollout to mp4 (headless).")
parser.add_argument("--model-path", type=str, default="", help="Path to .zip model. Empty means latest model.")
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
parser.add_argument("--movement", type=str, default="auto", choices=["auto", "right_only", "simple"])
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--frame-skip", type=int, default=4)
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
parser.add_argument("--progress-scale", type=float, default=0.02)
parser.add_argument("--death-penalty", type=float, default=-50.0)
parser.add_argument("--flag-bonus", type=float, default=100.0)
parser.add_argument("--stall-penalty", type=float, default=0.05)
parser.add_argument("--stall-steps", type=int, default=40)
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
parser.add_argument("--milestone-interval", type=int, default=32)
parser.add_argument("--milestone-bonus", type=float, default=1.0)
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
parser.add_argument("--clip-reward", action="store_true")
parser.add_argument("--fps", type=int, default=30)
parser.add_argument("--duration-sec", type=int, default=10)
parser.add_argument("--max-steps", type=int, default=5_000)
parser.add_argument("--stochastic", action="store_true")
parser.add_argument("--output", type=str, default="", help="Output mp4 path.")
return parser.parse_args()
def _action_count_for_movement(movement: str) -> int:
return len(get_action_set(movement))
def _model_action_count(model: PPO):
if hasattr(model, "action_space") and hasattr(model.action_space, "n"):
return int(model.action_space.n)
action_net = getattr(getattr(model, "policy", None), "action_net", None)
out_features = getattr(action_net, "out_features", None)
if out_features is not None:
return int(out_features)
return None
def resolve_movement(movement_arg: str, model: PPO) -> str:
model_action_n = _model_action_count(model)
if movement_arg == "auto":
if model_action_n is not None:
for candidate in ("right_only", "simple"):
if _action_count_for_movement(candidate) == model_action_n:
return candidate
return "right_only"
if model_action_n is not None:
expected_n = _action_count_for_movement(movement_arg)
if expected_n != model_action_n:
raise ValueError(
f"movement='{movement_arg}' has {expected_n} actions, but model expects {model_action_n}. "
"Use --movement auto or pass the matching movement."
)
return movement_arg
def resolve_model(user_path: str) -> Path:
if user_path:
p = Path(user_path).expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Model not found: {p}")
return p
paths = ensure_artifact_paths()
latest = latest_model_path(paths.models)
if latest is None:
raise FileNotFoundError("No model found under artifacts/models. Please run training first.")
return latest
def resolve_output_path(user_output: str) -> Path:
if user_output:
return Path(user_output).expanduser().resolve()
paths = ensure_artifact_paths()
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
return (paths.videos / f"mario_replay_{ts}.mp4").resolve()
def save_frames_fallback(frames, output_path: Path) -> Path:
frame_dir = output_path.with_suffix("")
frame_dir.mkdir(parents=True, exist_ok=True)
for i, frame in enumerate(frames):
imageio.imwrite(frame_dir / f"frame_{i:06d}.png", frame)
return frame_dir
def main() -> None:
args = parse_args()
seed_everything(args.seed)
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
model_path = resolve_model(args.model_path)
output_path = resolve_output_path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
model = PPO.load(str(model_path))
movement = resolve_movement(args.movement, model)
print(f"[video] movement={movement} reward_mode={reward_mode}")
env = make_mario_env(
env_id=args.env_id,
seed=args.seed,
movement=movement,
reward_mode=reward_mode,
clip_reward=args.clip_reward,
frame_skip=args.frame_skip,
render_mode="rgb_array",
progress_scale=args.progress_scale,
death_penalty=args.death_penalty,
flag_bonus=args.flag_bonus,
stall_penalty=args.stall_penalty,
stall_steps=args.stall_steps,
backward_penalty_scale=args.backward_penalty_scale,
milestone_interval=args.milestone_interval,
milestone_bonus=args.milestone_bonus,
no_progress_terminate_steps=args.no_progress_terminate_steps,
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
)
obs, _ = reset_compat(env, seed=args.seed)
frames = []
first_frame = env.render()
if first_frame is not None:
# nes-py may reuse the same frame buffer; copy to avoid aliasing all frames.
frames.append(first_frame.copy())
target_frames = max(1, args.fps * args.duration_sec)
done = False
step_count = 0
reward_sum = 0.0
episode_count = 1
while len(frames) < target_frames and step_count < args.max_steps:
action, _ = model.predict(obs, deterministic=not args.stochastic)
if isinstance(action, np.ndarray):
action = int(action.item())
obs, reward, terminated, truncated, info = step_compat(env, action)
reward_sum += float(reward)
step_count += 1
done = terminated or truncated
frame = env.render()
if frame is not None:
frames.append(frame.copy())
if done:
obs, _ = reset_compat(env, seed=args.seed + episode_count)
episode_count += 1
frame = env.render()
if frame is not None and len(frames) < target_frames:
frames.append(frame.copy())
done = False
if not frames:
raise RuntimeError("No frames captured. Check render_mode support and environment setup.")
try:
writer = imageio.get_writer(str(output_path), fps=args.fps, codec="libx264", quality=8)
for frame in frames:
writer.append_data(frame)
writer.close()
print(f"[video] Saved mp4: {output_path}")
except Exception as exc:
frame_dir = save_frames_fallback(frames, output_path)
print(f"[warn] mp4 write failed: {exc}")
print(f"[fallback] Saved frame sequence: {frame_dir}")
print(
"[fallback] Convert frames to mp4 with: "
f"ffmpeg -framerate {args.fps} -i {frame_dir}/frame_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}"
)
print(
f"[stats] frames={len(frames)} approx_sec={len(frames)/max(args.fps,1):.2f} "
f"steps={step_count} reward_sum={reward_sum:.2f} episodes={episode_count}"
)
env.close()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,300 @@
from __future__ import annotations
import argparse
import shutil
from datetime import datetime
from pathlib import Path
from stable_baselines3 import PPO
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback
from stable_baselines3.common.save_util import load_from_zip_file
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
from torch.utils.tensorboard import SummaryWriter
from src.env import make_env_fn
from src.utils import ensure_artifact_paths, resolve_torch_device, seed_everything, write_latest_pointer
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train PPO agent on NES Super Mario Bros.")
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
parser.add_argument("--movement", type=str, default="right_only", choices=["right_only", "simple"])
parser.add_argument("--total-timesteps", type=int, default=1_000_000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--n-envs", type=int, default=4)
parser.add_argument("--frame-skip", type=int, default=4)
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
parser.add_argument("--progress-scale", type=float, default=0.02)
parser.add_argument("--death-penalty", type=float, default=-50.0)
parser.add_argument("--flag-bonus", type=float, default=100.0)
parser.add_argument("--stall-penalty", type=float, default=0.05)
parser.add_argument("--stall-steps", type=int, default=40)
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
parser.add_argument("--milestone-interval", type=int, default=32)
parser.add_argument("--milestone-bonus", type=float, default=1.0)
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
parser.add_argument("--learning-rate", type=float, default=2.5e-4)
parser.add_argument("--n-steps", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--n-epochs", type=int, default=4)
parser.add_argument("--gamma", type=float, default=0.99)
parser.add_argument("--gae-lambda", type=float, default=0.95)
parser.add_argument("--ent-coef", type=float, default=0.01)
parser.add_argument("--clip-range", type=float, default=0.2)
parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.")
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "mps"])
parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.")
parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.")
parser.add_argument(
"--init-model-path",
type=str,
default="",
help="Optional .zip model path to initialize weights before training.",
)
parser.add_argument(
"--allow-partial-init",
action="store_true",
help=(
"Allow partial init when checkpoint is from a different action space "
"(e.g. right_only -> simple). Compatible policy layers will be loaded, "
"incompatible layers (like action head) are skipped."
),
)
parser.add_argument("--progress-bar", action="store_true", help="Enable SB3 progress bar (requires tqdm/rich).")
return parser.parse_args()
def load_partial_policy_weights(model: PPO, init_model_path: Path, device: str) -> tuple[int, int]:
_, params, _ = load_from_zip_file(str(init_model_path), device=device)
if params is None or "policy" not in params:
raise RuntimeError(f"No policy parameters found in checkpoint: {init_model_path}")
src_state = params["policy"]
dst_state = model.policy.state_dict()
compatible_state = {}
skipped_count = 0
for key, value in src_state.items():
if key in dst_state and dst_state[key].shape == value.shape:
compatible_state[key] = value
else:
skipped_count += 1
if not compatible_state:
raise RuntimeError("No compatible policy tensors found for partial init.")
model.policy.load_state_dict(compatible_state, strict=False)
return len(compatible_state), skipped_count
class EpisodeEndLoggingCallback(BaseCallback):
"""Log per-episode terminal diagnostics to stdout and TensorBoard."""
REASON_TO_CODE = {"death": 0, "no_progress": 1, "timeout": 2, "clear": 3}
def __init__(self, n_envs: int, tb_log_dir: Path):
super().__init__(verbose=0)
self.n_envs = n_envs
self.tb_log_dir = tb_log_dir
self.tb_writer: SummaryWriter | None = None
self.episode_count = 0
self.episode_max_x_pos = [0.0 for _ in range(max(n_envs, 1))]
@staticmethod
def _resolve_done_reason(info: dict) -> str:
if bool(info.get("flag_get", False)):
return "clear"
if bool(info.get("terminated_by_stall", False)):
return "no_progress"
if bool(info.get("TimeLimit.truncated", False)):
return "timeout"
try:
if float(info.get("time", 1.0)) <= 0.0:
return "timeout"
except (TypeError, ValueError):
pass
return "death"
def _on_training_start(self) -> None:
self.tb_log_dir.mkdir(parents=True, exist_ok=True)
self.tb_writer = SummaryWriter(log_dir=str(self.tb_log_dir))
def _on_step(self) -> bool:
infos = self.locals.get("infos")
dones = self.locals.get("dones")
if infos is None or dones is None:
return True
for env_idx, (info, done) in enumerate(zip(infos, dones)):
if env_idx >= len(self.episode_max_x_pos):
continue
x_pos = float(info.get("x_pos", 0.0))
if x_pos > self.episode_max_x_pos[env_idx]:
self.episode_max_x_pos[env_idx] = x_pos
if not bool(done):
continue
self.episode_count += 1
max_x_pos = self.episode_max_x_pos[env_idx]
flag_get = 1.0 if bool(info.get("flag_get", False)) else 0.0
done_reason = self._resolve_done_reason(info)
done_reason_code = float(self.REASON_TO_CODE[done_reason])
episode_step = self.episode_count
self.logger.record_mean("rollout/episode_max_x_pos", max_x_pos)
self.logger.record_mean("rollout/flag_get", flag_get)
self.logger.record_mean(f"rollout/done_reason_{done_reason}", 1.0)
if self.tb_writer is not None:
self.tb_writer.add_scalar("episode_end/episode_max_x_pos", max_x_pos, episode_step)
self.tb_writer.add_scalar("episode_end/flag_get", flag_get, episode_step)
self.tb_writer.add_scalar("episode_end/done_reason_code", done_reason_code, episode_step)
self.tb_writer.add_scalar("episode_end/done_reason_death", 1.0 if done_reason == "death" else 0.0, episode_step)
self.tb_writer.add_scalar(
"episode_end/done_reason_no_progress", 1.0 if done_reason == "no_progress" else 0.0, episode_step
)
self.tb_writer.add_scalar("episode_end/done_reason_timeout", 1.0 if done_reason == "timeout" else 0.0, episode_step)
self.tb_writer.add_scalar("episode_end/done_reason_clear", 1.0 if done_reason == "clear" else 0.0, episode_step)
self.tb_writer.flush()
print(
f"[episode_end] ep={episode_step} env={env_idx} reason={done_reason} "
f"max_x={max_x_pos:.1f} flag_get={bool(flag_get)}"
)
self.episode_max_x_pos[env_idx] = 0.0
return True
def _on_training_end(self) -> None:
if self.tb_writer is not None:
self.tb_writer.close()
self.tb_writer = None
def main() -> None:
args = parse_args()
seed_everything(args.seed)
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
paths = ensure_artifact_paths()
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
env_slug = args.env_id.replace("/", "_")
run_name = args.run_name.strip() or f"ppo_{env_slug}_{timestamp}"
run_model_dir = paths.models / run_name
run_log_dir = paths.logs / run_name
run_tb_dir = run_log_dir / "tb"
run_model_dir.mkdir(parents=True, exist_ok=True)
run_tb_dir.mkdir(parents=True, exist_ok=True)
device, device_msg = resolve_torch_device(args.device)
print(f"[device] {device} | {device_msg}")
print(f"[run] {run_name}")
print(
f"[config] env_id={args.env_id}, movement={args.movement}, reward_mode={reward_mode}, "
f"clip_reward={args.clip_reward}, n_envs={args.n_envs}"
)
env_fns = [
make_env_fn(
env_id=args.env_id,
seed=args.seed + i,
movement=args.movement,
reward_mode=reward_mode,
clip_reward=args.clip_reward,
frame_skip=args.frame_skip,
progress_scale=args.progress_scale,
death_penalty=args.death_penalty,
flag_bonus=args.flag_bonus,
stall_penalty=args.stall_penalty,
stall_steps=args.stall_steps,
backward_penalty_scale=args.backward_penalty_scale,
milestone_interval=args.milestone_interval,
milestone_bonus=args.milestone_bonus,
no_progress_terminate_steps=args.no_progress_terminate_steps,
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
)
for i in range(args.n_envs)
]
vec_env = DummyVecEnv(env_fns)
vec_env = VecMonitor(vec_env, filename=str(run_log_dir / "monitor.csv"))
model = PPO(
policy="CnnPolicy",
env=vec_env,
learning_rate=args.learning_rate,
n_steps=args.n_steps,
batch_size=args.batch_size,
n_epochs=args.n_epochs,
gamma=args.gamma,
gae_lambda=args.gae_lambda,
ent_coef=args.ent_coef,
clip_range=args.clip_range,
tensorboard_log=str(run_tb_dir),
seed=args.seed,
verbose=1,
device=device,
)
if args.init_model_path.strip():
init_model_path = Path(args.init_model_path).expanduser().resolve()
if not init_model_path.exists():
raise FileNotFoundError(f"Init model not found: {init_model_path}")
try:
model.set_parameters(str(init_model_path), exact_match=False, device=device)
print(f"[init] Loaded initial weights from: {init_model_path}")
except RuntimeError as exc:
if not args.allow_partial_init:
raise RuntimeError(
f"{exc}\n"
"Hint: checkpoint and current env may use different action spaces. "
"Try one of:\n"
"1) keep the same --movement as the checkpoint;\n"
"2) remove --init-model-path and train from scratch;\n"
"3) add --allow-partial-init to load compatible layers only."
) from exc
loaded_count, skipped_count = load_partial_policy_weights(model, init_model_path, device=device)
print(
f"[init] Partial init from {init_model_path}: "
f"loaded_tensors={loaded_count}, skipped_tensors={skipped_count}"
)
callback = CheckpointCallback(
save_freq=max(args.save_freq // max(args.n_envs, 1), 1),
save_path=str(run_model_dir),
name_prefix="ppo_mario_ckpt",
save_replay_buffer=False,
save_vecnormalize=False,
)
episode_end_logging_callback = EpisodeEndLoggingCallback(
n_envs=args.n_envs,
tb_log_dir=run_tb_dir / "episode_end",
)
try:
model.learn(
total_timesteps=args.total_timesteps,
callback=CallbackList([callback, episode_end_logging_callback]),
tb_log_name="ppo",
progress_bar=args.progress_bar,
)
finally:
vec_env.close()
final_model = run_model_dir / "final_model"
model.save(str(final_model))
latest_model = paths.models / "latest_model.zip"
shutil.copy2(str(final_model) + ".zip", latest_model)
pointer = write_latest_pointer(paths.models, latest_model)
print(f"[done] Final model: {final_model}.zip")
print(f"[done] Latest model pointer: {pointer}")
if __name__ == "__main__":
main()

75
mario-rl-mvp/src/utils.py Normal file
View File

@@ -0,0 +1,75 @@
from __future__ import annotations
import random
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import torch
@dataclass(frozen=True)
class ArtifactPaths:
root: Path
models: Path
videos: Path
logs: Path
def project_root() -> Path:
return Path(__file__).resolve().parents[1]
def ensure_artifact_paths(root: Optional[Path] = None) -> ArtifactPaths:
root = root or project_root()
artifacts = root / "artifacts"
models = artifacts / "models"
videos = artifacts / "videos"
logs = artifacts / "logs"
for p in (artifacts, models, videos, logs):
p.mkdir(parents=True, exist_ok=True)
return ArtifactPaths(root=artifacts, models=models, videos=videos, logs=logs)
def seed_everything(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]:
requested = requested.lower().strip()
if requested == "cpu":
return "cpu", "User requested CPU."
if requested not in {"auto", "mps"}:
return "cpu", f"Unknown device '{requested}', fallback to CPU."
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
if not mps_available:
if requested == "mps":
return "cpu", "MPS requested but unavailable, fallback to CPU."
return "cpu", "MPS unavailable, using CPU."
try:
x = torch.ones(8, device="mps")
_ = (x * 2).cpu().numpy()
return "mps", "MPS is available and passed a quick tensor sanity check."
except Exception as exc: # pragma: no cover - hardware dependent
return "cpu", f"MPS check failed ({exc}), fallback to CPU."
def latest_model_path(models_dir: Path) -> Optional[Path]:
candidates = [p for p in models_dir.rglob("*.zip") if p.is_file()]
if not candidates:
return None
return max(candidates, key=lambda p: p.stat().st_mtime)
def write_latest_pointer(models_dir: Path, model_path: Path) -> Path:
pointer = models_dir / "LATEST_MODEL.txt"
pointer.write_text(str(model_path.resolve()) + "\n", encoding="utf-8")
return pointer