为评估新增视频录制支持,包括单环境和矢量环境录像功能,并更新 README.md 添加使用示例和效果展示。
This commit is contained in:
BIN
mario-rl-mvp/PixPin_2026-02-14_12-57-02.gif
Normal file
BIN
mario-rl-mvp/PixPin_2026-02-14_12-57-02.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 MiB |
@@ -2,6 +2,10 @@
|
||||
|
||||
最小可运行工程:使用像素输入 + 传统 CNN policy(`stable-baselines3` PPO)训练 `gym-super-mario-bros / nes-py` 智能体,不使用大语言模型。
|
||||
|
||||
最新进度
|
||||
|
||||

|
||||
|
||||
## 1. 项目结构
|
||||
|
||||
```text
|
||||
@@ -278,6 +282,7 @@ python -m src.record_video \
|
||||
--time-penalty -0.01 \
|
||||
--epsilon 0.08 \
|
||||
--duration-sec 30
|
||||
--stochastic
|
||||
```
|
||||
|
||||
或者稳定版本
|
||||
@@ -297,6 +302,8 @@ python -m src.record_video \
|
||||
--epsilon-random-mode uniform \
|
||||
--max-steps 6000
|
||||
```
|
||||
|
||||
|
||||
可选:
|
||||
|
||||
```bash
|
||||
@@ -523,3 +530,17 @@ python -m src.record_video --duration-sec 10 --fps 30
|
||||
- `artifacts/models/` 下有 `.zip` 模型
|
||||
- `artifacts/logs/` 下有 TensorBoard event 文件
|
||||
- `artifacts/videos/` 下有 `.mp4`(或失败时有 `_frames/` 帧序列)
|
||||
|
||||
python -m src.eval \
|
||||
--model-path artifacts/models/latest_model.zip \
|
||||
--episodes 50 \
|
||||
--movement simple \
|
||||
--reward-mode progress \
|
||||
--no-progress-terminate-steps 300 \
|
||||
--no-progress-terminate-penalty 10 \
|
||||
--death-penalty -50 \
|
||||
--stall-penalty 0.05 \
|
||||
--stall-steps 40 \
|
||||
--time-penalty -0.01 \
|
||||
--random-noops 30 \
|
||||
--epsilon 0.03
|
||||
@@ -2,9 +2,14 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
from typing import Any
|
||||
|
||||
import imageio.v2 as imageio
|
||||
import numpy as np
|
||||
from src.policy_utils import select_action
|
||||
from src.runtime import configure_runtime_env
|
||||
@@ -12,6 +17,7 @@ from src.runtime import configure_runtime_env
|
||||
configure_runtime_env()
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.vec_env import VecEnv, VecVideoRecorder
|
||||
|
||||
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
||||
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
|
||||
@@ -82,6 +88,27 @@ def parse_args() -> argparse.Namespace:
|
||||
default=0,
|
||||
help="Random no-op steps after reset (0 disables). Uses action 0 as NOOP.",
|
||||
)
|
||||
parser.add_argument("--record-video", action="store_true", help="Record eval rollouts to mp4.")
|
||||
parser.add_argument("--video-dir", type=str, default="artifacts/videos/eval", help="Output directory for videos.")
|
||||
parser.add_argument("--video-fps", type=int, default=60, help="Video FPS metadata.")
|
||||
parser.add_argument(
|
||||
"--render-mode",
|
||||
type=str,
|
||||
default="rgb_array",
|
||||
help="Render mode for recording (e.g. rgb_array, human). Recording requires frame output.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video-episode-trigger",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Record every N episodes (single env) or N segments (VecEnv). 1 means all.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video-length",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Max recorded steps per video. 0 means full episode (single env) / --max-steps (VecEnv).",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -133,60 +160,191 @@ def resolve_model_path(user_path: str) -> Path:
|
||||
return latest
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.epsilon < 0.0 or args.epsilon > 1.0:
|
||||
raise ValueError(f"--epsilon must be in [0, 1], got {args.epsilon}")
|
||||
seed_everything(args.seed)
|
||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||
def _resolve_render_mode(args: argparse.Namespace) -> str | None:
|
||||
if not args.record_video:
|
||||
return None
|
||||
render_mode = args.render_mode.strip() if args.render_mode else "rgb_array"
|
||||
if render_mode != "rgb_array":
|
||||
print(f"[video] render_mode={render_mode} is not frame-safe for mp4; forcing rgb_array.")
|
||||
return "rgb_array"
|
||||
return render_mode
|
||||
|
||||
model_path = resolve_model_path(args.model_path)
|
||||
print(f"[eval] model={model_path}")
|
||||
model = PPO.load(str(model_path))
|
||||
movement = resolve_movement(args.movement, model)
|
||||
|
||||
def _build_video_name_prefix(model_path: Path, seed: int) -> str:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
model_slug = re.sub(r"[^A-Za-z0-9._-]+", "-", model_path.stem).strip("-") or "model"
|
||||
return f"eval-{model_slug}-seed{seed}-{timestamp}"
|
||||
|
||||
|
||||
def _check_video_recording_dependencies() -> None:
|
||||
if shutil.which("ffmpeg"):
|
||||
return
|
||||
|
||||
try:
|
||||
import imageio_ffmpeg
|
||||
|
||||
ffmpeg_exe = Path(imageio_ffmpeg.get_ffmpeg_exe())
|
||||
if ffmpeg_exe.exists():
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise RuntimeError("需要安装 ffmpeg 才能录制 mp4。请安装系统 ffmpeg,或 `pip install imageio-ffmpeg`。")
|
||||
|
||||
|
||||
class EpisodeVideoRecorder:
|
||||
"""Manual mp4 recorder for regular gym.Env to avoid wrapper instability."""
|
||||
|
||||
def __init__(self, video_dir: Path, name_prefix: str, fps: int, trigger_every: int, video_length: int):
|
||||
self.video_dir = video_dir
|
||||
self.name_prefix = name_prefix
|
||||
self.fps = int(fps)
|
||||
self.trigger_every = int(max(1, trigger_every))
|
||||
self.video_length = int(max(0, video_length))
|
||||
self.writer: Any | None = None
|
||||
self.current_video_path: Path | None = None
|
||||
self.current_episode_id = -1
|
||||
self.recorded_steps = 0
|
||||
|
||||
@property
|
||||
def recording(self) -> bool:
|
||||
return self.writer is not None
|
||||
|
||||
def _episode_enabled(self, episode_id: int) -> bool:
|
||||
return episode_id % self.trigger_every == 0
|
||||
|
||||
def start_episode(self, episode_id: int) -> None:
|
||||
self.end_episode()
|
||||
self.current_episode_id = int(episode_id)
|
||||
self.recorded_steps = 0
|
||||
if not self._episode_enabled(episode_id):
|
||||
return
|
||||
|
||||
video_name = f"{self.name_prefix}-episode-{episode_id}.mp4"
|
||||
self.current_video_path = self.video_dir / video_name
|
||||
self.writer = imageio.get_writer(
|
||||
str(self.current_video_path),
|
||||
fps=self.fps,
|
||||
codec="libx264",
|
||||
quality=8,
|
||||
macro_block_size=1,
|
||||
)
|
||||
|
||||
def capture_frame_from_env(self, env: Any) -> None:
|
||||
if not self.recording:
|
||||
return
|
||||
frame = env.render()
|
||||
if frame is None:
|
||||
raise RuntimeError("录像失败:当前 render_mode 无法产出帧,请使用 --render-mode rgb_array。")
|
||||
if isinstance(frame, list):
|
||||
if len(frame) == 0:
|
||||
raise RuntimeError("录像失败:render() 返回空帧列表。")
|
||||
frame = frame[-1]
|
||||
|
||||
try:
|
||||
assert self.writer is not None
|
||||
self.writer.append_data(np.asarray(frame).copy())
|
||||
except Exception as exc:
|
||||
raise RuntimeError(f"写入视频帧失败,请检查 ffmpeg/imageio-ffmpeg: {exc}") from exc
|
||||
|
||||
def on_env_step(self) -> None:
|
||||
if not self.recording:
|
||||
return
|
||||
self.recorded_steps += 1
|
||||
if self.video_length > 0 and self.recorded_steps >= self.video_length:
|
||||
self.end_episode()
|
||||
|
||||
def end_episode(self) -> None:
|
||||
if self.writer is None:
|
||||
return
|
||||
self.writer.close()
|
||||
if self.current_video_path is not None:
|
||||
print(f"[video] saved {self.current_video_path}")
|
||||
self.writer = None
|
||||
|
||||
def close(self) -> None:
|
||||
self.end_episode()
|
||||
|
||||
|
||||
def _set_render_fps(env: Any, fps: int) -> None:
|
||||
metadata = getattr(env, "metadata", None)
|
||||
if isinstance(metadata, dict):
|
||||
metadata["render_fps"] = fps
|
||||
|
||||
if isinstance(env, VecEnv):
|
||||
try:
|
||||
env_metadatas = env.get_attr("metadata")
|
||||
for env_metadata in env_metadatas:
|
||||
if isinstance(env_metadata, dict):
|
||||
env_metadata["render_fps"] = fps
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
unwrapped_metadata = getattr(getattr(env, "unwrapped", None), "metadata", None)
|
||||
if isinstance(unwrapped_metadata, dict):
|
||||
unwrapped_metadata["render_fps"] = fps
|
||||
|
||||
|
||||
def wrap_vec_env_for_video(env: VecEnv, args: argparse.Namespace, video_dir: Path, name_prefix: str) -> VecEnv:
|
||||
video_length = args.video_length if args.video_length > 0 else max(1, args.max_steps)
|
||||
trigger_every = max(1, args.video_episode_trigger)
|
||||
trigger_interval = max(1, video_length * trigger_every)
|
||||
record_video_trigger = lambda step_id: step_id % trigger_interval == 0
|
||||
wrapped = VecVideoRecorder(
|
||||
venv=env,
|
||||
video_folder=str(video_dir),
|
||||
record_video_trigger=record_video_trigger,
|
||||
video_length=video_length,
|
||||
name_prefix=name_prefix,
|
||||
)
|
||||
print(
|
||||
f"[eval] movement={movement} reward_mode={reward_mode} random_noops={args.random_noops} "
|
||||
f"time_penalty={args.time_penalty} hard_stuck_steps={args.hard_stuck_steps} "
|
||||
f"hard_stuck_epsilon={args.hard_stuck_epsilon} hard_stuck_penalty={args.hard_stuck_penalty} "
|
||||
f"epsilon={args.epsilon} epsilon_random_mode={args.epsilon_random_mode}"
|
||||
f"[video] recording enabled folder={video_dir} fps={args.video_fps} "
|
||||
f"kind=VecVideoRecorder trigger_steps={trigger_interval} video_length={video_length}"
|
||||
)
|
||||
return wrapped
|
||||
|
||||
env = make_mario_env(
|
||||
env_id=args.env_id,
|
||||
seed=args.seed,
|
||||
movement=movement,
|
||||
reward_mode=reward_mode,
|
||||
clip_reward=args.clip_reward,
|
||||
frame_skip=args.frame_skip,
|
||||
render_mode=None,
|
||||
progress_scale=args.progress_scale,
|
||||
death_penalty=args.death_penalty,
|
||||
flag_bonus=args.flag_bonus,
|
||||
stall_penalty=args.stall_penalty,
|
||||
stall_steps=args.stall_steps,
|
||||
backward_penalty_scale=args.backward_penalty_scale,
|
||||
milestone_interval=args.milestone_interval,
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
time_penalty=args.time_penalty,
|
||||
hard_stuck_steps=args.hard_stuck_steps,
|
||||
hard_stuck_epsilon=args.hard_stuck_epsilon,
|
||||
hard_stuck_penalty=args.hard_stuck_penalty,
|
||||
)
|
||||
|
||||
rewards = []
|
||||
max_x_positions = []
|
||||
clear_flags = []
|
||||
def _close_env_safely(env: Any) -> None:
|
||||
# SB3 VecVideoRecorder may close wrapped env twice on close().
|
||||
if isinstance(env, VecVideoRecorder):
|
||||
was_recording = bool(getattr(env, "recording", False))
|
||||
env.close_video_recorder()
|
||||
if not was_recording:
|
||||
env.venv.close()
|
||||
return
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def evaluate_single_env(
|
||||
env: Any,
|
||||
model: PPO,
|
||||
args: argparse.Namespace,
|
||||
video_recorder: EpisodeVideoRecorder | None = None,
|
||||
) -> tuple[list[float], list[float], list[float]]:
|
||||
rewards: list[float] = []
|
||||
max_x_positions: list[float] = []
|
||||
clear_flags: list[float] = []
|
||||
|
||||
for ep in range(1, args.episodes + 1):
|
||||
if video_recorder is not None:
|
||||
video_recorder.start_episode(ep - 1)
|
||||
obs, info = reset_compat(env, seed=args.seed + ep)
|
||||
if video_recorder is not None:
|
||||
video_recorder.capture_frame_from_env(env)
|
||||
if args.random_noops > 0:
|
||||
noop_steps = np.random.randint(0, args.random_noops + 1)
|
||||
for _ in range(noop_steps):
|
||||
obs, _, terminated, truncated, info = step_compat(env, 0)
|
||||
if video_recorder is not None:
|
||||
video_recorder.capture_frame_from_env(env)
|
||||
video_recorder.on_env_step()
|
||||
if terminated or truncated:
|
||||
obs, info = reset_compat(env, seed=args.seed + ep + 1000)
|
||||
if video_recorder is not None:
|
||||
video_recorder.capture_frame_from_env(env)
|
||||
|
||||
done = False
|
||||
ep_reward = 0.0
|
||||
ep_max_x = float(info.get("x_pos", 0.0))
|
||||
@@ -203,6 +361,9 @@ def main() -> None:
|
||||
env=env,
|
||||
)
|
||||
obs, reward, terminated, truncated, info = step_compat(env, action)
|
||||
if video_recorder is not None:
|
||||
video_recorder.capture_frame_from_env(env)
|
||||
video_recorder.on_env_step()
|
||||
ep_reward += float(reward)
|
||||
ep_max_x = max(ep_max_x, float(info.get("x_pos", 0.0)))
|
||||
flag_get = flag_get or bool(info.get("flag_get", False))
|
||||
@@ -216,6 +377,168 @@ def main() -> None:
|
||||
f"[episode {ep}] reward={ep_reward:.2f} max_x={ep_max_x:.1f} "
|
||||
f"clear={flag_get} steps={step_count}"
|
||||
)
|
||||
if video_recorder is not None:
|
||||
video_recorder.end_episode()
|
||||
|
||||
return rewards, max_x_positions, clear_flags
|
||||
|
||||
|
||||
def _reset_vec_rank0(env: VecEnv, obs: np.ndarray, seed: int) -> np.ndarray:
|
||||
try:
|
||||
reset_results = env.env_method("reset", seed=seed, indices=0)
|
||||
except TypeError:
|
||||
reset_results = env.env_method("reset", indices=0)
|
||||
|
||||
if not reset_results:
|
||||
return obs
|
||||
|
||||
reset_result = reset_results[0]
|
||||
obs0 = reset_result[0] if isinstance(reset_result, tuple) else reset_result
|
||||
obs[0] = obs0
|
||||
return obs
|
||||
|
||||
|
||||
def evaluate_vec_env(env: VecEnv, model: PPO, args: argparse.Namespace) -> tuple[list[float], list[float], list[float]]:
|
||||
if args.random_noops > 0:
|
||||
print("[warn] random_noops is ignored for VecEnv eval.")
|
||||
|
||||
rewards: list[float] = []
|
||||
max_x_positions: list[float] = []
|
||||
clear_flags: list[float] = []
|
||||
|
||||
obs = env.reset()
|
||||
n_envs = int(getattr(env, "num_envs", 1))
|
||||
|
||||
for ep in range(1, args.episodes + 1):
|
||||
done = False
|
||||
ep_reward = 0.0
|
||||
ep_max_x = 0.0
|
||||
flag_get = False
|
||||
step_count = 0
|
||||
|
||||
while not done and step_count < args.max_steps:
|
||||
actions = np.asarray(
|
||||
[
|
||||
select_action(
|
||||
model=model,
|
||||
obs=obs[env_idx],
|
||||
deterministic=not args.stochastic,
|
||||
epsilon=args.epsilon,
|
||||
epsilon_random_mode=args.epsilon_random_mode,
|
||||
env=env,
|
||||
)
|
||||
for env_idx in range(n_envs)
|
||||
],
|
||||
dtype=np.int64,
|
||||
)
|
||||
obs, reward_vec, done_vec, info_vec = env.step(actions)
|
||||
info0 = info_vec[0] if len(info_vec) > 0 else {}
|
||||
ep_reward += float(reward_vec[0])
|
||||
ep_max_x = max(ep_max_x, float(info0.get("x_pos", 0.0)))
|
||||
flag_get = flag_get or bool(info0.get("flag_get", False))
|
||||
done = bool(done_vec[0])
|
||||
step_count += 1
|
||||
|
||||
if not done and step_count >= args.max_steps:
|
||||
obs = _reset_vec_rank0(env=env, obs=obs, seed=args.seed + ep + 1000)
|
||||
|
||||
rewards.append(ep_reward)
|
||||
max_x_positions.append(ep_max_x)
|
||||
clear_flags.append(1.0 if flag_get else 0.0)
|
||||
print(
|
||||
f"[episode {ep}] reward={ep_reward:.2f} max_x={ep_max_x:.1f} "
|
||||
f"clear={flag_get} steps={step_count}"
|
||||
)
|
||||
|
||||
return rewards, max_x_positions, clear_flags
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
if args.epsilon < 0.0 or args.epsilon > 1.0:
|
||||
raise ValueError(f"--epsilon must be in [0, 1], got {args.epsilon}")
|
||||
if args.video_fps <= 0:
|
||||
raise ValueError(f"--video-fps must be > 0, got {args.video_fps}")
|
||||
if args.video_episode_trigger <= 0:
|
||||
raise ValueError(f"--video-episode-trigger must be > 0, got {args.video_episode_trigger}")
|
||||
if args.video_length < 0:
|
||||
raise ValueError(f"--video-length must be >= 0, got {args.video_length}")
|
||||
seed_everything(args.seed)
|
||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||
|
||||
model_path = resolve_model_path(args.model_path)
|
||||
print(f"[eval] model={model_path}")
|
||||
model = PPO.load(str(model_path))
|
||||
movement = resolve_movement(args.movement, model)
|
||||
print(
|
||||
f"[eval] movement={movement} reward_mode={reward_mode} random_noops={args.random_noops} "
|
||||
f"time_penalty={args.time_penalty} hard_stuck_steps={args.hard_stuck_steps} "
|
||||
f"hard_stuck_epsilon={args.hard_stuck_epsilon} hard_stuck_penalty={args.hard_stuck_penalty} "
|
||||
f"epsilon={args.epsilon} epsilon_random_mode={args.epsilon_random_mode}"
|
||||
)
|
||||
render_mode = _resolve_render_mode(args)
|
||||
|
||||
env = make_mario_env(
|
||||
env_id=args.env_id,
|
||||
seed=args.seed,
|
||||
movement=movement,
|
||||
reward_mode=reward_mode,
|
||||
clip_reward=args.clip_reward,
|
||||
frame_skip=args.frame_skip,
|
||||
render_mode=render_mode,
|
||||
progress_scale=args.progress_scale,
|
||||
death_penalty=args.death_penalty,
|
||||
flag_bonus=args.flag_bonus,
|
||||
stall_penalty=args.stall_penalty,
|
||||
stall_steps=args.stall_steps,
|
||||
backward_penalty_scale=args.backward_penalty_scale,
|
||||
milestone_interval=args.milestone_interval,
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
time_penalty=args.time_penalty,
|
||||
hard_stuck_steps=args.hard_stuck_steps,
|
||||
hard_stuck_epsilon=args.hard_stuck_epsilon,
|
||||
hard_stuck_penalty=args.hard_stuck_penalty,
|
||||
)
|
||||
video_recorder: EpisodeVideoRecorder | None = None
|
||||
if args.record_video:
|
||||
_check_video_recording_dependencies()
|
||||
video_dir = Path(args.video_dir).expanduser().resolve()
|
||||
video_dir.mkdir(parents=True, exist_ok=True)
|
||||
name_prefix = _build_video_name_prefix(model_path=model_path, seed=args.seed)
|
||||
_set_render_fps(env, args.video_fps)
|
||||
if isinstance(env, VecEnv):
|
||||
env = wrap_vec_env_for_video(env=env, args=args, video_dir=video_dir, name_prefix=name_prefix)
|
||||
else:
|
||||
trigger_every = max(1, args.video_episode_trigger)
|
||||
video_recorder = EpisodeVideoRecorder(
|
||||
video_dir=video_dir,
|
||||
name_prefix=name_prefix,
|
||||
fps=args.video_fps,
|
||||
trigger_every=trigger_every,
|
||||
video_length=args.video_length,
|
||||
)
|
||||
print(
|
||||
f"[video] recording enabled folder={video_dir} fps={args.video_fps} "
|
||||
f"kind=imageio_episode trigger_every_episodes={trigger_every} "
|
||||
f"video_length={max(0, args.video_length)}"
|
||||
)
|
||||
|
||||
try:
|
||||
if isinstance(env, VecEnv):
|
||||
rewards, max_x_positions, clear_flags = evaluate_vec_env(env=env, model=model, args=args)
|
||||
else:
|
||||
rewards, max_x_positions, clear_flags = evaluate_single_env(
|
||||
env=env,
|
||||
model=model,
|
||||
args=args,
|
||||
video_recorder=video_recorder,
|
||||
)
|
||||
finally:
|
||||
if video_recorder is not None:
|
||||
video_recorder.close()
|
||||
_close_env_safely(env)
|
||||
|
||||
summary = {
|
||||
"episodes": args.episodes,
|
||||
@@ -225,8 +548,6 @@ def main() -> None:
|
||||
}
|
||||
print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2))
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user