feat: improve device handling and add stochastic option
This commit is contained in:
@@ -114,6 +114,7 @@ def main() -> None:
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = PPO.load(str(model_path))
|
||||
print(f"[video] model={model_path}")
|
||||
movement = resolve_movement(args.movement, model)
|
||||
print(f"[video] movement={movement} reward_mode={reward_mode}")
|
||||
env = make_mario_env(
|
||||
|
||||
@@ -45,7 +45,12 @@ def parse_args() -> argparse.Namespace:
|
||||
parser.add_argument("--clip-range", type=float, default=0.2)
|
||||
|
||||
parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.")
|
||||
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "mps"])
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="auto",
|
||||
help="Torch device: auto | cpu | mps | cuda | cuda:0 ...",
|
||||
)
|
||||
parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.")
|
||||
parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.")
|
||||
parser.add_argument(
|
||||
|
||||
@@ -45,21 +45,49 @@ def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]:
|
||||
if requested == "cpu":
|
||||
return "cpu", "User requested CPU."
|
||||
|
||||
if requested not in {"auto", "mps"}:
|
||||
return "cpu", f"Unknown device '{requested}', fallback to CPU."
|
||||
def _probe_cuda(device_name: str = "cuda") -> Tuple[str, str]:
|
||||
if not torch.cuda.is_available():
|
||||
return "cpu", "CUDA unavailable, using CPU."
|
||||
try:
|
||||
x = torch.ones(8, device=device_name)
|
||||
_ = (x * 2).cpu().numpy()
|
||||
gpu_name = torch.cuda.get_device_name(torch.device(device_name))
|
||||
return device_name, f"CUDA is available ({gpu_name})."
|
||||
except Exception as exc: # pragma: no cover - hardware dependent
|
||||
return "cpu", f"CUDA check failed ({exc}), fallback to CPU."
|
||||
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
if not mps_available:
|
||||
if requested == "mps":
|
||||
return "cpu", "MPS requested but unavailable, fallback to CPU."
|
||||
return "cpu", "MPS unavailable, using CPU."
|
||||
def _probe_mps() -> Tuple[str, str]:
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
if not mps_available:
|
||||
return "cpu", "MPS unavailable, using CPU."
|
||||
try:
|
||||
x = torch.ones(8, device="mps")
|
||||
_ = (x * 2).cpu().numpy()
|
||||
return "mps", "MPS is available and passed a quick tensor sanity check."
|
||||
except Exception as exc: # pragma: no cover - hardware dependent
|
||||
return "cpu", f"MPS check failed ({exc}), fallback to CPU."
|
||||
|
||||
try:
|
||||
x = torch.ones(8, device="mps")
|
||||
_ = (x * 2).cpu().numpy()
|
||||
return "mps", "MPS is available and passed a quick tensor sanity check."
|
||||
except Exception as exc: # pragma: no cover - hardware dependent
|
||||
return "cpu", f"MPS check failed ({exc}), fallback to CPU."
|
||||
if requested.startswith("cuda"):
|
||||
return _probe_cuda(requested)
|
||||
|
||||
if requested == "mps":
|
||||
device, msg = _probe_mps()
|
||||
if device == "cpu":
|
||||
return device, "MPS requested but unavailable, fallback to CPU."
|
||||
return device, msg
|
||||
|
||||
if requested == "auto":
|
||||
device, msg = _probe_cuda("cuda")
|
||||
if device != "cpu":
|
||||
return device, msg
|
||||
|
||||
device, msg = _probe_mps()
|
||||
if device != "cpu":
|
||||
return device, msg
|
||||
|
||||
return "cpu", "Neither CUDA nor MPS is available, using CPU."
|
||||
|
||||
return "cpu", f"Unknown device '{requested}', fallback to CPU."
|
||||
|
||||
|
||||
def latest_model_path(models_dir: Path) -> Optional[Path]:
|
||||
|
||||
Reference in New Issue
Block a user