Gym Interface

import gymnasium as gym
import matplotlib.pyplot as plt
import newton
import numpy as np
import warp as wp
from gymnasium import spaces
from lwmr.utils import create_viewer_viser
from tqdm.auto import trange

wp.config.quiet = True
/Users/ajcd2020/Documents/Repositories/anthonyjclark/simer-tutorial/2026-icra/.venv/lib/python3.14/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
class RotatingCubeEnv(gym.Env):
    metadata = {"render_modes": ["viser", "none"], "render_fps": 60}

    def __init__(
        self,
        render_mode=None,
        frame_step=1.0 / 60.0,
        sim_substeps=4,
        max_episode_steps=400,
        max_target_velocity=10.0,
        desired_velocity=8.0,
        target_kd=100.0,
        use_graph=False,
        seed=None,
    ):
        super().__init__()

        self.render_mode = render_mode
        self.frame_step = float(frame_step)
        self.sim_substeps = int(sim_substeps)
        self.time_step = self.frame_step / self.sim_substeps

        self.max_episode_steps = int(max_episode_steps)
        self.max_target_velocity = float(max_target_velocity)
        self.desired_velocity = float(desired_velocity)
        self.target_kd = float(target_kd)
        self.use_graph = bool(use_graph)

        self.sim_time = 0.0
        self.step_count = 0

        self.viewer = None
        self.graph = None

        self.np_random = None
        self._seed = seed

        # One normalized velocity command.
        self.action_space = spaces.Box(
            low=np.array([-1.0], dtype=np.float32),
            high=np.array([1.0], dtype=np.float32),
            dtype=np.float32,
        )

        # joint position, joint velocity
        self.observation_space = spaces.Box(
            low=np.array([-np.inf, -np.inf], dtype=np.float32),
            high=np.array([np.inf, np.inf], dtype=np.float32),
            dtype=np.float32,
        )

        self._build_model()

    # ---------------------------------------------------------------------
    # region Model
    # ---------------------------------------------------------------------
    def _build_model(self):
        builder = newton.ModelBuilder()

        builder.add_ground_plane()

        xform = wp.transform(p=wp.vec3(0.0, 0.0, 1.0))

        body = builder.add_link(mass=1.0)

        joint = builder.add_joint_revolute(
            parent=-1,
            child=body,
            parent_xform=xform,
            axis=wp.vec3(0.0, 0.0, 1.0),
            actuator_mode=newton.JointTargetMode.VELOCITY,
            target_kd=self.target_kd,
        )

        builder.add_articulation([joint])
        builder.add_shape_box(body)

        # self.builder = builder
        # self.body = body
        # self.joint = joint

        self.model = builder.finalize()

        self.state_0 = self.model.state()
        self.state_1 = self.model.state()
        self.control = self.model.control()
        self.contacts = self.model.contacts()

        self.solver = newton.solvers.SolverMuJoCo(self.model)

        self.joint_index = builder.joint_qd_start[joint]

        self.joint_target_vels = np.zeros(self.control.joint_target_vel.shape, dtype=np.float32)

        if self.use_graph and self.model.device.is_cuda and wp.get_device().is_cuda:
            with wp.ScopedCapture() as capture:
                self._simulate()
            self.graph = capture.graph
        else:
            self.graph = None

        if self.render_mode == "viser":
            # self.viewer = create_viewer("gym_cube", self.model)
            self.viewer = create_viewer_viser("spinning_cube", self.model, quiet=False, overwrite=False)

    # ---------------------------------------------------------------------
    # region Gym
    # ---------------------------------------------------------------------
    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)

        if self.step_count == 0:
            return self._get_obs(), {}

        if seed is not None:
            self._seed = seed

        self.sim_time = 0.0
        self.step_count = 0

        self._build_model()

        obs = self._get_obs()
        info = {}

        return obs, info

    def step(self, action):
        action = np.asarray(action, dtype=np.float32)
        action = np.clip(action, self.action_space.low, self.action_space.high)  # type: ignore

        target_vel = float(action[0]) * self.max_target_velocity
        self._set_target_velocity(target_vel)

        if self.graph is not None:
            wp.capture_launch(self.graph)
        else:
            self._simulate()

        self.sim_time += self.frame_step
        self.step_count += 1

        obs = self._get_obs()

        reward, reward_info = self._compute_reward(action=action, target_vel=target_vel, obs=obs)

        terminated = False
        truncated = self.step_count >= self.max_episode_steps

        info = {
            "sim_time": self.sim_time,
            "step_count": self.step_count,
            "target_velocity": target_vel,
            **reward_info,
        }

        if self.render_mode == "viser":
            self.render()

        return obs, reward, terminated, truncated, info

    def render(self):
        if self.viewer:
            self.viewer.begin_frame(self.sim_time)
            self.viewer.log_state(self.state_0)
            self.viewer.end_frame()

    def close(self):
        if self.viewer:
            self.viewer.close()

    # ---------------------------------------------------------------------
    # Simulation helpers
    # ---------------------------------------------------------------------
    def _simulate(self):
        for _ in range(self.sim_substeps):
            self.state_0.clear_forces()
            self.model.collide(self.state_0, self.contacts)
            self.solver.step(self.state_0, self.state_1, self.control, self.contacts, self.time_step)
            self.state_0, self.state_1 = self.state_1, self.state_0

    def _set_target_velocity(self, target_vel):
        self.joint_target_vels.fill(0.0)
        self.joint_target_vels[self.joint_index] = target_vel
        self.control.joint_target_vel.assign(self.joint_target_vels)

    def _get_obs(self):
        joint_q = self.state_0.joint_q.numpy()
        joint_qd = self.state_0.joint_qd.numpy()

        q = float(joint_q[self.joint_index]) if joint_q.size > 0 else 0.0
        qd = float(joint_qd[self.joint_index]) if joint_qd.size > 0 else 0.0

        obs = np.array([q, qd], dtype=np.float32)

        if not np.all(np.isfinite(obs)):
            raise RuntimeError(f"Non-finite observation: {obs}")

        return obs

    def _compute_reward(self, action, target_vel, obs):
        q, qd = float(obs[0]), float(obs[1])

        # Velocity tracking reward.
        vel_error = qd - self.desired_velocity

        # Bounded reward around desired velocity.
        tracking_reward = -0.01 * vel_error * vel_error

        # Small action penalty.
        action_penalty = -0.001 * float(np.sum(np.square(action)))

        reward = tracking_reward + action_penalty

        info = {
            "joint_position": q,
            "joint_velocity": qd,
            "desired_velocity": self.desired_velocity,
            "velocity_error": vel_error,
            "tracking_reward": tracking_reward,
            "action_penalty": action_penalty,
            "reward": reward,
        }

        return float(reward), info
env = RotatingCubeEnv(render_mode="viser")

obs, info = env.reset()

obs_history = [obs.copy()]

time = 0.0
for _ in trange(400):
    vel = np.sin(2 * np.pi * 0.5 * time) * 0.5 + 0.5
    action = np.array([vel], dtype=np.float32)

    obs, reward, terminated, truncated, info = env.step(action)
    obs_history.append(obs.copy())
    time += env.frame_step

    if terminated or truncated:
        obs, info = env.reset()

# env.close()
env.unwrapped.viewer.show_notebook()  # type: ignore
Recording to docs/_static/spinning_cube_04.viser...
╭────── viser (listening *:8080) ───────╮
│             ╷                         │
│   HTTP      │ http://localhost:8080   │
│   Websocket │ ws://localhost:8080     │
│             ╵                         │
╰───────────────────────────────────────╯

  0%|          | 0/400 [00:00<?, ?it/s]
  0%|          | 1/400 [00:00<01:23,  4.79it/s]
  4%|▍         | 15/400 [00:00<00:06, 58.68it/s]
  7%|▋         | 29/400 [00:00<00:04, 87.50it/s]
 11%|█         | 43/400 [00:00<00:03, 104.64it/s]
 14%|█▍        | 57/400 [00:00<00:02, 115.28it/s]
 18%|█▊        | 70/400 [00:00<00:02, 119.47it/s]
 21%|██        | 84/400 [00:00<00:02, 124.65it/s]
 24%|██▍       | 98/400 [00:00<00:02, 128.29it/s]
 28%|██▊       | 112/400 [00:01<00:02, 130.12it/s]
 32%|███▏      | 126/400 [00:01<00:02, 131.40it/s]
 35%|███▌      | 140/400 [00:01<00:01, 133.16it/s]
 38%|███▊      | 154/400 [00:01<00:01, 134.16it/s]
 42%|████▏     | 168/400 [00:01<00:01, 134.65it/s]
 46%|████▌     | 182/400 [00:01<00:01, 135.42it/s]
 49%|████▉     | 196/400 [00:01<00:01, 134.62it/s]
 52%|█████▎    | 210/400 [00:01<00:01, 134.16it/s]
 56%|█████▌    | 224/400 [00:01<00:01, 134.46it/s]
 60%|█████▉    | 238/400 [00:01<00:01, 135.11it/s]
 63%|██████▎   | 252/400 [00:02<00:01, 135.61it/s]
 66%|██████▋   | 266/400 [00:02<00:00, 135.16it/s]
 70%|███████   | 280/400 [00:02<00:00, 135.84it/s]
 74%|███████▎  | 294/400 [00:02<00:00, 136.19it/s]
 77%|███████▋  | 308/400 [00:02<00:00, 136.41it/s]
 80%|████████  | 322/400 [00:02<00:00, 136.59it/s]
 84%|████████▍ | 336/400 [00:02<00:00, 136.20it/s]
 88%|████████▊ | 350/400 [00:02<00:00, 136.47it/s]
 91%|█████████ | 364/400 [00:02<00:00, 135.87it/s]
 94%|█████████▍| 378/400 [00:02<00:00, 136.12it/s]
 98%|█████████▊| 392/400 [00:03<00:00, 136.33it/s]
Recording to docs/_static/spinning_cube_04.viser...
╭────── viser (listening *:8081) ───────╮
│             ╷                         │
│   HTTP      │ http://localhost:8081   │
│   Websocket │ ws://localhost:8081     │
│             ╵                         │
╰───────────────────────────────────────╯

100%|██████████| 400/400 [00:03<00:00, 124.72it/s]
plt.plot(np.array(obs_history)[:, 0], label="joint_position")
plt.plot(np.array(obs_history)[:, 1], label="joint_velocity")
plt.legend()
plt.title("Observation History")
plt.xlabel("Step");