feat: initial mario rl mvp
This commit is contained in:
1
mario-rl-mvp/src/__init__.py
Normal file
1
mario-rl-mvp/src/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Mario RL MVP package."""
|
||||
337
mario-rl-mvp/src/env.py
Normal file
337
mario-rl-mvp/src/env.py
Normal file
@@ -0,0 +1,337 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from typing import Any, Callable, Deque, Dict, Optional, Tuple
|
||||
|
||||
import cv2
|
||||
import gym
|
||||
import gym_super_mario_bros
|
||||
import numpy as np
|
||||
from gym.spaces import Box
|
||||
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT
|
||||
from nes_py.wrappers import JoypadSpace
|
||||
|
||||
|
||||
def reset_compat(env: gym.Env, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict[str, Any]]:
|
||||
try:
|
||||
result = env.reset(seed=seed)
|
||||
except TypeError:
|
||||
result = env.reset()
|
||||
if seed is not None and hasattr(env, "seed"):
|
||||
env.seed(seed)
|
||||
|
||||
if isinstance(result, tuple) and len(result) == 2:
|
||||
obs, info = result
|
||||
return obs, info
|
||||
return result, {}
|
||||
|
||||
|
||||
def step_compat(env: gym.Env, action: Any) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]:
|
||||
result = env.step(action)
|
||||
if isinstance(result, tuple) and len(result) == 5:
|
||||
obs, reward, terminated, truncated, info = result
|
||||
return obs, float(reward), bool(terminated), bool(truncated), info
|
||||
if isinstance(result, tuple) and len(result) == 4:
|
||||
obs, reward, done, info = result
|
||||
return obs, float(reward), bool(done), False, info
|
||||
raise RuntimeError(f"Unexpected step return format: {type(result)} / {result}")
|
||||
|
||||
|
||||
class SkipFrame(gym.Wrapper):
|
||||
def __init__(self, env: gym.Env, skip: int = 4):
|
||||
super().__init__(env)
|
||||
self._skip = skip
|
||||
|
||||
def step(self, action: Any):
|
||||
total_reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
info: Dict[str, Any] = {}
|
||||
obs = None
|
||||
for _ in range(self._skip):
|
||||
obs, reward, terminated, truncated, info = step_compat(self.env, action)
|
||||
total_reward += reward
|
||||
if terminated or truncated:
|
||||
break
|
||||
return obs, total_reward, terminated, truncated, info
|
||||
|
||||
|
||||
class PreprocessFrame(gym.ObservationWrapper):
|
||||
"""Convert RGB frame to grayscale 84x84 uint8."""
|
||||
|
||||
def __init__(self, env: gym.Env, width: int = 84, height: int = 84):
|
||||
super().__init__(env)
|
||||
self.width = width
|
||||
self.height = height
|
||||
self.observation_space = Box(low=0, high=255, shape=(height, width, 1), dtype=np.uint8)
|
||||
|
||||
def observation(self, observation: np.ndarray) -> np.ndarray:
|
||||
if observation.ndim == 3 and observation.shape[2] == 3:
|
||||
gray = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
|
||||
elif observation.ndim == 2:
|
||||
gray = observation
|
||||
else:
|
||||
gray = np.squeeze(observation)
|
||||
resized = cv2.resize(gray, (self.width, self.height), interpolation=cv2.INTER_AREA)
|
||||
return resized[:, :, None].astype(np.uint8)
|
||||
|
||||
|
||||
class ChannelLastFrameStack(gym.Wrapper):
|
||||
"""Stack frames on the channel axis: (H, W, C*num_stack)."""
|
||||
|
||||
def __init__(self, env: gym.Env, num_stack: int = 4):
|
||||
super().__init__(env)
|
||||
self.num_stack = num_stack
|
||||
self.frames: Deque[np.ndarray] = deque(maxlen=num_stack)
|
||||
|
||||
obs_space = env.observation_space
|
||||
assert isinstance(obs_space, Box), "Frame stack requires Box observation space."
|
||||
h, w, c = obs_space.shape
|
||||
self.observation_space = Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(h, w, c * num_stack),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
del options
|
||||
obs, info = reset_compat(self.env, seed=seed)
|
||||
self.frames.clear()
|
||||
for _ in range(self.num_stack):
|
||||
self.frames.append(obs)
|
||||
return self._get_observation(), info
|
||||
|
||||
def step(self, action: Any):
|
||||
obs, reward, terminated, truncated, info = step_compat(self.env, action)
|
||||
self.frames.append(obs)
|
||||
return self._get_observation(), reward, terminated, truncated, info
|
||||
|
||||
def _get_observation(self) -> np.ndarray:
|
||||
assert len(self.frames) == self.num_stack
|
||||
return np.concatenate(list(self.frames), axis=2)
|
||||
|
||||
|
||||
class TransposeObservation(gym.ObservationWrapper):
|
||||
"""Convert observation from HWC to CHW for CNN policy."""
|
||||
|
||||
def __init__(self, env: gym.Env):
|
||||
super().__init__(env)
|
||||
obs_space = env.observation_space
|
||||
assert isinstance(obs_space, Box), "TransposeObservation requires Box observation space."
|
||||
h, w, c = obs_space.shape
|
||||
self.observation_space = Box(low=0, high=255, shape=(c, h, w), dtype=obs_space.dtype)
|
||||
|
||||
def observation(self, observation: np.ndarray) -> np.ndarray:
|
||||
return np.transpose(observation, (2, 0, 1)).astype(np.uint8)
|
||||
|
||||
|
||||
class ClipRewardEnv(gym.RewardWrapper):
|
||||
def reward(self, reward):
|
||||
return float(np.sign(reward))
|
||||
|
||||
|
||||
class ProgressRewardEnv(gym.Wrapper):
|
||||
"""Reward shaping focused on moving right and avoiding local traps."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: gym.Env,
|
||||
progress_scale: float = 0.02,
|
||||
death_penalty: float = -50.0,
|
||||
flag_bonus: float = 100.0,
|
||||
stall_penalty: float = 0.05,
|
||||
stall_steps: int = 40,
|
||||
backward_penalty_scale: float = 0.01,
|
||||
milestone_interval: int = 32,
|
||||
milestone_bonus: float = 1.0,
|
||||
no_progress_terminate_steps: int = 120,
|
||||
no_progress_terminate_penalty: float = 20.0,
|
||||
):
|
||||
super().__init__(env)
|
||||
self.progress_scale = progress_scale
|
||||
self.death_penalty = death_penalty
|
||||
self.flag_bonus = flag_bonus
|
||||
self.stall_penalty = stall_penalty
|
||||
self.stall_steps = stall_steps
|
||||
self.backward_penalty_scale = backward_penalty_scale
|
||||
self.milestone_interval = milestone_interval
|
||||
self.milestone_bonus = milestone_bonus
|
||||
self.no_progress_terminate_steps = no_progress_terminate_steps
|
||||
self.no_progress_terminate_penalty = no_progress_terminate_penalty
|
||||
self._last_x_pos: Optional[float] = None
|
||||
self._best_x_pos = 0.0
|
||||
self._stall_count = 0
|
||||
self._next_milestone_x = float(milestone_interval)
|
||||
|
||||
def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
|
||||
del options
|
||||
obs, info = reset_compat(self.env, seed=seed)
|
||||
self._last_x_pos = float(info.get("x_pos", 0.0))
|
||||
self._best_x_pos = self._last_x_pos
|
||||
self._stall_count = 0
|
||||
if self.milestone_interval > 0:
|
||||
k = int(self._best_x_pos // self.milestone_interval) + 1
|
||||
self._next_milestone_x = float(k * self.milestone_interval)
|
||||
return obs, info
|
||||
|
||||
def step(self, action: Any):
|
||||
obs, reward, terminated, truncated, info = step_compat(self.env, action)
|
||||
x_pos = float(info.get("x_pos", 0.0))
|
||||
|
||||
if self._last_x_pos is None:
|
||||
delta_x = 0.0
|
||||
else:
|
||||
delta_x = x_pos - self._last_x_pos
|
||||
self._last_x_pos = x_pos
|
||||
|
||||
shaped_reward = float(reward)
|
||||
if delta_x > 0:
|
||||
shaped_reward += self.progress_scale * delta_x
|
||||
self._stall_count = 0
|
||||
if x_pos > self._best_x_pos:
|
||||
self._best_x_pos = x_pos
|
||||
if self.milestone_interval > 0 and self.milestone_bonus != 0.0:
|
||||
while x_pos >= self._next_milestone_x:
|
||||
shaped_reward += self.milestone_bonus
|
||||
self._next_milestone_x += self.milestone_interval
|
||||
else:
|
||||
self._stall_count += 1
|
||||
if delta_x < 0:
|
||||
shaped_reward -= self.backward_penalty_scale * abs(delta_x)
|
||||
|
||||
if self._stall_count >= self.stall_steps:
|
||||
shaped_reward -= self.stall_penalty
|
||||
|
||||
if (
|
||||
self.no_progress_terminate_steps > 0
|
||||
and not terminated
|
||||
and not truncated
|
||||
and self._stall_count >= self.no_progress_terminate_steps
|
||||
):
|
||||
truncated = True
|
||||
shaped_reward -= self.no_progress_terminate_penalty
|
||||
info["terminated_by_stall"] = True
|
||||
|
||||
if terminated or truncated:
|
||||
if bool(info.get("flag_get", False)):
|
||||
shaped_reward += self.flag_bonus
|
||||
elif terminated:
|
||||
shaped_reward += self.death_penalty
|
||||
|
||||
return obs, shaped_reward, terminated, truncated, info
|
||||
|
||||
|
||||
def get_action_set(name: str):
|
||||
name = name.lower().strip()
|
||||
if name == "simple":
|
||||
return SIMPLE_MOVEMENT
|
||||
if name == "right_only":
|
||||
return RIGHT_ONLY
|
||||
raise ValueError(f"Unsupported movement='{name}'. Use one of: right_only, simple")
|
||||
|
||||
|
||||
def make_mario_env(
|
||||
env_id: str = "SuperMarioBros-1-1-v0",
|
||||
seed: int = 0,
|
||||
movement: str = "right_only",
|
||||
reward_mode: str = "raw",
|
||||
clip_reward: bool = False,
|
||||
frame_skip: int = 4,
|
||||
render_mode: Optional[str] = None,
|
||||
progress_scale: float = 0.02,
|
||||
death_penalty: float = -50.0,
|
||||
flag_bonus: float = 100.0,
|
||||
stall_penalty: float = 0.05,
|
||||
stall_steps: int = 40,
|
||||
backward_penalty_scale: float = 0.01,
|
||||
milestone_interval: int = 32,
|
||||
milestone_bonus: float = 1.0,
|
||||
no_progress_terminate_steps: int = 120,
|
||||
no_progress_terminate_penalty: float = 20.0,
|
||||
) -> gym.Env:
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if render_mode is not None:
|
||||
kwargs["render_mode"] = render_mode
|
||||
|
||||
try:
|
||||
env = gym_super_mario_bros.make(env_id, apply_api_compatibility=True, **kwargs)
|
||||
except TypeError:
|
||||
env = gym_super_mario_bros.make(env_id, **kwargs)
|
||||
|
||||
env = JoypadSpace(env, get_action_set(movement))
|
||||
env = SkipFrame(env, skip=frame_skip)
|
||||
|
||||
mode = reward_mode.lower().strip()
|
||||
if clip_reward and mode == "raw":
|
||||
mode = "clip"
|
||||
if mode == "clip":
|
||||
env = ClipRewardEnv(env)
|
||||
elif mode == "progress":
|
||||
env = ProgressRewardEnv(
|
||||
env=env,
|
||||
progress_scale=progress_scale,
|
||||
death_penalty=death_penalty,
|
||||
flag_bonus=flag_bonus,
|
||||
stall_penalty=stall_penalty,
|
||||
stall_steps=stall_steps,
|
||||
backward_penalty_scale=backward_penalty_scale,
|
||||
milestone_interval=milestone_interval,
|
||||
milestone_bonus=milestone_bonus,
|
||||
no_progress_terminate_steps=no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=no_progress_terminate_penalty,
|
||||
)
|
||||
elif mode != "raw":
|
||||
raise ValueError(f"Unsupported reward_mode='{reward_mode}'. Use one of: raw, clip, progress")
|
||||
|
||||
env = PreprocessFrame(env, width=84, height=84)
|
||||
env = ChannelLastFrameStack(env, num_stack=4)
|
||||
env = TransposeObservation(env)
|
||||
|
||||
# Seed once so each env subprocess/dummy env has deterministic startup.
|
||||
reset_compat(env, seed=seed)
|
||||
if hasattr(env.action_space, "seed"):
|
||||
env.action_space.seed(seed)
|
||||
return env
|
||||
|
||||
|
||||
def make_env_fn(
|
||||
env_id: str,
|
||||
seed: int,
|
||||
movement: str,
|
||||
reward_mode: str,
|
||||
clip_reward: bool,
|
||||
frame_skip: int,
|
||||
progress_scale: float,
|
||||
death_penalty: float,
|
||||
flag_bonus: float,
|
||||
stall_penalty: float,
|
||||
stall_steps: int,
|
||||
backward_penalty_scale: float,
|
||||
milestone_interval: int,
|
||||
milestone_bonus: float,
|
||||
no_progress_terminate_steps: int,
|
||||
no_progress_terminate_penalty: float,
|
||||
) -> Callable[[], gym.Env]:
|
||||
def _thunk() -> gym.Env:
|
||||
return make_mario_env(
|
||||
env_id=env_id,
|
||||
seed=seed,
|
||||
movement=movement,
|
||||
reward_mode=reward_mode,
|
||||
clip_reward=clip_reward,
|
||||
frame_skip=frame_skip,
|
||||
render_mode=None,
|
||||
progress_scale=progress_scale,
|
||||
death_penalty=death_penalty,
|
||||
flag_bonus=flag_bonus,
|
||||
stall_penalty=stall_penalty,
|
||||
stall_steps=stall_steps,
|
||||
backward_penalty_scale=backward_penalty_scale,
|
||||
milestone_interval=milestone_interval,
|
||||
milestone_bonus=milestone_bonus,
|
||||
no_progress_terminate_steps=no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=no_progress_terminate_penalty,
|
||||
)
|
||||
|
||||
return _thunk
|
||||
162
mario-rl-mvp/src/eval.py
Normal file
162
mario-rl-mvp/src/eval.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from statistics import mean
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
||||
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Evaluate trained Mario PPO agent.")
|
||||
parser.add_argument("--model-path", type=str, default="", help="Path to .zip model. If empty, use latest model.")
|
||||
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
|
||||
parser.add_argument("--movement", type=str, default="auto", choices=["auto", "right_only", "simple"])
|
||||
parser.add_argument("--episodes", type=int, default=5)
|
||||
parser.add_argument("--max-steps", type=int, default=5_000)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--frame-skip", type=int, default=4)
|
||||
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
|
||||
parser.add_argument("--progress-scale", type=float, default=0.02)
|
||||
parser.add_argument("--death-penalty", type=float, default=-50.0)
|
||||
parser.add_argument("--flag-bonus", type=float, default=100.0)
|
||||
parser.add_argument("--stall-penalty", type=float, default=0.05)
|
||||
parser.add_argument("--stall-steps", type=int, default=40)
|
||||
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
|
||||
parser.add_argument("--milestone-interval", type=int, default=32)
|
||||
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||
parser.add_argument("--clip-reward", action="store_true")
|
||||
parser.add_argument("--stochastic", action="store_true", help="Use stochastic policy (deterministic=False).")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _action_count_for_movement(movement: str) -> int:
|
||||
return len(get_action_set(movement))
|
||||
|
||||
|
||||
def _model_action_count(model: PPO):
|
||||
if hasattr(model, "action_space") and hasattr(model.action_space, "n"):
|
||||
return int(model.action_space.n)
|
||||
action_net = getattr(getattr(model, "policy", None), "action_net", None)
|
||||
out_features = getattr(action_net, "out_features", None)
|
||||
if out_features is not None:
|
||||
return int(out_features)
|
||||
return None
|
||||
|
||||
|
||||
def resolve_movement(movement_arg: str, model: PPO) -> str:
|
||||
model_action_n = _model_action_count(model)
|
||||
|
||||
if movement_arg == "auto":
|
||||
if model_action_n is not None:
|
||||
for candidate in ("right_only", "simple"):
|
||||
if _action_count_for_movement(candidate) == model_action_n:
|
||||
return candidate
|
||||
return "right_only"
|
||||
|
||||
if model_action_n is not None:
|
||||
expected_n = _action_count_for_movement(movement_arg)
|
||||
if expected_n != model_action_n:
|
||||
raise ValueError(
|
||||
f"movement='{movement_arg}' has {expected_n} actions, but model expects {model_action_n}. "
|
||||
"Use --movement auto or pass the matching movement."
|
||||
)
|
||||
return movement_arg
|
||||
|
||||
|
||||
def resolve_model_path(user_path: str) -> Path:
|
||||
if user_path:
|
||||
p = Path(user_path).expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Model not found: {p}")
|
||||
return p
|
||||
|
||||
paths = ensure_artifact_paths()
|
||||
latest = latest_model_path(paths.models)
|
||||
if latest is None:
|
||||
raise FileNotFoundError("No model found under artifacts/models. Please run training first.")
|
||||
return latest
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
seed_everything(args.seed)
|
||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||
|
||||
model_path = resolve_model_path(args.model_path)
|
||||
print(f"[eval] model={model_path}")
|
||||
model = PPO.load(str(model_path))
|
||||
movement = resolve_movement(args.movement, model)
|
||||
print(f"[eval] movement={movement} reward_mode={reward_mode}")
|
||||
|
||||
env = make_mario_env(
|
||||
env_id=args.env_id,
|
||||
seed=args.seed,
|
||||
movement=movement,
|
||||
reward_mode=reward_mode,
|
||||
clip_reward=args.clip_reward,
|
||||
frame_skip=args.frame_skip,
|
||||
render_mode=None,
|
||||
progress_scale=args.progress_scale,
|
||||
death_penalty=args.death_penalty,
|
||||
flag_bonus=args.flag_bonus,
|
||||
stall_penalty=args.stall_penalty,
|
||||
stall_steps=args.stall_steps,
|
||||
backward_penalty_scale=args.backward_penalty_scale,
|
||||
milestone_interval=args.milestone_interval,
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
)
|
||||
|
||||
rewards = []
|
||||
max_x_positions = []
|
||||
clear_flags = []
|
||||
|
||||
for ep in range(1, args.episodes + 1):
|
||||
obs, info = reset_compat(env, seed=args.seed + ep)
|
||||
done = False
|
||||
ep_reward = 0.0
|
||||
ep_max_x = float(info.get("x_pos", 0.0))
|
||||
flag_get = False
|
||||
step_count = 0
|
||||
|
||||
while not done and step_count < args.max_steps:
|
||||
action, _ = model.predict(obs, deterministic=not args.stochastic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = int(action.item())
|
||||
obs, reward, terminated, truncated, info = step_compat(env, action)
|
||||
ep_reward += float(reward)
|
||||
ep_max_x = max(ep_max_x, float(info.get("x_pos", 0.0)))
|
||||
flag_get = flag_get or bool(info.get("flag_get", False))
|
||||
done = terminated or truncated
|
||||
step_count += 1
|
||||
|
||||
rewards.append(ep_reward)
|
||||
max_x_positions.append(ep_max_x)
|
||||
clear_flags.append(1.0 if flag_get else 0.0)
|
||||
print(
|
||||
f"[episode {ep}] reward={ep_reward:.2f} max_x={ep_max_x:.1f} "
|
||||
f"clear={flag_get} steps={step_count}"
|
||||
)
|
||||
|
||||
summary = {
|
||||
"episodes": args.episodes,
|
||||
"avg_reward": mean(rewards) if rewards else 0.0,
|
||||
"avg_max_x_pos": mean(max_x_positions) if max_x_positions else 0.0,
|
||||
"clear_rate": mean(clear_flags) if clear_flags else 0.0,
|
||||
}
|
||||
print("[summary]", json.dumps(summary, ensure_ascii=False, indent=2))
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
200
mario-rl-mvp/src/record_video.py
Normal file
200
mario-rl-mvp/src/record_video.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import imageio.v2 as imageio
|
||||
import numpy as np
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
from src.env import get_action_set, make_mario_env, reset_compat, step_compat
|
||||
from src.utils import ensure_artifact_paths, latest_model_path, seed_everything
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Record Mario agent rollout to mp4 (headless).")
|
||||
parser.add_argument("--model-path", type=str, default="", help="Path to .zip model. Empty means latest model.")
|
||||
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
|
||||
parser.add_argument("--movement", type=str, default="auto", choices=["auto", "right_only", "simple"])
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--frame-skip", type=int, default=4)
|
||||
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
|
||||
parser.add_argument("--progress-scale", type=float, default=0.02)
|
||||
parser.add_argument("--death-penalty", type=float, default=-50.0)
|
||||
parser.add_argument("--flag-bonus", type=float, default=100.0)
|
||||
parser.add_argument("--stall-penalty", type=float, default=0.05)
|
||||
parser.add_argument("--stall-steps", type=int, default=40)
|
||||
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
|
||||
parser.add_argument("--milestone-interval", type=int, default=32)
|
||||
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||
parser.add_argument("--clip-reward", action="store_true")
|
||||
parser.add_argument("--fps", type=int, default=30)
|
||||
parser.add_argument("--duration-sec", type=int, default=10)
|
||||
parser.add_argument("--max-steps", type=int, default=5_000)
|
||||
parser.add_argument("--stochastic", action="store_true")
|
||||
parser.add_argument("--output", type=str, default="", help="Output mp4 path.")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _action_count_for_movement(movement: str) -> int:
|
||||
return len(get_action_set(movement))
|
||||
|
||||
|
||||
def _model_action_count(model: PPO):
|
||||
if hasattr(model, "action_space") and hasattr(model.action_space, "n"):
|
||||
return int(model.action_space.n)
|
||||
action_net = getattr(getattr(model, "policy", None), "action_net", None)
|
||||
out_features = getattr(action_net, "out_features", None)
|
||||
if out_features is not None:
|
||||
return int(out_features)
|
||||
return None
|
||||
|
||||
|
||||
def resolve_movement(movement_arg: str, model: PPO) -> str:
|
||||
model_action_n = _model_action_count(model)
|
||||
|
||||
if movement_arg == "auto":
|
||||
if model_action_n is not None:
|
||||
for candidate in ("right_only", "simple"):
|
||||
if _action_count_for_movement(candidate) == model_action_n:
|
||||
return candidate
|
||||
return "right_only"
|
||||
|
||||
if model_action_n is not None:
|
||||
expected_n = _action_count_for_movement(movement_arg)
|
||||
if expected_n != model_action_n:
|
||||
raise ValueError(
|
||||
f"movement='{movement_arg}' has {expected_n} actions, but model expects {model_action_n}. "
|
||||
"Use --movement auto or pass the matching movement."
|
||||
)
|
||||
return movement_arg
|
||||
|
||||
|
||||
def resolve_model(user_path: str) -> Path:
|
||||
if user_path:
|
||||
p = Path(user_path).expanduser().resolve()
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(f"Model not found: {p}")
|
||||
return p
|
||||
|
||||
paths = ensure_artifact_paths()
|
||||
latest = latest_model_path(paths.models)
|
||||
if latest is None:
|
||||
raise FileNotFoundError("No model found under artifacts/models. Please run training first.")
|
||||
return latest
|
||||
|
||||
|
||||
def resolve_output_path(user_output: str) -> Path:
|
||||
if user_output:
|
||||
return Path(user_output).expanduser().resolve()
|
||||
|
||||
paths = ensure_artifact_paths()
|
||||
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return (paths.videos / f"mario_replay_{ts}.mp4").resolve()
|
||||
|
||||
|
||||
def save_frames_fallback(frames, output_path: Path) -> Path:
|
||||
frame_dir = output_path.with_suffix("")
|
||||
frame_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, frame in enumerate(frames):
|
||||
imageio.imwrite(frame_dir / f"frame_{i:06d}.png", frame)
|
||||
return frame_dir
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
seed_everything(args.seed)
|
||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||
|
||||
model_path = resolve_model(args.model_path)
|
||||
output_path = resolve_output_path(args.output)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = PPO.load(str(model_path))
|
||||
movement = resolve_movement(args.movement, model)
|
||||
print(f"[video] movement={movement} reward_mode={reward_mode}")
|
||||
env = make_mario_env(
|
||||
env_id=args.env_id,
|
||||
seed=args.seed,
|
||||
movement=movement,
|
||||
reward_mode=reward_mode,
|
||||
clip_reward=args.clip_reward,
|
||||
frame_skip=args.frame_skip,
|
||||
render_mode="rgb_array",
|
||||
progress_scale=args.progress_scale,
|
||||
death_penalty=args.death_penalty,
|
||||
flag_bonus=args.flag_bonus,
|
||||
stall_penalty=args.stall_penalty,
|
||||
stall_steps=args.stall_steps,
|
||||
backward_penalty_scale=args.backward_penalty_scale,
|
||||
milestone_interval=args.milestone_interval,
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
)
|
||||
|
||||
obs, _ = reset_compat(env, seed=args.seed)
|
||||
frames = []
|
||||
|
||||
first_frame = env.render()
|
||||
if first_frame is not None:
|
||||
# nes-py may reuse the same frame buffer; copy to avoid aliasing all frames.
|
||||
frames.append(first_frame.copy())
|
||||
|
||||
target_frames = max(1, args.fps * args.duration_sec)
|
||||
done = False
|
||||
step_count = 0
|
||||
reward_sum = 0.0
|
||||
episode_count = 1
|
||||
|
||||
while len(frames) < target_frames and step_count < args.max_steps:
|
||||
action, _ = model.predict(obs, deterministic=not args.stochastic)
|
||||
if isinstance(action, np.ndarray):
|
||||
action = int(action.item())
|
||||
obs, reward, terminated, truncated, info = step_compat(env, action)
|
||||
reward_sum += float(reward)
|
||||
step_count += 1
|
||||
done = terminated or truncated
|
||||
|
||||
frame = env.render()
|
||||
if frame is not None:
|
||||
frames.append(frame.copy())
|
||||
|
||||
if done:
|
||||
obs, _ = reset_compat(env, seed=args.seed + episode_count)
|
||||
episode_count += 1
|
||||
frame = env.render()
|
||||
if frame is not None and len(frames) < target_frames:
|
||||
frames.append(frame.copy())
|
||||
done = False
|
||||
|
||||
if not frames:
|
||||
raise RuntimeError("No frames captured. Check render_mode support and environment setup.")
|
||||
|
||||
try:
|
||||
writer = imageio.get_writer(str(output_path), fps=args.fps, codec="libx264", quality=8)
|
||||
for frame in frames:
|
||||
writer.append_data(frame)
|
||||
writer.close()
|
||||
print(f"[video] Saved mp4: {output_path}")
|
||||
except Exception as exc:
|
||||
frame_dir = save_frames_fallback(frames, output_path)
|
||||
print(f"[warn] mp4 write failed: {exc}")
|
||||
print(f"[fallback] Saved frame sequence: {frame_dir}")
|
||||
print(
|
||||
"[fallback] Convert frames to mp4 with: "
|
||||
f"ffmpeg -framerate {args.fps} -i {frame_dir}/frame_%06d.png -c:v libx264 -pix_fmt yuv420p {output_path}"
|
||||
)
|
||||
|
||||
print(
|
||||
f"[stats] frames={len(frames)} approx_sec={len(frames)/max(args.fps,1):.2f} "
|
||||
f"steps={step_count} reward_sum={reward_sum:.2f} episodes={episode_count}"
|
||||
)
|
||||
env.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
300
mario-rl-mvp/src/train_ppo.py
Normal file
300
mario-rl-mvp/src/train_ppo.py
Normal file
@@ -0,0 +1,300 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import BaseCallback, CallbackList, CheckpointCallback
|
||||
from stable_baselines3.common.save_util import load_from_zip_file
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from src.env import make_env_fn
|
||||
from src.utils import ensure_artifact_paths, resolve_torch_device, seed_everything, write_latest_pointer
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Train PPO agent on NES Super Mario Bros.")
|
||||
parser.add_argument("--env-id", type=str, default="SuperMarioBros-1-1-v0")
|
||||
parser.add_argument("--movement", type=str, default="right_only", choices=["right_only", "simple"])
|
||||
parser.add_argument("--total-timesteps", type=int, default=1_000_000)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument("--n-envs", type=int, default=4)
|
||||
parser.add_argument("--frame-skip", type=int, default=4)
|
||||
parser.add_argument("--reward-mode", type=str, default="raw", choices=["raw", "clip", "progress"])
|
||||
parser.add_argument("--progress-scale", type=float, default=0.02)
|
||||
parser.add_argument("--death-penalty", type=float, default=-50.0)
|
||||
parser.add_argument("--flag-bonus", type=float, default=100.0)
|
||||
parser.add_argument("--stall-penalty", type=float, default=0.05)
|
||||
parser.add_argument("--stall-steps", type=int, default=40)
|
||||
parser.add_argument("--backward-penalty-scale", type=float, default=0.01)
|
||||
parser.add_argument("--milestone-interval", type=int, default=32)
|
||||
parser.add_argument("--milestone-bonus", type=float, default=1.0)
|
||||
parser.add_argument("--no-progress-terminate-steps", type=int, default=120)
|
||||
parser.add_argument("--no-progress-terminate-penalty", type=float, default=20.0)
|
||||
|
||||
parser.add_argument("--learning-rate", type=float, default=2.5e-4)
|
||||
parser.add_argument("--n-steps", type=int, default=128)
|
||||
parser.add_argument("--batch-size", type=int, default=256)
|
||||
parser.add_argument("--n-epochs", type=int, default=4)
|
||||
parser.add_argument("--gamma", type=float, default=0.99)
|
||||
parser.add_argument("--gae-lambda", type=float, default=0.95)
|
||||
parser.add_argument("--ent-coef", type=float, default=0.01)
|
||||
parser.add_argument("--clip-range", type=float, default=0.2)
|
||||
|
||||
parser.add_argument("--save-freq", type=int, default=50_000, help="Checkpoint frequency in env steps.")
|
||||
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cpu", "mps"])
|
||||
parser.add_argument("--clip-reward", action="store_true", help="Enable reward clipping to {-1,0,1}.")
|
||||
parser.add_argument("--run-name", type=str, default="", help="Optional custom run name.")
|
||||
parser.add_argument(
|
||||
"--init-model-path",
|
||||
type=str,
|
||||
default="",
|
||||
help="Optional .zip model path to initialize weights before training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow-partial-init",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Allow partial init when checkpoint is from a different action space "
|
||||
"(e.g. right_only -> simple). Compatible policy layers will be loaded, "
|
||||
"incompatible layers (like action head) are skipped."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--progress-bar", action="store_true", help="Enable SB3 progress bar (requires tqdm/rich).")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_partial_policy_weights(model: PPO, init_model_path: Path, device: str) -> tuple[int, int]:
|
||||
_, params, _ = load_from_zip_file(str(init_model_path), device=device)
|
||||
if params is None or "policy" not in params:
|
||||
raise RuntimeError(f"No policy parameters found in checkpoint: {init_model_path}")
|
||||
|
||||
src_state = params["policy"]
|
||||
dst_state = model.policy.state_dict()
|
||||
|
||||
compatible_state = {}
|
||||
skipped_count = 0
|
||||
for key, value in src_state.items():
|
||||
if key in dst_state and dst_state[key].shape == value.shape:
|
||||
compatible_state[key] = value
|
||||
else:
|
||||
skipped_count += 1
|
||||
|
||||
if not compatible_state:
|
||||
raise RuntimeError("No compatible policy tensors found for partial init.")
|
||||
|
||||
model.policy.load_state_dict(compatible_state, strict=False)
|
||||
return len(compatible_state), skipped_count
|
||||
|
||||
|
||||
class EpisodeEndLoggingCallback(BaseCallback):
|
||||
"""Log per-episode terminal diagnostics to stdout and TensorBoard."""
|
||||
|
||||
REASON_TO_CODE = {"death": 0, "no_progress": 1, "timeout": 2, "clear": 3}
|
||||
|
||||
def __init__(self, n_envs: int, tb_log_dir: Path):
|
||||
super().__init__(verbose=0)
|
||||
self.n_envs = n_envs
|
||||
self.tb_log_dir = tb_log_dir
|
||||
self.tb_writer: SummaryWriter | None = None
|
||||
self.episode_count = 0
|
||||
self.episode_max_x_pos = [0.0 for _ in range(max(n_envs, 1))]
|
||||
|
||||
@staticmethod
|
||||
def _resolve_done_reason(info: dict) -> str:
|
||||
if bool(info.get("flag_get", False)):
|
||||
return "clear"
|
||||
if bool(info.get("terminated_by_stall", False)):
|
||||
return "no_progress"
|
||||
if bool(info.get("TimeLimit.truncated", False)):
|
||||
return "timeout"
|
||||
try:
|
||||
if float(info.get("time", 1.0)) <= 0.0:
|
||||
return "timeout"
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return "death"
|
||||
|
||||
def _on_training_start(self) -> None:
|
||||
self.tb_log_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.tb_writer = SummaryWriter(log_dir=str(self.tb_log_dir))
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
infos = self.locals.get("infos")
|
||||
dones = self.locals.get("dones")
|
||||
if infos is None or dones is None:
|
||||
return True
|
||||
|
||||
for env_idx, (info, done) in enumerate(zip(infos, dones)):
|
||||
if env_idx >= len(self.episode_max_x_pos):
|
||||
continue
|
||||
x_pos = float(info.get("x_pos", 0.0))
|
||||
if x_pos > self.episode_max_x_pos[env_idx]:
|
||||
self.episode_max_x_pos[env_idx] = x_pos
|
||||
|
||||
if not bool(done):
|
||||
continue
|
||||
|
||||
self.episode_count += 1
|
||||
max_x_pos = self.episode_max_x_pos[env_idx]
|
||||
flag_get = 1.0 if bool(info.get("flag_get", False)) else 0.0
|
||||
done_reason = self._resolve_done_reason(info)
|
||||
done_reason_code = float(self.REASON_TO_CODE[done_reason])
|
||||
episode_step = self.episode_count
|
||||
|
||||
self.logger.record_mean("rollout/episode_max_x_pos", max_x_pos)
|
||||
self.logger.record_mean("rollout/flag_get", flag_get)
|
||||
self.logger.record_mean(f"rollout/done_reason_{done_reason}", 1.0)
|
||||
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.add_scalar("episode_end/episode_max_x_pos", max_x_pos, episode_step)
|
||||
self.tb_writer.add_scalar("episode_end/flag_get", flag_get, episode_step)
|
||||
self.tb_writer.add_scalar("episode_end/done_reason_code", done_reason_code, episode_step)
|
||||
self.tb_writer.add_scalar("episode_end/done_reason_death", 1.0 if done_reason == "death" else 0.0, episode_step)
|
||||
self.tb_writer.add_scalar(
|
||||
"episode_end/done_reason_no_progress", 1.0 if done_reason == "no_progress" else 0.0, episode_step
|
||||
)
|
||||
self.tb_writer.add_scalar("episode_end/done_reason_timeout", 1.0 if done_reason == "timeout" else 0.0, episode_step)
|
||||
self.tb_writer.add_scalar("episode_end/done_reason_clear", 1.0 if done_reason == "clear" else 0.0, episode_step)
|
||||
self.tb_writer.flush()
|
||||
|
||||
print(
|
||||
f"[episode_end] ep={episode_step} env={env_idx} reason={done_reason} "
|
||||
f"max_x={max_x_pos:.1f} flag_get={bool(flag_get)}"
|
||||
)
|
||||
self.episode_max_x_pos[env_idx] = 0.0
|
||||
|
||||
return True
|
||||
|
||||
def _on_training_end(self) -> None:
|
||||
if self.tb_writer is not None:
|
||||
self.tb_writer.close()
|
||||
self.tb_writer = None
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
seed_everything(args.seed)
|
||||
reward_mode = "clip" if args.clip_reward and args.reward_mode == "raw" else args.reward_mode
|
||||
|
||||
paths = ensure_artifact_paths()
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
env_slug = args.env_id.replace("/", "_")
|
||||
run_name = args.run_name.strip() or f"ppo_{env_slug}_{timestamp}"
|
||||
|
||||
run_model_dir = paths.models / run_name
|
||||
run_log_dir = paths.logs / run_name
|
||||
run_tb_dir = run_log_dir / "tb"
|
||||
run_model_dir.mkdir(parents=True, exist_ok=True)
|
||||
run_tb_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device, device_msg = resolve_torch_device(args.device)
|
||||
print(f"[device] {device} | {device_msg}")
|
||||
print(f"[run] {run_name}")
|
||||
print(
|
||||
f"[config] env_id={args.env_id}, movement={args.movement}, reward_mode={reward_mode}, "
|
||||
f"clip_reward={args.clip_reward}, n_envs={args.n_envs}"
|
||||
)
|
||||
|
||||
env_fns = [
|
||||
make_env_fn(
|
||||
env_id=args.env_id,
|
||||
seed=args.seed + i,
|
||||
movement=args.movement,
|
||||
reward_mode=reward_mode,
|
||||
clip_reward=args.clip_reward,
|
||||
frame_skip=args.frame_skip,
|
||||
progress_scale=args.progress_scale,
|
||||
death_penalty=args.death_penalty,
|
||||
flag_bonus=args.flag_bonus,
|
||||
stall_penalty=args.stall_penalty,
|
||||
stall_steps=args.stall_steps,
|
||||
backward_penalty_scale=args.backward_penalty_scale,
|
||||
milestone_interval=args.milestone_interval,
|
||||
milestone_bonus=args.milestone_bonus,
|
||||
no_progress_terminate_steps=args.no_progress_terminate_steps,
|
||||
no_progress_terminate_penalty=args.no_progress_terminate_penalty,
|
||||
)
|
||||
for i in range(args.n_envs)
|
||||
]
|
||||
vec_env = DummyVecEnv(env_fns)
|
||||
vec_env = VecMonitor(vec_env, filename=str(run_log_dir / "monitor.csv"))
|
||||
|
||||
model = PPO(
|
||||
policy="CnnPolicy",
|
||||
env=vec_env,
|
||||
learning_rate=args.learning_rate,
|
||||
n_steps=args.n_steps,
|
||||
batch_size=args.batch_size,
|
||||
n_epochs=args.n_epochs,
|
||||
gamma=args.gamma,
|
||||
gae_lambda=args.gae_lambda,
|
||||
ent_coef=args.ent_coef,
|
||||
clip_range=args.clip_range,
|
||||
tensorboard_log=str(run_tb_dir),
|
||||
seed=args.seed,
|
||||
verbose=1,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if args.init_model_path.strip():
|
||||
init_model_path = Path(args.init_model_path).expanduser().resolve()
|
||||
if not init_model_path.exists():
|
||||
raise FileNotFoundError(f"Init model not found: {init_model_path}")
|
||||
try:
|
||||
model.set_parameters(str(init_model_path), exact_match=False, device=device)
|
||||
print(f"[init] Loaded initial weights from: {init_model_path}")
|
||||
except RuntimeError as exc:
|
||||
if not args.allow_partial_init:
|
||||
raise RuntimeError(
|
||||
f"{exc}\n"
|
||||
"Hint: checkpoint and current env may use different action spaces. "
|
||||
"Try one of:\n"
|
||||
"1) keep the same --movement as the checkpoint;\n"
|
||||
"2) remove --init-model-path and train from scratch;\n"
|
||||
"3) add --allow-partial-init to load compatible layers only."
|
||||
) from exc
|
||||
loaded_count, skipped_count = load_partial_policy_weights(model, init_model_path, device=device)
|
||||
print(
|
||||
f"[init] Partial init from {init_model_path}: "
|
||||
f"loaded_tensors={loaded_count}, skipped_tensors={skipped_count}"
|
||||
)
|
||||
|
||||
callback = CheckpointCallback(
|
||||
save_freq=max(args.save_freq // max(args.n_envs, 1), 1),
|
||||
save_path=str(run_model_dir),
|
||||
name_prefix="ppo_mario_ckpt",
|
||||
save_replay_buffer=False,
|
||||
save_vecnormalize=False,
|
||||
)
|
||||
episode_end_logging_callback = EpisodeEndLoggingCallback(
|
||||
n_envs=args.n_envs,
|
||||
tb_log_dir=run_tb_dir / "episode_end",
|
||||
)
|
||||
|
||||
try:
|
||||
model.learn(
|
||||
total_timesteps=args.total_timesteps,
|
||||
callback=CallbackList([callback, episode_end_logging_callback]),
|
||||
tb_log_name="ppo",
|
||||
progress_bar=args.progress_bar,
|
||||
)
|
||||
finally:
|
||||
vec_env.close()
|
||||
|
||||
final_model = run_model_dir / "final_model"
|
||||
model.save(str(final_model))
|
||||
|
||||
latest_model = paths.models / "latest_model.zip"
|
||||
shutil.copy2(str(final_model) + ".zip", latest_model)
|
||||
pointer = write_latest_pointer(paths.models, latest_model)
|
||||
|
||||
print(f"[done] Final model: {final_model}.zip")
|
||||
print(f"[done] Latest model pointer: {pointer}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
75
mario-rl-mvp/src/utils.py
Normal file
75
mario-rl-mvp/src/utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ArtifactPaths:
|
||||
root: Path
|
||||
models: Path
|
||||
videos: Path
|
||||
logs: Path
|
||||
|
||||
|
||||
def project_root() -> Path:
|
||||
return Path(__file__).resolve().parents[1]
|
||||
|
||||
|
||||
def ensure_artifact_paths(root: Optional[Path] = None) -> ArtifactPaths:
|
||||
root = root or project_root()
|
||||
artifacts = root / "artifacts"
|
||||
models = artifacts / "models"
|
||||
videos = artifacts / "videos"
|
||||
logs = artifacts / "logs"
|
||||
for p in (artifacts, models, videos, logs):
|
||||
p.mkdir(parents=True, exist_ok=True)
|
||||
return ArtifactPaths(root=artifacts, models=models, videos=videos, logs=logs)
|
||||
|
||||
|
||||
def seed_everything(seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def resolve_torch_device(requested: str = "auto") -> Tuple[str, str]:
|
||||
requested = requested.lower().strip()
|
||||
if requested == "cpu":
|
||||
return "cpu", "User requested CPU."
|
||||
|
||||
if requested not in {"auto", "mps"}:
|
||||
return "cpu", f"Unknown device '{requested}', fallback to CPU."
|
||||
|
||||
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
||||
if not mps_available:
|
||||
if requested == "mps":
|
||||
return "cpu", "MPS requested but unavailable, fallback to CPU."
|
||||
return "cpu", "MPS unavailable, using CPU."
|
||||
|
||||
try:
|
||||
x = torch.ones(8, device="mps")
|
||||
_ = (x * 2).cpu().numpy()
|
||||
return "mps", "MPS is available and passed a quick tensor sanity check."
|
||||
except Exception as exc: # pragma: no cover - hardware dependent
|
||||
return "cpu", f"MPS check failed ({exc}), fallback to CPU."
|
||||
|
||||
|
||||
def latest_model_path(models_dir: Path) -> Optional[Path]:
|
||||
candidates = [p for p in models_dir.rglob("*.zip") if p.is_file()]
|
||||
if not candidates:
|
||||
return None
|
||||
return max(candidates, key=lambda p: p.stat().st_mtime)
|
||||
|
||||
|
||||
def write_latest_pointer(models_dir: Path, model_path: Path) -> Path:
|
||||
pointer = models_dir / "LATEST_MODEL.txt"
|
||||
pointer.write_text(str(model_path.resolve()) + "\n", encoding="utf-8")
|
||||
return pointer
|
||||
Reference in New Issue
Block a user