为评估新增视频录制支持,包括单环境和矢量环境录像功能,并更新 README.md 添加使用示例和效果展示。
This commit is contained in:
BIN
mario-rl-mvp/PixPin_2026-02-14_12-57-02.gif
Normal file
BIN
mario-rl-mvp/PixPin_2026-02-14_12-57-02.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 16 MiB |
@@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
最小可运行工程:使用像素输入 + 传统 CNN policy(`stable-baselines3` PPO)训练 `gym-super-mario-bros / nes-py` 智能体,不使用大语言模型。
|
最小可运行工程:使用像素输入 + 传统 CNN policy(`stable-baselines3` PPO)训练 `gym-super-mario-bros / nes-py` 智能体,不使用大语言模型。
|
||||||
|
|
||||||
|
最新进度
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
## 1. 项目结构
|
## 1. 项目结构
|
||||||
|
|
||||||
```text
|
```text
|
||||||
@@ -278,6 +282,7 @@ python -m src.record_video \
|
|||||||
--time-penalty -0.01 \
|
--time-penalty -0.01 \
|
||||||
--epsilon 0.08 \
|
--epsilon 0.08 \
|
||||||
--duration-sec 30
|
--duration-sec 30
|
||||||
|
--stochastic
|
||||||
```
|
```
|
||||||
|
|
||||||
或者稳定版本
|
或者稳定版本
|
||||||
@@ -297,6 +302,8 @@ python -m src.record_video \
|
|||||||
--epsilon-random-mode uniform \
|
--epsilon-random-mode uniform \
|
||||||
--max-steps 6000
|
--max-steps 6000
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
可选:
|
可选:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -523,3 +530,17 @@ python -m src.record_video --duration-sec 10 --fps 30
|
|||||||
- `artifacts/models/` 下有 `.zip` 模型
|
- `artifacts/models/` 下有 `.zip` 模型
|
||||||
- `artifacts/logs/` 下有 TensorBoard event 文件
|
- `artifacts/logs/` 下有 TensorBoard event 文件
|
||||||
- `artifacts/videos/` 下有 `.mp4`(或失败时有 `_frames/` 帧序列)
|
- `artifacts/videos/` 下有 `.mp4`(或失败时有 `_frames/` 帧序列)
|
||||||
|
|
||||||
|
python -m src.eval \
|
||||||
|
--model-path artifacts/models/latest_model.zip \
|
||||||
|
--episodes 50 \
|
||||||
|
--movement simple \
|
||||||
|
--reward-mode progress \
|
||||||
|
--no-progress-terminate-steps 300 \
|
||||||
|
--no-progress-terminate-penalty 10 \
|
||||||
|
--death-penalty -50 \
|
||||||
|
--stall-penalty 0.05 \
|
||||||
|
--stall-steps 40 \
|
||||||
|
--time-penalty -0.01 \
|
||||||
|
--random-noops 30 \
|
||||||
|
--epsilon 0.03
|
||||||
@@ -2,9 +2,14 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import imageio.v2 as imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from src.policy_utils import select_action
|
from src.policy_utils import select_action
|
||||||
from src.runtime import configure_runtime_env
|
from src.runtime import configure_runtime_env
|
||||||
@@ -12,6 +17,7 @@ from src.runtime import configure_runtime_env
|
|||||||
configure_runtime_env()
|
configure_runtime_env()
|
||||||
|
|
||||||
from stable_baselines3 import PPO
|
from stable_baselines3 import PPO
|
||||||
|
from stable_baselines3.common.vec_env import VecEnv, VecVideoRecorder
|
||||||
|
|
||||||
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
||||||
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
|
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
|
||||||
@@ -82,6 +88,27 @@ def parse_args() -> argparse.Namespace:
|
|||||||
default=0,
|
default=0,
|
||||||
help="Random no-op steps after reset (0 disables). Uses action 0 as NOOP.",
|
help="Random no-op steps after reset (0 disables). Uses action 0 as NOOP.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--record-video", action="store_true", help="Record eval rollouts to mp4.")
|
||||||
|
parser.add_argument("--video-dir", type=str, default="artifacts/videos/eval", help="Output directory for videos.")
|
||||||
|
parser.add_argument("--video-fps", type=int, default=60, help="Video FPS metadata.")
|
||||||
|
parser.add_argument(
|
||||||
|
"--render-mode",
|
||||||
|
type=str,
|
||||||
|
default="rgb_array",
|
||||||
|
help="Render mode for recording (e.g. rgb_array, human). Recording requires frame output.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--video-episode-trigger",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Record every N episodes (single env) or N segments (VecEnv). 1 means all.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--video-length",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="Max recorded steps per video. 0 means full episode (single env) / --max-steps (VecEnv).",
|
||||||
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
@@ -133,60 +160,191 @@ def resolve_model_path(user_path: str) -> Path:
|
|||||||
return latest
|
return latest
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
def _resolve_render_mode(args: argparse.Namespace) -> str | None:
|
||||||
args = parse_args()
|
if not args.record_video:
|
||||||
if args.epsilon < 0.0 or args.epsilon > 1.0:
|
return None
|
||||||
raise ValueError(f"--epsilon must be in [0, 1], got {args.epsilon}")
|
render_mode = args.render_mode.strip() if args.render_mode else "rgb_array"
|
||||||
seed_everything(args.seed)
|
if render_mode != "rgb_array":
|
||||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
print(f"[video] render_mode={render_mode} is not frame-safe for mp4; forcing rgb_array.")
|
||||||
|
return "rgb_array"
|
||||||
|
return render_mode
|
||||||
|
|
||||||
model_path = resolve_model_path(args.model_path)
|
|
||||||
print(f"[eval] model={model_path}")
|
def _build_video_name_prefix(model_path: Path, seed: int) -> str:
|
||||||
model = PPO.load(str(model_path))
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
movement = resolve_movement(args.movement, model)
|
model_slug = re.sub(r"[^A-Za-z0-9._-]+", "-", model_path.stem).strip("-") or "model"
|
||||||
|
return f"eval-{model_slug}-seed{seed}-{timestamp}"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_video_recording_dependencies() -> None:
|
||||||
|
if shutil.which("ffmpeg"):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
import imageio_ffmpeg
|
||||||
|
|
||||||
|
ffmpeg_exe = Path(imageio_ffmpeg.get_ffmpeg_exe())
|
||||||
|
if ffmpeg_exe.exists():
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
raise RuntimeError("需要安装 ffmpeg 才能录制 mp4。请安装系统 ffmpeg,或 `pip install imageio-ffmpeg`。")
|
||||||
|
|
||||||
|
|
||||||
|
class EpisodeVideoRecorder:
|
||||||
|
"""Manual mp4 recorder for regular gym.Env to avoid wrapper instability."""
|
||||||
|
|
||||||
|
def __init__(self, video_dir: Path, name_prefix: str, fps: int, trigger_every: int, video_length: int):
|
||||||
|
self.video_dir = video_dir
|
||||||
|
self.name_prefix = name_prefix
|
||||||
|
self.fps = int(fps)
|
||||||
|
self.trigger_every = int(max(1, trigger_every))
|
||||||
|
self.video_length = int(max(0, video_length))
|
||||||
|
self.writer: Any | None = None
|
||||||
|
self.current_video_path: Path | None = None
|
||||||
|
self.current_episode_id = -1
|
||||||
|
self.recorded_steps = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def recording(self) -> bool:
|
||||||
|
return self.writer is not None
|
||||||
|
|
||||||
|
def _episode_enabled(self, episode_id: int) -> bool:
|
||||||
|
return episode_id % self.trigger_every == 0
|
||||||
|
|
||||||
|
def start_episode(self, episode_id: int) -> None:
|
||||||
|
self.end_episode()
|
||||||
|
self.current_episode_id = int(episode_id)
|
||||||
|
self.recorded_steps = 0
|
||||||
|
if not self._episode_enabled(episode_id):
|
||||||
|
return
|
||||||
|
|
||||||
|
video_name = f"{self.name_prefix}-episode-{episode_id}.mp4"
|
||||||
|
self.current_video_path = self.video_dir / video_name
|
||||||
|
self.writer = imageio.get_writer(
|
||||||
|
str(self.current_video_path),
|
||||||
|
fps=self.fps,
|
||||||
|
codec="libx264",
|
||||||
|
quality=8,
|
||||||
|
macro_block_size=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def capture_frame_from_env(self, env: Any) -> None:
|
||||||
|
if not self.recording:
|
||||||
|
return
|
||||||
|
frame = env.render()
|
||||||
|
if frame is None:
|
||||||
|
raise RuntimeError("录像失败:当前 render_mode 无法产出帧,请使用 --render-mode rgb_array。")
|
||||||
|
if isinstance(frame, list):
|
||||||
|
if len(frame) == 0:
|
||||||
|
raise RuntimeError("录像失败:render() 返回空帧列表。")
|
||||||
|
frame = frame[-1]
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert self.writer is not None
|
||||||
|
self.writer.append_data(np.asarray(frame).copy())
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(f"写入视频帧失败,请检查 ffmpeg/imageio-ffmpeg: {exc}") from exc
|
||||||
|
|
||||||
|
def on_env_step(self) -> None:
|
||||||
|
if not self.recording:
|
||||||
|
return
|
||||||
|
self.recorded_steps += 1
|
||||||
|
if self.video_length > 0 and self.recorded_steps >= self.video_length:
|
||||||
|
self.end_episode()
|
||||||
|
|
||||||
|
def end_episode(self) -> None:
|
||||||
|
if self.writer is None:
|
||||||
|
return
|
||||||
|
self.writer.close()
|
||||||
|
if self.current_video_path is not None:
|
||||||
|
print(f"[video] saved {self.current_video_path}")
|
||||||
|
self.writer = None
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
self.end_episode()
|
||||||
|
|
||||||
|
|
||||||
|
def _set_render_fps(env: Any, fps: int) -> None:
|
||||||
|
metadata = getattr(env, "metadata", None)
|
||||||
|
if isinstance(metadata, dict):
|
||||||
|
metadata["render_fps"] = fps
|
||||||
|
|
||||||
|
if isinstance(env, VecEnv):
|
||||||
|
try:
|
||||||
|
env_metadatas = env.get_attr("metadata")
|
||||||
|
for env_metadata in env_metadatas:
|
||||||
|
if isinstance(env_metadata, dict):
|
||||||
|
env_metadata["render_fps"] = fps
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
|
unwrapped_metadata = getattr(getattr(env, "unwrapped", None), "metadata", None)
|
||||||
|
if isinstance(unwrapped_metadata, dict):
|
||||||
|
unwrapped_metadata["render_fps"] = fps
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_vec_env_for_video(env: VecEnv, args: argparse.Namespace, video_dir: Path, name_prefix: str) -> VecEnv:
|
||||||
|
video_length = args.video_length if args.video_length > 0 else max(1, args.max_steps)
|
||||||
|
trigger_every = max(1, args.video_episode_trigger)
|
||||||
|
trigger_interval = max(1, video_length * trigger_every)
|
||||||
|
record_video_trigger = lambda step_id: step_id % trigger_interval == 0
|
||||||
|
wrapped = VecVideoRecorder(
|
||||||
|
venv=env,
|
||||||
|
video_folder=str(video_dir),
|
||||||
|
record_video_trigger=record_video_trigger,
|
||||||
|
video_length=video_length,
|
||||||
|
name_prefix=name_prefix,
|
||||||
|
)
|
||||||
print(
|
print(
|
||||||
f"[eval] movement={movement} reward_mode={reward_mode} random_noops={args.random_noops} "
|
f"[video] recording enabled folder={video_dir} fps={args.video_fps} "
|
||||||
f"time_penalty={args.time_penalty} hard_stuck_steps={args.hard_stuck_steps} "
|
f"kind=VecVideoRecorder trigger_steps={trigger_interval} video_length={video_length}"
|
||||||
f"hard_stuck_epsilon={args.hard_stuck_epsilon} hard_stuck_penalty={args.hard_stuck_penalty} "
|
|
||||||
f"epsilon={args.epsilon} epsilon_random_mode={args.epsilon_random_mode}"
|
|
||||||
)
|
)
|
||||||
|
return wrapped
|
||||||
|
|
||||||
env = make_mario_env(
|
|
||||||
env_id=args.env_id,
|
|
||||||
seed=args.seed,
|
|
||||||
movement=movement,
|
|
||||||
reward_mode=reward_mode,
|
|
||||||
clip_reward=args.clip_reward,
|
|
||||||
frame_skip=args.frame_skip,
|
|
||||||
render_mode=None,
|
|
||||||
progress_scale=args.progress_scale,
|
|
||||||
death_penalty=args.death_penalty,
|
|
||||||
flag_bonus=args.flag_bonus,
|
|
||||||
stall_penalty=args.stall_penalty,
|
|
||||||
stall_steps=args.stall_steps,
|
|
||||||
backward_penalty_scale=args.backward_penalty_scale,
|
|
||||||
milestone_interval=args.milestone_interval,
|
|
||||||
milestone_bonus=args.milestone_bonus,
|
|
||||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
|
||||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
|
||||||
time_penalty=args.time_penalty,
|
|
||||||
hard_stuck_steps=args.hard_stuck_steps,
|
|
||||||
hard_stuck_epsilon=args.hard_stuck_epsilon,
|
|
||||||
hard_stuck_penalty=args.hard_stuck_penalty,
|
|
||||||
)
|
|
||||||
|
|
||||||
rewards = []
|
def _close_env_safely(env: Any) -> None:
|
||||||
max_x_positions = []
|
# SB3 VecVideoRecorder may close wrapped env twice on close().
|
||||||
clear_flags = []
|
if isinstance(env, VecVideoRecorder):
|
||||||
|
was_recording = bool(getattr(env, "recording", False))
|
||||||
|
env.close_video_recorder()
|
||||||
|
if not was_recording:
|
||||||
|
env.venv.close()
|
||||||
|
return
|
||||||
|
|
||||||
|
env.close()
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_single_env(
|
||||||
|
env: Any,
|
||||||
|
model: PPO,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
video_recorder: EpisodeVideoRecorder | None = None,
|
||||||
|
) -> tuple[list[float], list[float], list[float]]:
|
||||||
|
rewards: list[float] = []
|
||||||
|
max_x_positions: list[float] = []
|
||||||
|
clear_flags: list[float] = []
|
||||||
|
|
||||||
for ep in range(1, args.episodes + 1):
|
for ep in range(1, args.episodes + 1):
|
||||||
|
if video_recorder is not None:
|
||||||
|
video_recorder.start_episode(ep - 1)
|
||||||
obs, info = reset_compat(env, seed=args.seed + ep)
|
obs, info = reset_compat(env, seed=args.seed + ep)
|
||||||
|
if video_recorder is not None:
|
||||||
|
video_recorder.capture_frame_from_env(env)
|
||||||
if args.random_noops > 0:
|
if args.random_noops > 0:
|
||||||
noop_steps = np.random.randint(0, args.random_noops + 1)
|
noop_steps = np.random.randint(0, args.random_noops + 1)
|
||||||
for _ in range(noop_steps):
|
for _ in range(noop_steps):
|
||||||
obs, _, terminated, truncated, info = step_compat(env, 0)
|
obs, _, terminated, truncated, info = step_compat(env, 0)
|
||||||
|
if video_recorder is not None:
|
||||||
|
video_recorder.capture_frame_from_env(env)
|
||||||
|
video_recorder.on_env_step()
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
obs, info = reset_compat(env, seed=args.seed + ep + 1000)
|
obs, info = reset_compat(env, seed=args.seed + ep + 1000)
|
||||||
|
if video_recorder is not None:
|
||||||
|
video_recorder.capture_frame_from_env(env)
|
||||||
|
|
||||||
done = False
|
done = False
|
||||||
ep_reward = 0.0
|
ep_reward = 0.0
|
||||||
ep_max_x = float(info.get("x_pos", 0.0))
|
ep_max_x = float(info.get("x_pos", 0.0))
|
||||||
@@ -203,6 +361,9 @@ def main() -> None:
|
|||||||
env=env,
|
env=env,
|
||||||
)
|
)
|
||||||
obs, reward, terminated, truncated, info = step_compat(env, action)
|
obs, reward, terminated, truncated, info = step_compat(env, action)
|
||||||
|
if video_recorder is not None:
|
||||||
|
video_recorder.capture_frame_from_env(env)
|
||||||
|
video_recorder.on_env_step()
|
||||||
ep_reward += float(reward)
|
ep_reward += float(reward)
|
||||||
ep_max_x = max(ep_max_x, float(info.get("x_pos", 0.0)))
|
ep_max_x = max(ep_max_x, float(info.get("x_pos", 0.0)))
|
||||||
flag_get = flag_get or bool(info.get("flag_get", False))
|
flag_get = flag_get or bool(info.get("flag_get", False))
|
||||||
@@ -216,6 +377,168 @@ def main() -> None:
|
|||||||
f"[episode {ep}] reward={ep_reward:.2f} max_x={ep_max_x:.1f} "
|
f"[episode {ep}] reward={ep_reward:.2f} max_x={ep_max_x:.1f} "
|
||||||
f"clear={flag_get} steps={step_count}"
|
f"clear={flag_get} steps={step_count}"
|
||||||
)
|
)
|
||||||
|
if video_recorder is not None:
|
||||||
|
video_recorder.end_episode()
|
||||||
|
|
||||||
|
return rewards, max_x_positions, clear_flags
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_vec_rank0(env: VecEnv, obs: np.ndarray, seed: int) -> np.ndarray:
|
||||||
|
try:
|
||||||
|
reset_results = env.env_method("reset", seed=seed, indices=0)
|
||||||
|
except TypeError:
|
||||||
|
reset_results = env.env_method("reset", indices=0)
|
||||||
|
|
||||||
|
if not reset_results:
|
||||||
|
return obs
|
||||||
|
|
||||||
|
reset_result = reset_results[0]
|
||||||
|
obs0 = reset_result[0] if isinstance(reset_result, tuple) else reset_result
|
||||||
|
obs[0] = obs0
|
||||||
|
return obs
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_vec_env(env: VecEnv, model: PPO, args: argparse.Namespace) -> tuple[list[float], list[float], list[float]]:
|
||||||
|
if args.random_noops > 0:
|
||||||
|
print("[warn] random_noops is ignored for VecEnv eval.")
|
||||||
|
|
||||||
|
rewards: list[float] = []
|
||||||
|
max_x_positions: list[float] = []
|
||||||
|
clear_flags: list[float] = []
|
||||||
|
|
||||||
|
obs = env.reset()
|
||||||
|
n_envs = int(getattr(env, "num_envs", 1))
|
||||||
|
|
||||||
|
for ep in range(1, args.episodes + 1):
|
||||||
|
done = False
|
||||||
|
ep_reward = 0.0
|
||||||
|
ep_max_x = 0.0
|
||||||
|
flag_get = False
|
||||||
|
step_count = 0
|
||||||
|
|
||||||
|
while not done and step_count < args.max_steps:
|
||||||
|
actions = np.asarray(
|
||||||
|
[
|
||||||
|
select_action(
|
||||||
|
model=model,
|
||||||
|
obs=obs[env_idx],
|
||||||
|
deterministic=not args.stochastic,
|
||||||
|
epsilon=args.epsilon,
|
||||||
|
epsilon_random_mode=args.epsilon_random_mode,
|
||||||
|
env=env,
|
||||||
|
)
|
||||||
|
for env_idx in range(n_envs)
|
||||||
|
],
|
||||||
|
dtype=np.int64,
|
||||||
|
)
|
||||||
|
obs, reward_vec, done_vec, info_vec = env.step(actions)
|
||||||
|
info0 = info_vec[0] if len(info_vec) > 0 else {}
|
||||||
|
ep_reward += float(reward_vec[0])
|
||||||
|
ep_max_x = max(ep_max_x, float(info0.get("x_pos", 0.0)))
|
||||||
|
flag_get = flag_get or bool(info0.get("flag_get", False))
|
||||||
|
done = bool(done_vec[0])
|
||||||
|
step_count += 1
|
||||||
|
|
||||||
|
if not done and step_count >= args.max_steps:
|
||||||
|
obs = _reset_vec_rank0(env=env, obs=obs, seed=args.seed + ep + 1000)
|
||||||
|
|
||||||
|
rewards.append(ep_reward)
|
||||||
|
max_x_positions.append(ep_max_x)
|
||||||
|
clear_flags.append(1.0 if flag_get else 0.0)
|
||||||
|
print(
|
||||||
|
f"[episode {ep}] reward={ep_reward:.2f} max_x={ep_max_x:.1f} "
|
||||||
|
f"clear={flag_get} steps={step_count}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return rewards, max_x_positions, clear_flags
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
if args.epsilon < 0.0 or args.epsilon > 1.0:
|
||||||
|
raise ValueError(f"--epsilon must be in [0, 1], got {args.epsilon}")
|
||||||
|
if args.video_fps <= 0:
|
||||||
|
raise ValueError(f"--video-fps must be > 0, got {args.video_fps}")
|
||||||
|
if args.video_episode_trigger <= 0:
|
||||||
|
raise ValueError(f"--video-episode-trigger must be > 0, got {args.video_episode_trigger}")
|
||||||
|
if args.video_length < 0:
|
||||||
|
raise ValueError(f"--video-length must be >= 0, got {args.video_length}")
|
||||||
|
seed_everything(args.seed)
|
||||||
|
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||||
|
|
||||||
|
model_path = resolve_model_path(args.model_path)
|
||||||
|
print(f"[eval] model={model_path}")
|
||||||
|
model = PPO.load(str(model_path))
|
||||||
|
movement = resolve_movement(args.movement, model)
|
||||||
|
print(
|
||||||
|
f"[eval] movement={movement} reward_mode={reward_mode} random_noops={args.random_noops} "
|
||||||
|
f"time_penalty={args.time_penalty} hard_stuck_steps={args.hard_stuck_steps} "
|
||||||
|
f"hard_stuck_epsilon={args.hard_stuck_epsilon} hard_stuck_penalty={args.hard_stuck_penalty} "
|
||||||
|
f"epsilon={args.epsilon} epsilon_random_mode={args.epsilon_random_mode}"
|
||||||
|
)
|
||||||
|
render_mode = _resolve_render_mode(args)
|
||||||
|
|
||||||
|
env = make_mario_env(
|
||||||
|
env_id=args.env_id,
|
||||||
|
seed=args.seed,
|
||||||
|
movement=movement,
|
||||||
|
reward_mode=reward_mode,
|
||||||
|
clip_reward=args.clip_reward,
|
||||||
|
frame_skip=args.frame_skip,
|
||||||
|
render_mode=render_mode,
|
||||||
|
progress_scale=args.progress_scale,
|
||||||
|
death_penalty=args.death_penalty,
|
||||||
|
flag_bonus=args.flag_bonus,
|
||||||
|
stall_penalty=args.stall_penalty,
|
||||||
|
stall_steps=args.stall_steps,
|
||||||
|
backward_penalty_scale=args.backward_penalty_scale,
|
||||||
|
milestone_interval=args.milestone_interval,
|
||||||
|
milestone_bonus=args.milestone_bonus,
|
||||||
|
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||||
|
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||||
|
time_penalty=args.time_penalty,
|
||||||
|
hard_stuck_steps=args.hard_stuck_steps,
|
||||||
|
hard_stuck_epsilon=args.hard_stuck_epsilon,
|
||||||
|
hard_stuck_penalty=args.hard_stuck_penalty,
|
||||||
|
)
|
||||||
|
video_recorder: EpisodeVideoRecorder | None = None
|
||||||
|
if args.record_video:
|
||||||
|
_check_video_recording_dependencies()
|
||||||
|
video_dir = Path(args.video_dir).expanduser().resolve()
|
||||||
|
video_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
name_prefix = _build_video_name_prefix(model_path=model_path, seed=args.seed)
|
||||||
|
_set_render_fps(env, args.video_fps)
|
||||||
|
if isinstance(env, VecEnv):
|
||||||
|
env = wrap_vec_env_for_video(env=env, args=args, video_dir=video_dir, name_prefix=name_prefix)
|
||||||
|
else:
|
||||||
|
trigger_every = max(1, args.video_episode_trigger)
|
||||||
|
video_recorder = EpisodeVideoRecorder(
|
||||||
|
video_dir=video_dir,
|
||||||
|
name_prefix=name_prefix,
|
||||||
|
fps=args.video_fps,
|
||||||
|
trigger_every=trigger_every,
|
||||||
|
video_length=args.video_length,
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"[video] recording enabled folder={video_dir} fps={args.video_fps} "
|
||||||
|
f"kind=imageio_episode trigger_every_episodes={trigger_every} "
|
||||||
|
f"video_length={max(0, args.video_length)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if isinstance(env, VecEnv):
|
||||||
|
rewards, max_x_positions, clear_flags = evaluate_vec_env(env=env, model=model, args=args)
|
||||||
|
else:
|
||||||
|
rewards, max_x_positions, clear_flags = evaluate_single_env(
|
||||||
|
env=env,
|
||||||
|
model=model,
|
||||||
|
args=args,
|
||||||
|
video_recorder=video_recorder,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if video_recorder is not None:
|
||||||
|
video_recorder.close()
|
||||||
|
_close_env_safely(env)
|
||||||
|
|
||||||
summary = {
|
summary = {
|
||||||
"episodes": args.episodes,
|
"episodes": args.episodes,
|
||||||
@@ -225,8 +548,6 @@ def main() -> None:
|
|||||||
}
|
}
|
||||||
print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2))
|
print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
env.close()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Reference in New Issue
Block a user