feat: improve device handling and add stochastic option

This commit is contained in:
2026-02-12 19:13:12 +08:00
parent d23de69b9a
commit 71008dfb72
4 changed files with 52 additions and 18 deletions

View File

@@ -143,13 +143,13 @@ tensorboard --logdir artifacts/logs --port 6006
加载最新模型,跑 N 个 episode输出平均指标 加载最新模型,跑 N 个 episode输出平均指标
```bash ```bash
python -m src.eval --episodes 5 python -m src.eval --episodes 5 --stochastic
``` ```
可指定模型: 可指定模型:
```bash ```bash
python -m src.eval --model-path artifacts/models/latest_model.zip --episodes 10 python -m src.eval --model-path artifacts/models/latest_model.zip --episodes 10 --stochastic
``` ```
注意:`eval.py` 默认 `--movement auto`,会按模型动作维度自动匹配 `right_only/simple`,避免动作空间不一致导致 `KeyError` 注意:`eval.py` 默认 `--movement auto`,会按模型动作维度自动匹配 `right_only/simple`,避免动作空间不一致导致 `KeyError`
@@ -174,13 +174,13 @@ _total_timesteps = 150000
默认录制约 10 秒 mp4 到 `artifacts/videos/` 默认录制约 10 秒 mp4 到 `artifacts/videos/`
```bash ```bash
python -m src.record_video --duration-sec 10 --fps 30 python -m src.record_video --duration-sec 10 --fps 30 --stochastic
``` ```
可指定输出路径: 可指定输出路径:
```bash ```bash
python -m src.record_video --output artifacts/videos/demo.mp4 --duration-sec 10 python -m src.record_video --output artifacts/videos/demo.mp4 --stochastic --duration-sec 10
``` ```
注意:`record_video.py` 默认 `--movement auto`,会按模型自动匹配动作空间。 注意:`record_video.py` 默认 `--movement auto`,会按模型自动匹配动作空间。

View File

@@ -114,6 +114,7 @@ def main() -> None:
output_path.parent.mkdir(parents=True, exist_ok=True) output_path.parent.mkdir(parents=True, exist_ok=True)
model = PPO.load(str(model_path)) model = PPO.load(str(model_path))
print(f"[video] model={model_path}")
movement = resolve_movement(args.movement, model) movement = resolve_movement(args.movement, model)
print(f"[video] movement={movement} reward_mode={reward_mode}") print(f"[video] movement={movement} reward_mode={reward_mode}")
env = make_mario_env( env = make_mario_env(

View File

@@ -45,7 +45,12 @@ def parse_args() -> argparse.Namespace:
parser.add_argument("--clip-range", type=float, default=0.2) parser.add_argument("--clip-range", type=float, default=0.2)
parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.") parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.")
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "mps"]) parser.add_argument(
"--device",
type=str,
default="auto",
help="Torch device: auto | cpu | mps | cuda | cuda:0 ...",
)
parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.") parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.")
parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.") parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.")
parser.add_argument( parser.add_argument(

View File

@@ -45,21 +45,49 @@ def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]:
if requested == "cpu": if requested == "cpu":
return "cpu", "User requested CPU." return "cpu", "User requested CPU."
if requested not in {"auto", "mps"}: def _probe_cuda(device_name: str = "cuda") -> Tuple[str, str]:
return "cpu", f"Unknown device '{requested}', fallback to CPU." if not torch.cuda.is_available():
return "cpu", "CUDA unavailable, using CPU."
try:
x = torch.ones(8, device=device_name)
_ = (x * 2).cpu().numpy()
gpu_name = torch.cuda.get_device_name(torch.device(device_name))
return device_name, f"CUDA is available ({gpu_name})."
except Exception as exc: # pragma: no cover - hardware dependent
return "cpu", f"CUDA check failed ({exc}), fallback to CPU."
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() def _probe_mps() -> Tuple[str, str]:
if not mps_available: mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
if requested == "mps": if not mps_available:
return "cpu", "MPS requested but unavailable, fallback to CPU." return "cpu", "MPS unavailable, using CPU."
return "cpu", "MPS unavailable, using CPU." try:
x = torch.ones(8, device="mps")
_ = (x * 2).cpu().numpy()
return "mps", "MPS is available and passed a quick tensor sanity check."
except Exception as exc: # pragma: no cover - hardware dependent
return "cpu", f"MPS check failed ({exc}), fallback to CPU."
try: if requested.startswith("cuda"):
x = torch.ones(8, device="mps") return _probe_cuda(requested)
_ = (x * 2).cpu().numpy()
return "mps", "MPS is available and passed a quick tensor sanity check." if requested == "mps":
except Exception as exc: # pragma: no cover - hardware dependent device, msg = _probe_mps()
return "cpu", f"MPS check failed ({exc}), fallback to CPU." if device == "cpu":
return device, "MPS requested but unavailable, fallback to CPU."
return device, msg
if requested == "auto":
device, msg = _probe_cuda("cuda")
if device != "cpu":
return device, msg
device, msg = _probe_mps()
if device != "cpu":
return device, msg
return "cpu", "Neither CUDA nor MPS is available, using CPU."
return "cpu", f"Unknown device '{requested}', fallback to CPU."
def latest_model_path(models_dir: Path) -> Optional[Path]: def latest_model_path(models_dir: Path) -> Optional[Path]: