feat: improve device handling and add stochastic option
This commit is contained in:
@@ -143,13 +143,13 @@ tensorboard --logdir artifacts/logs --port 6006
|
|||||||
加载最新模型,跑 N 个 episode,输出平均指标:
|
加载最新模型,跑 N 个 episode,输出平均指标:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m src.eval --episodes 5
|
python -m src.eval --episodes 5 --stochastic
|
||||||
```
|
```
|
||||||
|
|
||||||
可指定模型:
|
可指定模型:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m src.eval --model-path artifacts/models/latest_model.zip --episodes 10
|
python -m src.eval --model-path artifacts/models/latest_model.zip --episodes 10 --stochastic
|
||||||
```
|
```
|
||||||
|
|
||||||
注意:`eval.py` 默认 `--movement auto`,会按模型动作维度自动匹配 `right_only/simple`,避免动作空间不一致导致 `KeyError`。
|
注意:`eval.py` 默认 `--movement auto`,会按模型动作维度自动匹配 `right_only/simple`,避免动作空间不一致导致 `KeyError`。
|
||||||
@@ -174,13 +174,13 @@ _total_timesteps = 150000
|
|||||||
默认录制约 10 秒 mp4 到 `artifacts/videos/`:
|
默认录制约 10 秒 mp4 到 `artifacts/videos/`:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m src.record_video --duration-sec 10 --fps 30
|
python -m src.record_video --duration-sec 10 --fps 30 --stochastic
|
||||||
```
|
```
|
||||||
|
|
||||||
可指定输出路径:
|
可指定输出路径:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python -m src.record_video --output artifacts/videos/demo.mp4 --duration-sec 10
|
python -m src.record_video --output artifacts/videos/demo.mp4 --stochastic --duration-sec 10
|
||||||
```
|
```
|
||||||
|
|
||||||
注意:`record_video.py` 默认 `--movement auto`,会按模型自动匹配动作空间。
|
注意:`record_video.py` 默认 `--movement auto`,会按模型自动匹配动作空间。
|
||||||
|
|||||||
@@ -114,6 +114,7 @@ def main() -> None:
|
|||||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
model = PPO.load(str(model_path))
|
model = PPO.load(str(model_path))
|
||||||
|
print(f"[video] model={model_path}")
|
||||||
movement = resolve_movement(args.movement, model)
|
movement = resolve_movement(args.movement, model)
|
||||||
print(f"[video] movement={movement} reward_mode={reward_mode}")
|
print(f"[video] movement={movement} reward_mode={reward_mode}")
|
||||||
env = make_mario_env(
|
env = make_mario_env(
|
||||||
|
|||||||
@@ -45,7 +45,12 @@ def parse_args() -> argparse.Namespace:
|
|||||||
parser.add_argument("--clip-range", type=float, default=0.2)
|
parser.add_argument("--clip-range", type=float, default=0.2)
|
||||||
|
|
||||||
parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.")
|
parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.")
|
||||||
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "mps"])
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="auto",
|
||||||
|
help="Torch device: auto | cpu | mps | cuda | cuda:0 ...",
|
||||||
|
)
|
||||||
parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.")
|
parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.")
|
||||||
parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.")
|
parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -45,21 +45,49 @@ def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]:
|
|||||||
if requested == "cpu":
|
if requested == "cpu":
|
||||||
return "cpu", "User requested CPU."
|
return "cpu", "User requested CPU."
|
||||||
|
|
||||||
if requested not in {"auto", "mps"}:
|
def _probe_cuda(device_name: str = "cuda") -> Tuple[str, str]:
|
||||||
return "cpu", f"Unknown device '{requested}', fallback to CPU."
|
if not torch.cuda.is_available():
|
||||||
|
return "cpu", "CUDA unavailable, using CPU."
|
||||||
|
try:
|
||||||
|
x = torch.ones(8, device=device_name)
|
||||||
|
_ = (x * 2).cpu().numpy()
|
||||||
|
gpu_name = torch.cuda.get_device_name(torch.device(device_name))
|
||||||
|
return device_name, f"CUDA is available ({gpu_name})."
|
||||||
|
except Exception as exc: # pragma: no cover - hardware dependent
|
||||||
|
return "cpu", f"CUDA check failed ({exc}), fallback to CPU."
|
||||||
|
|
||||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
def _probe_mps() -> Tuple[str, str]:
|
||||||
if not mps_available:
|
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||||
if requested == "mps":
|
if not mps_available:
|
||||||
return "cpu", "MPS requested but unavailable, fallback to CPU."
|
return "cpu", "MPS unavailable, using CPU."
|
||||||
return "cpu", "MPS unavailable, using CPU."
|
try:
|
||||||
|
x = torch.ones(8, device="mps")
|
||||||
|
_ = (x * 2).cpu().numpy()
|
||||||
|
return "mps", "MPS is available and passed a quick tensor sanity check."
|
||||||
|
except Exception as exc: # pragma: no cover - hardware dependent
|
||||||
|
return "cpu", f"MPS check failed ({exc}), fallback to CPU."
|
||||||
|
|
||||||
try:
|
if requested.startswith("cuda"):
|
||||||
x = torch.ones(8, device="mps")
|
return _probe_cuda(requested)
|
||||||
_ = (x * 2).cpu().numpy()
|
|
||||||
return "mps", "MPS is available and passed a quick tensor sanity check."
|
if requested == "mps":
|
||||||
except Exception as exc: # pragma: no cover - hardware dependent
|
device, msg = _probe_mps()
|
||||||
return "cpu", f"MPS check failed ({exc}), fallback to CPU."
|
if device == "cpu":
|
||||||
|
return device, "MPS requested but unavailable, fallback to CPU."
|
||||||
|
return device, msg
|
||||||
|
|
||||||
|
if requested == "auto":
|
||||||
|
device, msg = _probe_cuda("cuda")
|
||||||
|
if device != "cpu":
|
||||||
|
return device, msg
|
||||||
|
|
||||||
|
device, msg = _probe_mps()
|
||||||
|
if device != "cpu":
|
||||||
|
return device, msg
|
||||||
|
|
||||||
|
return "cpu", "Neither CUDA nor MPS is available, using CPU."
|
||||||
|
|
||||||
|
return "cpu", f"Unknown device '{requested}', fallback to CPU."
|
||||||
|
|
||||||
|
|
||||||
def latest_model_path(models_dir: Path) -> Optional[Path]:
|
def latest_model_path(models_dir: Path) -> Optional[Path]:
|
||||||
|
|||||||
Reference in New Issue
Block a user