From 71008dfb72fc24ede6429ca5bddd787165ec8431 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Roog=20=28=E9=A1=BE=E6=96=B0=E5=9F=B9=29?= Date: Thu, 12 Feb 2026 19:13:12 +0800 Subject: [PATCH] feat: improve device handling and add stochastic option --- mario-rl-mvp/README.md | 8 ++--- mario-rl-mvp/src/record_video.py | 1 + mario-rl-mvp/src/train_ppo.py | 7 ++++- mario-rl-mvp/src/utils.py | 54 ++++++++++++++++++++++++-------- 4 files changed, 52 insertions(+), 18 deletions(-) diff --git a/mario-rl-mvp/README.md b/mario-rl-mvp/README.md index e5d2bc3..9b4e674 100644 --- a/mario-rl-mvp/README.md +++ b/mario-rl-mvp/README.md @@ -143,13 +143,13 @@ tensorboard --logdir artifacts/logs --port 6006 加载最新模型,跑 N 个 episode,输出平均指标: ```bash -python -m src.eval --episodes 5 +python -m src.eval --episodes 5 --stochastic ``` 可指定模型: ```bash -python -m src.eval --model-path artifacts/models/latest_model.zip --episodes 10 +python -m src.eval --model-path artifacts/models/latest_model.zip --episodes 10 --stochastic ``` 注意:`eval.py` 默认 `--movement auto`,会按模型动作维度自动匹配 `right_only/simple`,避免动作空间不一致导致 `KeyError`。 @@ -174,13 +174,13 @@ _total_timesteps = 150000 默认录制约 10 秒 mp4 到 `artifacts/videos/`: ```bash -python -m src.record_video --duration-sec 10 --fps 30 +python -m src.record_video --duration-sec 10 --fps 30 --stochastic ``` 可指定输出路径: ```bash -python -m src.record_video --output artifacts/videos/demo.mp4 --duration-sec 10 +python -m src.record_video --output artifacts/videos/demo.mp4 --stochastic --duration-sec 10 ``` 注意:`record_video.py` 默认 `--movement auto`,会按模型自动匹配动作空间。 diff --git a/mario-rl-mvp/src/record_video.py b/mario-rl-mvp/src/record_video.py index 8bfffa6..49638e1 100644 --- a/mario-rl-mvp/src/record_video.py +++ b/mario-rl-mvp/src/record_video.py @@ -114,6 +114,7 @@ def main() -> None: output_path.parent.mkdir(parents=True, exist_ok=True) model = PPO.load(str(model_path)) + print(f"[video] model={model_path}") movement = resolve_movement(args.movement, model) print(f"[video] movement={movement} reward_mode={reward_mode}") env = make_mario_env( diff --git a/mario-rl-mvp/src/train_ppo.py b/mario-rl-mvp/src/train_ppo.py index a809d95..5ea4b18 100644 --- a/mario-rl-mvp/src/train_ppo.py +++ b/mario-rl-mvp/src/train_ppo.py @@ -45,7 +45,12 @@ def parse_args() -> argparse.Namespace: parser.add_argument("--clip-range", type=float, default=0.2) parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.") - parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "mps"]) + parser.add_argument( + "--device", + type=str, + default="auto", + help="Torch device: auto | cpu | mps | cuda | cuda:0 ...", + ) parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.") parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.") parser.add_argument( diff --git a/mario-rl-mvp/src/utils.py b/mario-rl-mvp/src/utils.py index fe62204..d4ce232 100644 --- a/mario-rl-mvp/src/utils.py +++ b/mario-rl-mvp/src/utils.py @@ -45,21 +45,49 @@ def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]: if requested == "cpu": return "cpu", "User requested CPU." - if requested not in {"auto", "mps"}: - return "cpu", f"Unknown device '{requested}', fallback to CPU." + def _probe_cuda(device_name: str = "cuda") -> Tuple[str, str]: + if not torch.cuda.is_available(): + return "cpu", "CUDA unavailable, using CPU." + try: + x = torch.ones(8, device=device_name) + _ = (x * 2).cpu().numpy() + gpu_name = torch.cuda.get_device_name(torch.device(device_name)) + return device_name, f"CUDA is available ({gpu_name})." + except Exception as exc: # pragma: no cover - hardware dependent + return "cpu", f"CUDA check failed ({exc}), fallback to CPU." - mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() - if not mps_available: - if requested == "mps": - return "cpu", "MPS requested but unavailable, fallback to CPU." - return "cpu", "MPS unavailable, using CPU." + def _probe_mps() -> Tuple[str, str]: + mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + if not mps_available: + return "cpu", "MPS unavailable, using CPU." + try: + x = torch.ones(8, device="mps") + _ = (x * 2).cpu().numpy() + return "mps", "MPS is available and passed a quick tensor sanity check." + except Exception as exc: # pragma: no cover - hardware dependent + return "cpu", f"MPS check failed ({exc}), fallback to CPU." - try: - x = torch.ones(8, device="mps") - _ = (x * 2).cpu().numpy() - return "mps", "MPS is available and passed a quick tensor sanity check." - except Exception as exc: # pragma: no cover - hardware dependent - return "cpu", f"MPS check failed ({exc}), fallback to CPU." + if requested.startswith("cuda"): + return _probe_cuda(requested) + + if requested == "mps": + device, msg = _probe_mps() + if device == "cpu": + return device, "MPS requested but unavailable, fallback to CPU." + return device, msg + + if requested == "auto": + device, msg = _probe_cuda("cuda") + if device != "cpu": + return device, msg + + device, msg = _probe_mps() + if device != "cpu": + return device, msg + + return "cpu", "Neither CUDA nor MPS is available, using CPU." + + return "cpu", f"Unknown device '{requested}', fallback to CPU." def latest_model_path(models_dir: Path) -> Optional[Path]: