diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 14df548019..1d58431e64 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -51,7 +51,6 @@ jobs: run: | pip install --upgrade pip setuptools wheel pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install -e '.[dev]' pynput - name: Get gstaichi version id: gstaichi_version diff --git a/examples/IPC_Solver/ipc_arm_cloth.py b/examples/IPC_Solver/ipc_arm_cloth.py index d2a7018f11..a8fe239940 100644 --- a/examples/IPC_Solver/ipc_arm_cloth.py +++ b/examples/IPC_Solver/ipc_arm_cloth.py @@ -14,46 +14,21 @@ ; - Roll Right (Rotate around X) u - Reset Scene space - Press to close gripper, release to open gripper -esc - Quit + +Plus all default viewer controls (press 'i' to see them) """ -import random -import threading import argparse -import numpy as np import csv import os from datetime import datetime -from pynput import keyboard -from scipy.spatial.transform import Rotation as R + +import numpy as np from huggingface_hub import snapshot_download +from scipy.spatial.transform import Rotation as R import genesis as gs - - -class KeyboardDevice: - def __init__(self): - self.pressed_keys = set() - self.lock = threading.Lock() - self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) - - def start(self): - self.listener.start() - - def stop(self): - self.listener.stop() - self.listener.join() - - def on_press(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.add(key) - - def on_release(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.discard(key) - - def get_cmd(self): - return self.pressed_keys +from genesis.ext.pyrender.interaction.keybindings import KeyAction, Keybind def build_scene(use_ipc=False, show_viewer=False, enable_ipc_gui=False): @@ -172,14 +147,14 @@ def build_scene(use_ipc=False, show_viewer=False, enable_ipc_gui=False): return scene, entities -def run_sim(scene, entities, clients, mode="interactive", trajectory_file=None): +def run_sim(scene, entities, mode="interactive", trajectory_file=None): robot = entities["robot"] target_entity = entities["target"] robot_init_pos = np.array([0.5, 0, 0.55]) robot_init_R = R.from_euler("y", np.pi) - target_pos = robot_init_pos.copy() - target_R = robot_init_R + target_pos = [robot_init_pos.copy()] # Use list for mutability in closures + target_R = [robot_init_R] # Use list for mutability in closures n_dofs = robot.n_dofs motors_dof = np.arange(n_dofs - 2) @@ -190,18 +165,97 @@ def run_sim(scene, entities, clients, mode="interactive", trajectory_file=None): trajectory = [] recording = mode == "record" + # Gripper state (use list for mutability in closures) + gripper_closed = [False] + + # Control parameters + dpos = 0.002 + drot = 0.01 + def reset_scene(): - nonlocal target_pos, target_R - target_pos = robot_init_pos.copy() - target_R = robot_init_R - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) - q = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat) + target_pos[0] = robot_init_pos.copy() + target_R[0] = robot_init_R + target_quat = target_R[0].as_quat(scalar_first=True) + target_entity.set_qpos(np.concatenate([target_pos[0], target_quat])) + q = robot.inverse_kinematics(link=ee_link, pos=target_pos[0], quat=target_quat) robot.set_qpos(q[:-2], motors_dof) # entities["cube"].set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) # entities["cube"].set_quat(R.from_euler("z", random.uniform(0, np.pi * 2)).as_quat(scalar_first=True)) + # Define movement callbacks + def move_forward(): + target_pos[0][0] -= dpos + + def move_backward(): + target_pos[0][0] += dpos + + def move_left(): + target_pos[0][1] -= dpos + + def move_right(): + target_pos[0][1] += dpos + + def move_up(): + target_pos[0][2] += dpos + + def move_down(): + target_pos[0][2] -= dpos + + def yaw_left(): + target_R[0] = R.from_euler("z", drot) * target_R[0] + + def yaw_right(): + target_R[0] = R.from_euler("z", -drot) * target_R[0] + + def pitch_up(): + target_R[0] = R.from_euler("y", drot) * target_R[0] + + def pitch_down(): + target_R[0] = R.from_euler("y", -drot) * target_R[0] + + def roll_left(): + target_R[0] = R.from_euler("x", drot) * target_R[0] + + def roll_right(): + target_R[0] = R.from_euler("x", -drot) * target_R[0] + + def close_gripper(): + gripper_closed[0] = True + + def open_gripper(): + gripper_closed[0] = False + + # Register keybindings (only for interactive and record modes) + if mode in ["interactive", "record"]: + from pyglet.window import key + + scene.viewer.register_keybinds( + ( + Keybind(key_code=key.UP, key_action=KeyAction.HOLD, name="move_forward", callback_func=move_forward), + Keybind( + key_code=key.DOWN, key_action=KeyAction.HOLD, name="move_backward", callback_func=move_backward + ), + Keybind(key_code=key.LEFT, key_action=KeyAction.HOLD, name="move_left", callback_func=move_left), + Keybind(key_code=key.RIGHT, key_action=KeyAction.HOLD, name="move_right", callback_func=move_right), + Keybind(key_code=key.N, key_action=KeyAction.HOLD, name="move_up", callback_func=move_up), + Keybind(key_code=key.M, key_action=KeyAction.HOLD, name="move_down", callback_func=move_down), + Keybind(key_code=key.J, key_action=KeyAction.HOLD, name="yaw_left", callback_func=yaw_left), + Keybind(key_code=key.K, key_action=KeyAction.HOLD, name="yaw_right", callback_func=yaw_right), + Keybind(key_code=key.I, key_action=KeyAction.HOLD, name="pitch_up", callback_func=pitch_up), + Keybind(key_code=key.O, key_action=KeyAction.HOLD, name="pitch_down", callback_func=pitch_down), + Keybind(key_code=key.L, key_action=KeyAction.HOLD, name="roll_left", callback_func=roll_left), + Keybind(key_code=key.SEMICOLON, key_action=KeyAction.HOLD, name="roll_right", callback_func=roll_right), + Keybind(key_code=key.U, key_action=KeyAction.HOLD, name="reset_scene", callback_func=reset_scene), + Keybind( + key_code=key.SPACE, key_action=KeyAction.PRESS, name="close_gripper", callback_func=close_gripper + ), + Keybind( + key_code=key.SPACE, key_action=KeyAction.RELEASE, name="open_gripper", callback_func=open_gripper + ), + ) + ) + # Load trajectory if in playback mode if mode == "playback": if not trajectory_file or not os.path.exists(trajectory_file): @@ -233,7 +287,7 @@ def reset_scene(): print(f"\nMode: {mode.upper()}") if mode == "record": - print("Recording trajectory... Press ESC to stop and save.") + print("Recording trajectory...") elif mode == "playback": print("Playing back trajectory...") @@ -249,99 +303,57 @@ def reset_scene(): print("l/;\t- Roll Left/Right (Rotate around X axis)") print("u\t- Reset Scene") print("space\t- Press to close gripper, release to open gripper") - print("esc\t- Quit") + if mode in ["interactive", "record"]: + print("\nPlus all default viewer controls (press 'i' to see them)") # reset scene before starting teleoperation reset_scene() # start teleoperation or playback - stop = False step_count = 0 - while not stop: - if mode == "playback": - # Playback mode: replay recorded trajectory - if step_count < len(trajectory): - step_data = trajectory[step_count] - target_pos = step_data["target_pos"] - target_R = R.from_quat(step_data["target_quat"]) - is_close_gripper = step_data["gripper_closed"] - step_count += 1 - print(f"\rPlayback step: {step_count}/{len(trajectory)}", end="") - # Check if user wants to stop playback - pressed_keys = clients["keyboard"].pressed_keys.copy() - stop = keyboard.Key.esc in pressed_keys + try: + while True: + if mode == "playback": + # Playback mode: replay recorded trajectory + if step_count < len(trajectory): + step_data = trajectory[step_count] + target_pos[0] = step_data["target_pos"] + target_R[0] = R.from_quat(step_data["target_quat"]) + gripper_closed[0] = step_data["gripper_closed"] + step_count += 1 + print(f"\rPlayback step: {step_count}/{len(trajectory)}", end="") + else: + print("\nPlayback finished!") + break else: - print("\nPlayback finished!") - break - else: - # Interactive or recording mode - pressed_keys = clients["keyboard"].pressed_keys.copy() - - # reset scene: - reset_flag = False - reset_flag |= keyboard.KeyCode.from_char("u") in pressed_keys - if reset_flag: - reset_scene() - - # stop teleoperation - stop = keyboard.Key.esc in pressed_keys - - # get ee target pose - is_close_gripper = False - dpos = 0.002 - drot = 0.01 - for key in pressed_keys: - if key == keyboard.Key.up: - target_pos[0] -= dpos - elif key == keyboard.Key.down: - target_pos[0] += dpos - elif key == keyboard.Key.right: - target_pos[1] += dpos - elif key == keyboard.Key.left: - target_pos[1] -= dpos - elif key == keyboard.KeyCode.from_char("n"): - target_pos[2] += dpos - elif key == keyboard.KeyCode.from_char("m"): - target_pos[2] -= dpos - elif key == keyboard.KeyCode.from_char("j"): - target_R = R.from_euler("z", drot) * target_R - elif key == keyboard.KeyCode.from_char("k"): - target_R = R.from_euler("z", -drot) * target_R - elif key == keyboard.KeyCode.from_char("i"): - target_R = R.from_euler("y", drot) * target_R - elif key == keyboard.KeyCode.from_char("o"): - target_R = R.from_euler("y", -drot) * target_R - elif key == keyboard.KeyCode.from_char("l"): - target_R = R.from_euler("x", drot) * target_R - elif key == keyboard.KeyCode.from_char(";"): - target_R = R.from_euler("x", -drot) * target_R - elif key == keyboard.Key.space: - is_close_gripper = True - - # Record current state if recording - if recording: - step_data = { - "target_pos": target_pos.copy(), - "target_quat": target_R.as_quat(), # x,y,z,w format - "gripper_closed": is_close_gripper, - "step": step_count, - } - trajectory.append(step_data) - - # control arm - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) - q, err = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat, return_error=True) - robot.control_dofs_position(q[:-2], motors_dof) - # control gripper - if is_close_gripper: - robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) - else: - robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) + # Interactive or recording mode + # Movement is handled by keybinding callbacks + # Record current state if recording + if recording: + step_data = { + "target_pos": target_pos[0].copy(), + "target_quat": target_R[0].as_quat(), # x,y,z,w format + "gripper_closed": gripper_closed[0], + "step": step_count, + } + trajectory.append(step_data) + + # control arm + target_quat = target_R[0].as_quat(scalar_first=True) + target_entity.set_qpos(np.concatenate([target_pos[0], target_quat])) + q, err = robot.inverse_kinematics(link=ee_link, pos=target_pos[0], quat=target_quat, return_error=True) + robot.control_dofs_position(q[:-2], motors_dof) + # control gripper + if gripper_closed[0]: + robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) + else: + robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) - scene.step() - step_count += 1 + scene.step() + step_count += 1 + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") # Save trajectory if recording if recording and len(trajectory) > 0: @@ -437,12 +449,8 @@ def main(): elif not os.path.isabs(trajectory_file): trajectory_file = os.path.join(traj_dir, os.path.basename(trajectory_file)) - clients = dict() - clients["keyboard"] = KeyboardDevice() - clients["keyboard"].start() - scene, entities = build_scene(use_ipc=args.ipc, show_viewer=args.vis, enable_ipc_gui=False) - run_sim(scene, entities, clients, mode=args.mode, trajectory_file=trajectory_file) + run_sim(scene, entities, mode=args.mode, trajectory_file=trajectory_file) if __name__ == "__main__": diff --git a/examples/drone/interactive_drone.py b/examples/drone/interactive_drone.py index cdc6713d6e..f18f353f10 100644 --- a/examples/drone/interactive_drone.py +++ b/examples/drone/interactive_drone.py @@ -1,11 +1,9 @@ import os -import time -import threading -from pynput import keyboard import numpy as np import genesis as gs +from genesis.ext.pyrender.interaction.keybindings import KeyAction, Keybind class DroneController: @@ -13,100 +11,57 @@ def __init__(self): self.thrust = 14475.8 # Base hover RPM - constant hover self.rotation_delta = 200.0 # Differential RPM for rotation self.thrust_delta = 10.0 # Amount to change thrust by when accelerating/decelerating - self.running = True self.rpms = [self.thrust] * 4 - self.pressed_keys = set() - - def on_press(self, key): - try: - if key == keyboard.Key.esc: - self.running = False - return False - self.pressed_keys.add(key) - print(f"Key pressed: {key}") - except AttributeError: - pass - - def on_release(self, key): - try: - self.pressed_keys.discard(key) - except KeyError: - pass def update_thrust(self): - # Store previous RPMs for debugging - prev_rpms = self.rpms.copy() - # Reset RPMs to hover thrust self.rpms = [self.thrust] * 4 + return self.rpms - # Acceleration (Spacebar) - All rotors spin faster - if keyboard.Key.space in self.pressed_keys: - self.thrust += self.thrust_delta - self.rpms = [self.thrust] * 4 - print("Accelerating") - - # Deceleration (Left Shift) - All rotors spin slower - if keyboard.Key.shift in self.pressed_keys: - self.thrust -= self.thrust_delta - self.rpms = [self.thrust] * 4 - print("Decelerating") - - # Forward (North) - Front rotors spin faster - if keyboard.Key.up in self.pressed_keys: - self.rpms[0] += self.rotation_delta # Front left - self.rpms[1] += self.rotation_delta # Front right - self.rpms[2] -= self.rotation_delta # Back left - self.rpms[3] -= self.rotation_delta # Back right - print("Moving Forward") - - # Backward (South) - Back rotors spin faster - if keyboard.Key.down in self.pressed_keys: - self.rpms[0] -= self.rotation_delta # Front left - self.rpms[1] -= self.rotation_delta # Front right - self.rpms[2] += self.rotation_delta # Back left - self.rpms[3] += self.rotation_delta # Back right - print("Moving Backward") - - # Left (West) - Left rotors spin faster - if keyboard.Key.left in self.pressed_keys: - self.rpms[0] -= self.rotation_delta # Front left - self.rpms[2] -= self.rotation_delta # Back left - self.rpms[1] += self.rotation_delta # Front right - self.rpms[3] += self.rotation_delta # Back right - print("Moving Left") - - # Right (East) - Right rotors spin faster - if keyboard.Key.right in self.pressed_keys: - self.rpms[0] += self.rotation_delta # Front left - self.rpms[2] += self.rotation_delta # Back left - self.rpms[1] -= self.rotation_delta # Front right - self.rpms[3] -= self.rotation_delta # Back right - print("Moving Right") - + def move_forward(self): + """Front rotors spin faster""" + self.rpms[0] += self.rotation_delta # Front left + self.rpms[1] += self.rotation_delta # Front right + self.rpms[2] -= self.rotation_delta # Back left + self.rpms[3] -= self.rotation_delta # Back right self.rpms = np.clip(self.rpms, 0, 25000) - # Debug print if any RPMs changed - if not np.array_equal(prev_rpms, self.rpms): - print(f"RPMs changed from {prev_rpms} to {self.rpms}") - - return self.rpms - + def move_backward(self): + """Back rotors spin faster""" + self.rpms[0] -= self.rotation_delta # Front left + self.rpms[1] -= self.rotation_delta # Front right + self.rpms[2] += self.rotation_delta # Back left + self.rpms[3] += self.rotation_delta # Back right + self.rpms = np.clip(self.rpms, 0, 25000) -def run_sim(scene, drone, controller): - while controller.running: - # Update drone with current RPMs - rpms = controller.update_thrust() - drone.set_propellels_rpm(rpms) + def move_left(self): + """Left rotors spin faster""" + self.rpms[0] -= self.rotation_delta # Front left + self.rpms[2] -= self.rotation_delta # Back left + self.rpms[1] += self.rotation_delta # Front right + self.rpms[3] += self.rotation_delta # Back right + self.rpms = np.clip(self.rpms, 0, 25000) - # Update physics - scene.step(refresh_visualizer=False) + def move_right(self): + """Right rotors spin faster""" + print("move right") + self.rpms[0] += self.rotation_delta # Front left + self.rpms[2] += self.rotation_delta # Back left + self.rpms[1] -= self.rotation_delta # Front right + self.rpms[3] -= self.rotation_delta # Back right + self.rpms = np.clip(self.rpms, 0, 25000) - # Limit simulation rate - time.sleep(1.0 / scene.viewer.max_FPS) + def accelerate(self): + """All rotors spin faster""" + self.thrust += self.thrust_delta + self.rpms = [self.thrust] * 4 + self.rpms = np.clip(self.rpms, 0, 25000) - if "PYTEST_VERSION" in os.environ: - break + def decelerate(self): + """All rotors spin slower""" + self.thrust -= self.thrust_delta + self.rpms = [self.thrust] * 4 + self.rpms = np.clip(self.rpms, 0, 25000) def main(): @@ -145,14 +100,36 @@ def main(): # Initialize controller controller = DroneController() - # Start keyboard listener. - # Note that instantiating the listener after building the scene causes segfault on MacOS. - listener = keyboard.Listener(on_press=controller.on_press, on_release=controller.on_release) - listener.start() - # Build scene scene.build() + # Register keybindings + from pyglet.window import key + + scene.viewer.register_keybinds( + ( + Keybind( + key_code=key.UP, key_action=KeyAction.HOLD, name="move_forward", callback_func=controller.move_forward + ), + Keybind( + key_code=key.DOWN, + key_action=KeyAction.HOLD, + name="move_backward", + callback_func=controller.move_backward, + ), + Keybind(key_code=key.LEFT, key_action=KeyAction.HOLD, name="move_left", callback_func=controller.move_left), + Keybind( + key_code=key.RIGHT, key_action=KeyAction.HOLD, name="move_right", callback_func=controller.move_right + ), + Keybind( + key_code=key.SPACE, key_action=KeyAction.HOLD, name="accelerate", callback_func=controller.accelerate + ), + Keybind( + key_code=key.LSHIFT, key_action=KeyAction.HOLD, name="decelerate", callback_func=controller.decelerate + ), + ) + ) + # Print control instructions print("\nDrone Controls:") print("↑ - Move Forward (North)") @@ -161,19 +138,25 @@ def main(): print("→ - Move Right (East)") print("space - Increase RPM") print("shift - Decrease RPM") - print("ESC - Quit\n") + print("\nPlus all default viewer controls (press 'i' to see them)\n") print("Initial hover RPM:", controller.thrust) - # Run simulation in another thread - threading.Thread(target=run_sim, args=(scene, drone, controller)).start() - if "PYTEST_VERSION" not in os.environ: - scene.viewer.run() - + # Run simulation try: - listener.stop() - except NotImplementedError: - # Dummy backend does not implement stop - pass + while True: + # Update drone with current RPMs + rpms = controller.update_thrust() + drone.set_propellels_rpm(rpms) + + # Update physics + scene.step() + + if "PYTEST_VERSION" in os.environ: + break + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") + finally: + gs.logger.info("Simulation finished.") if __name__ == "__main__": diff --git a/examples/keyboard_teleop.py b/examples/keyboard_teleop.py index d741b1f0b0..eaf59d3db4 100644 --- a/examples/keyboard_teleop.py +++ b/examples/keyboard_teleop.py @@ -11,48 +11,19 @@ u - Reset Scene space - Press to close gripper, release to open gripper esc - Quit + +Plus all default viewer controls (press 'i' to see them) """ -import os import random -import threading -import genesis as gs import numpy as np -from pynput import keyboard from scipy.spatial.transform import Rotation as R +import genesis as gs +from genesis.ext.pyrender.interaction.keybindings import KeyAction, Keybind -class KeyboardDevice: - def __init__(self): - self.pressed_keys = set() - self.lock = threading.Lock() - self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) - - def start(self): - self.listener.start() - - def stop(self): - try: - self.listener.stop() - except NotImplementedError: - # Dummy backend does not implement stop - pass - self.listener.join() - - def on_press(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.add(key) - - def on_release(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.discard(key) - - def get_cmd(self): - return self.pressed_keys - - -def build_scene(): +if __name__ == "__main__": ########################## init ########################## gs.init(precision="32", logging_level="info", backend=gs.cpu) np.set_printoptions(precision=7, suppress=True) @@ -80,19 +51,19 @@ def build_scene(): ) ########################## entities ########################## - entities = dict() - entities["plane"] = scene.add_entity( + plane = scene.add_entity( gs.morphs.Plane(), ) - entities["robot"] = scene.add_entity( + robot = scene.add_entity( material=gs.materials.Rigid(gravity_compensation=1), morph=gs.morphs.MJCF( file="xml/franka_emika_panda/panda.xml", euler=(0, 0, 0), ), ) - entities["cube"] = scene.add_entity( + + cube = scene.add_entity( material=gs.materials.Rigid(rho=300), morph=gs.morphs.Box( pos=(0.5, 0.0, 0.07), @@ -101,7 +72,7 @@ def build_scene(): surface=gs.surfaces.Default(color=(0.5, 1, 0.5)), ) - entities["target"] = scene.add_entity( + target = scene.add_entity( gs.morphs.Mesh( file="meshes/axis.obj", scale=0.15, @@ -113,115 +84,114 @@ def build_scene(): ########################## build ########################## scene.build() - return scene, entities - - -def run_sim(scene, entities, clients): - robot = entities["robot"] - target_entity = entities["target"] - + # Initialize robot control state robot_init_pos = np.array([0.5, 0, 0.55]) robot_init_R = R.from_euler("y", np.pi) - target_pos = robot_init_pos.copy() - target_R = robot_init_R + # Get DOF indices n_dofs = robot.n_dofs motors_dof = np.arange(n_dofs - 2) fingers_dof = np.arange(n_dofs - 2, n_dofs) ee_link = robot.get_link("hand") - def reset_scene(): - nonlocal target_pos, target_R - target_pos = robot_init_pos.copy() - target_R = robot_init_R - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) + # Initialize target pose + target_pos = robot_init_pos.copy() + target_R = [robot_init_R] # Use list to make it mutable in closures + + # Control parameters + dpos = 0.002 + drot = 0.01 + + # Helper function to reset robot + def reset_robot(): + """Reset robot and cube to initial positions.""" + target_pos[:] = robot_init_pos.copy() + target_R[0] = robot_init_R + target_quat = target_R[0].as_quat(scalar_first=True) + target.set_qpos(np.concatenate([target_pos, target_quat])) q = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat) robot.set_qpos(q[:-2], motors_dof) - entities["cube"].set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) - entities["cube"].set_quat(R.from_euler("z", random.uniform(0, np.pi * 2)).as_quat(scalar_first=True)) - - print("\nKeyboard Controls:") - print("↑\t- Move Forward (North)") - print("↓\t- Move Backward (South)") - print("←\t- Move Left (West)") - print("→\t- Move Right (East)") - print("n\t- Move Up") - print("m\t- Move Down") - print("j\t- Rotate Counterclockwise") - print("k\t- Rotate Clockwise") - print("u\t- Reset Scene") - print("space\t- Press to close gripper, release to open gripper") - print("esc\t- Quit") - - # reset scen before starting teleoperation - reset_scene() - - # start teleoperation - stop = False - while not stop: - pressed_keys = clients["keyboard"].pressed_keys.copy() - - # reset scene: - reset_flag = False - reset_flag |= keyboard.KeyCode.from_char("u") in pressed_keys - if reset_flag: - reset_scene() - - # stop teleoperation - stop = keyboard.Key.esc in pressed_keys - - # get ee target pose - is_close_gripper = False - dpos = 0.002 - drot = 0.01 - for key in pressed_keys: - if key == keyboard.Key.up: - target_pos[0] -= dpos - elif key == keyboard.Key.down: - target_pos[0] += dpos - elif key == keyboard.Key.right: - target_pos[1] += dpos - elif key == keyboard.Key.left: - target_pos[1] -= dpos - elif key == keyboard.KeyCode.from_char("n"): - target_pos[2] += dpos - elif key == keyboard.KeyCode.from_char("m"): - target_pos[2] -= dpos - elif key == keyboard.KeyCode.from_char("j"): - target_R = R.from_euler("z", drot) * target_R - elif key == keyboard.KeyCode.from_char("k"): - target_R = R.from_euler("z", -drot) * target_R - elif key == keyboard.Key.space: - is_close_gripper = True - - # control arm - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) - q, _err = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat, return_error=True) - robot.control_dofs_position(q[:-2], motors_dof) - - # control gripper - if is_close_gripper: - robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) - else: - robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) - - scene.step() - - if "PYTEST_VERSION" in os.environ: - break - - -def main(): - clients = dict() - clients["keyboard"] = KeyboardDevice() - clients["keyboard"].start() - - scene, entities = build_scene() - run_sim(scene, entities, clients) - + # Randomize cube position + cube.set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) + cube.set_quat(R.from_euler("z", random.uniform(0, np.pi * 2)).as_quat(scalar_first=True)) + + # Initialize robot pose + reset_robot() + + # Robot teleoperation callback functions + def move_forward(): + target_pos[0] -= dpos + + def move_backward(): + target_pos[0] += dpos + + def move_left(): + target_pos[1] -= dpos + + def move_right(): + target_pos[1] += dpos + + def move_up(): + target_pos[2] += dpos + + def move_down(): + target_pos[2] -= dpos + + def rotate_ccw(): + target_R[0] = R.from_euler("z", drot) * target_R[0] + + def rotate_cw(): + target_R[0] = R.from_euler("z", -drot) * target_R[0] + + def close_gripper(): + robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) + + def open_gripper(): + robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) + + # Register robot teleoperation keybindings + from pyglet.window import key + + scene.viewer.register_keybinds( + ( + Keybind(key_code=key.UP, key_action=KeyAction.HOLD, name="move_forward", callback_func=move_forward), + Keybind(key_code=key.DOWN, key_action=KeyAction.HOLD, name="move_backward", callback_func=move_backward), + Keybind(key_code=key.LEFT, key_action=KeyAction.HOLD, name="move_left", callback_func=move_left), + Keybind(key_code=key.RIGHT, key_action=KeyAction.HOLD, name="move_right", callback_func=move_right), + Keybind(key_code=key.N, key_action=KeyAction.HOLD, name="move_up", callback_func=move_up), + Keybind(key_code=key.M, key_action=KeyAction.HOLD, name="move_down", callback_func=move_down), + Keybind(key_code=key.J, key_action=KeyAction.HOLD, name="rotate_ccw", callback_func=rotate_ccw), + Keybind(key_code=key.K, key_action=KeyAction.HOLD, name="rotate_cw", callback_func=rotate_cw), + Keybind(key_code=key.U, key_action=KeyAction.HOLD, name="reset_scene", callback_func=reset_robot), + Keybind( + key_code=key.SPACE, + name="close_gripper", + callback_func=close_gripper, + key_action=KeyAction.PRESS, + ), + Keybind( + key_code=key.SPACE, + name="open_gripper", + callback_func=open_gripper, + key_action=KeyAction.RELEASE, + ), + ) + ) -if __name__ == "__main__": - main() + ########################## run simulation ########################## + try: + while True: + # Update target entity visualization + target_quat = target_R[0].as_quat(scalar_first=True) + target.set_qpos(np.concatenate([target_pos, target_quat])) + + # Control arm with inverse kinematics + q, err = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat, return_error=True) + robot.control_dofs_position(q[:-2], motors_dof) + + scene.step() + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") + finally: + gs.logger.info("Simulation finished.") diff --git a/examples/sensors/lidar_teleop.py b/examples/sensors/lidar_teleop.py index 5336114018..6f5f0c8623 100644 --- a/examples/sensors/lidar_teleop.py +++ b/examples/sensors/lidar_teleop.py @@ -1,28 +1,16 @@ import argparse import os -import threading import numpy as np import genesis as gs +from genesis.ext.pyrender.interaction.keybindings import KeyAction, Keybind from genesis.utils.geom import euler_to_quat -IS_PYNPUT_AVAILABLE = False -try: - from pynput import keyboard - - IS_PYNPUT_AVAILABLE = True -except ImportError: - pass - # Position and angle increments for keyboard teleop control KEY_DPOS = 0.1 KEY_DANGLE = 0.1 -# Movement when no keyboard control is available -MOVE_RADIUS = 1.0 -MOVE_RATE = 1.0 / 100.0 - # Number of obstacles to create in a ring around the robot NUM_CYLINDERS = 8 NUM_BOXES = 6 @@ -30,35 +18,6 @@ BOX_RING_RADIUS = 5.0 -class KeyboardDevice: - def __init__(self): - self.pressed_keys = set() - self.lock = threading.Lock() - self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) - - def start(self): - self.listener.start() - - def stop(self): - try: - self.listener.stop() - except NotImplementedError: - # Dummy backend does not implement stop - pass - self.listener.join() - - def on_press(self, key: "keyboard.Key"): - with self.lock: - self.pressed_keys.add(key) - - def on_release(self, key: "keyboard.Key"): - with self.lock: - self.pressed_keys.discard(key) - - def get_cmd(self): - return self.pressed_keys - - def main(): parser = argparse.ArgumentParser(description="Genesis LiDAR/Depth Camera Visualization with Keyboard Teleop") parser.add_argument("-B", "--n_envs", type=int, default=0, help="Number of environments to replicate") @@ -69,12 +28,6 @@ def main(): ) args = parser.parse_args() - if IS_PYNPUT_AVAILABLE: - kb = KeyboardDevice() - kb.start() - else: - print("Keyboard teleop is disabled since pynput is not installed. To install, run `pip install pynput`.") - gs.init(backend=gs.cpu if args.cpu else gs.gpu, precision="32", logging_level="info") scene = gs.Scene( @@ -169,17 +122,7 @@ def main(): scene.build(n_envs=args.n_envs) - if IS_PYNPUT_AVAILABLE: - # Avoid using same keys as interactive viewer keyboard controls - print("Keyboard Controls:") - print("[↑/↓/←/→]: Move XY") - print("[j/k]: Down/Up") - print("[n/m]: Roll CCW/CW") - print("[,/.]: Pitch Up/Down") - print("[o/p]: Yaw CCW/CW") - print("[\\]: Reset") - print("[esc]: Quit") - + # Initialize pose state init_pos = np.array([0.0, 0.0, 0.35], dtype=np.float32) init_euler = np.array([0.0, 0.0, 0.0], dtype=np.float32) @@ -193,48 +136,81 @@ def apply_pose_to_all_envs(pos_np: np.ndarray, quat_np: np.ndarray): robot.set_pos(pos_np) robot.set_quat(quat_np) + # Define control callbacks + def reset_pose(): + target_pos[:] = init_pos + target_euler[:] = init_euler + + def move_forward(): + target_pos[0] += KEY_DPOS + + def move_backward(): + target_pos[0] -= KEY_DPOS + + def move_right(): + target_pos[1] -= KEY_DPOS + + def move_left(): + target_pos[1] += KEY_DPOS + + def move_down(): + target_pos[2] -= KEY_DPOS + + def move_up(): + target_pos[2] += KEY_DPOS + + def roll_ccw(): + target_euler[0] += KEY_DANGLE + + def roll_cw(): + target_euler[0] -= KEY_DANGLE + + def pitch_up(): + target_euler[1] += KEY_DANGLE + + def pitch_down(): + target_euler[1] -= KEY_DANGLE + + def yaw_ccw(): + target_euler[2] += KEY_DANGLE + + def yaw_cw(): + target_euler[2] -= KEY_DANGLE + + # Register keybindings + from pyglet.window import key + + scene.viewer.register_keybinds( + ( + Keybind(key_code=key.UP, key_action=KeyAction.HOLD, name="move_forward", callback_func=move_forward), + Keybind(key_code=key.DOWN, key_action=KeyAction.HOLD, name="move_backward", callback_func=move_backward), + Keybind(key_code=key.RIGHT, key_action=KeyAction.HOLD, name="move_right", callback_func=move_right), + Keybind(key_code=key.LEFT, key_action=KeyAction.HOLD, name="move_left", callback_func=move_left), + Keybind(key_code=key.J, key_action=KeyAction.HOLD, name="move_down", callback_func=move_down), + Keybind(key_code=key.K, key_action=KeyAction.HOLD, name="move_up", callback_func=move_up), + Keybind(key_code=key.N, key_action=KeyAction.HOLD, name="roll_ccw", callback_func=roll_ccw), + Keybind(key_code=key.M, key_action=KeyAction.HOLD, name="roll_cw", callback_func=roll_cw), + Keybind(key_code=key.COMMA, key_action=KeyAction.HOLD, name="pitch_up", callback_func=pitch_up), + Keybind(key_code=key.PERIOD, key_action=KeyAction.HOLD, name="pitch_down", callback_func=pitch_down), + Keybind(key_code=key.O, key_action=KeyAction.HOLD, name="yaw_ccw", callback_func=yaw_ccw), + Keybind(key_code=key.P, key_action=KeyAction.HOLD, name="yaw_cw", callback_func=yaw_cw), + Keybind(key_code=key.BACKSLASH, key_action=KeyAction.HOLD, name="reset", callback_func=reset_pose), + ) + ) + + # Print controls + print("Keyboard Controls:") + print("[↑/↓/←/→]: Move XY") + print("[j/k]: Down/Up") + print("[n/m]: Roll CCW/CW") + print("[,/.]: Pitch Up/Down") + print("[o/p]: Yaw CCW/CW") + print("[\\]: Reset") + apply_pose_to_all_envs(target_pos, euler_to_quat(target_euler)) try: while True: - if IS_PYNPUT_AVAILABLE: - pressed = kb.pressed_keys.copy() - if keyboard.Key.esc in pressed: - break - if keyboard.KeyCode.from_char("\\") in pressed: - target_pos[:] = init_pos - target_euler[:] = init_euler - - if keyboard.Key.up in pressed: - target_pos[0] += KEY_DPOS - if keyboard.Key.down in pressed: - target_pos[0] -= KEY_DPOS - if keyboard.Key.right in pressed: - target_pos[1] -= KEY_DPOS - if keyboard.Key.left in pressed: - target_pos[1] += KEY_DPOS - if keyboard.KeyCode.from_char("j") in pressed: - target_pos[2] -= KEY_DPOS - if keyboard.KeyCode.from_char("k") in pressed: - target_pos[2] += KEY_DPOS - - if keyboard.KeyCode.from_char("n") in pressed: - target_euler[0] += KEY_DANGLE # roll CCW around +X - if keyboard.KeyCode.from_char("m") in pressed: - target_euler[0] -= KEY_DANGLE # roll CW around +X - if keyboard.KeyCode.from_char(",") in pressed: - target_euler[1] += KEY_DANGLE # pitch up around +Y - if keyboard.KeyCode.from_char(".") in pressed: - target_euler[1] -= KEY_DANGLE # pitch down around +Y - if keyboard.KeyCode.from_char("o") in pressed: - target_euler[2] += KEY_DANGLE # yaw CCW around +Z - if keyboard.KeyCode.from_char("p") in pressed: - target_euler[2] -= KEY_DANGLE # yaw CW around +Z - else: - # move in a circle if no keyboard control - target_pos[0] = MOVE_RADIUS * np.cos(scene.t * MOVE_RATE) - target_pos[1] = MOVE_RADIUS * np.sin(scene.t * MOVE_RATE) - apply_pose_to_all_envs(target_pos, euler_to_quat(target_euler)) scene.step() diff --git a/examples/viewer_plugin/mesh_point_selector.py b/examples/viewer_plugin/mesh_point_selector.py new file mode 100644 index 0000000000..140b58bfd6 --- /dev/null +++ b/examples/viewer_plugin/mesh_point_selector.py @@ -0,0 +1,47 @@ +from huggingface_hub import snapshot_download + +import genesis as gs + +if __name__ == "__main__": + + gs.init(backend=gs.gpu) + + scene = gs.Scene( + sim_options=gs.options.SimOptions( + gravity=(0.0, 0.0, 0.0), + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(0.5, 0.2, 1.0), + camera_lookat=(0.0, 0.0, 1.0), + camera_fov=40, + viewer_plugin=gs.options.viewer_plugins.MeshPointSelectorPlugin( + sphere_radius=0.004, + grid_snap=(-1.0, 0.01, 0.01), + ), + ), + profiling_options=gs.options.ProfilingOptions( + show_FPS=False, + ), + show_viewer=True, + ) + + asset_path = snapshot_download( + repo_id="Genesis-Intelligence/assets", + allow_patterns="allegro_hand/*", + repo_type="dataset", + ) + hand = scene.add_entity( + morph=gs.morphs.URDF( + file=f"{asset_path}/allegro_hand/allegro_hand_right_glb.urdf", + collision=True, + pos=(0.0, 0.0, 1.0), + euler=(0.0, 0.0, 0.0), + fixed=True, + merge_fixed_links=False, + ), + ) + + scene.build() + + while True: + scene.step() diff --git a/examples/viewer_plugin/mouse_spring.py b/examples/viewer_plugin/mouse_spring.py new file mode 100644 index 0000000000..d88647e131 --- /dev/null +++ b/examples/viewer_plugin/mouse_spring.py @@ -0,0 +1,43 @@ +import math + +import genesis as gs + +if __name__ == "__main__": + + gs.init(backend=gs.gpu) + + scene = gs.Scene( + viewer_options=gs.options.ViewerOptions( + camera_pos=(3.5, 0.0, 2.5), + camera_lookat=(0.0, 0.0, 0.5), + camera_fov=40, + viewer_plugin=gs.options.viewer_plugins.MouseSpringPlugin(), + ), + profiling_options=gs.options.ProfilingOptions( + show_FPS=False, + ), + show_viewer=True, + ) + + scene.add_entity(gs.morphs.Plane()) + + sphere = scene.add_entity( + morph=gs.morphs.Sphere( + pos=(-0.3, -0.3, 0), + radius=0.1, + ), + ) + for i in range(20): + angle = i * (2 * math.pi / 20) + radius = 0.5 + i * 0.1 + cube = scene.add_entity( + morph=gs.morphs.Box( + pos=(radius * math.cos(angle), radius * math.sin(angle), 0.1 + i * 0.1), + size=(0.2, 0.2, 0.2), + ), + ) + + scene.build() + + while True: + scene.step() diff --git a/genesis/engine/sensors/raycaster.py b/genesis/engine/sensors/raycaster.py index 03a9a93db2..0209080199 100644 --- a/genesis/engine/sensors/raycaster.py +++ b/genesis/engine/sensors/raycaster.py @@ -8,11 +8,13 @@ import genesis as gs import genesis.utils.array_class as array_class +from genesis.engine.bvh import AABB, LBVH, STACK_SIZE from genesis.options.sensors import ( Raycaster as RaycasterOptions, +) +from genesis.options.sensors import ( RaycastPattern, ) -from genesis.engine.bvh import AABB, LBVH, STACK_SIZE from genesis.utils.geom import ( ti_normalize, ti_transform_by_quat, @@ -21,6 +23,7 @@ transform_by_trans_quat, ) from genesis.utils.misc import concat_with_tensor, make_tensor_field +from genesis.utils.raycast import kernel_update_aabbs, ray_aabb_intersection, ray_triangle_intersection from genesis.vis.rasterizer_context import RasterizerContext from .base_sensor import ( @@ -36,119 +39,6 @@ from genesis.utils.ring_buffer import TensorRingBuffer -@ti.func -def ray_triangle_intersection(ray_start, ray_dir, v0, v1, v2): - """ - Moller-Trumbore ray-triangle intersection. - - Returns: vec4(t, u, v, hit) where hit=1.0 if intersection found, 0.0 otherwise - """ - result = ti.Vector.zero(gs.ti_float, 4) - - edge1 = v1 - v0 - edge2 = v2 - v0 - - # Begin calculating determinant - also used to calculate u parameter - h = ray_dir.cross(edge2) - a = edge1.dot(h) - - # Check all conditions in sequence without early returns - valid = True - - t = gs.ti_float(0.0) - u = gs.ti_float(0.0) - v = gs.ti_float(0.0) - f = gs.ti_float(0.0) - s = ti.Vector.zero(gs.ti_float, 3) - q = ti.Vector.zero(gs.ti_float, 3) - - # If determinant is near zero, ray lies in plane of triangle - if ti.abs(a) < gs.EPS: - valid = False - - if valid: - f = 1.0 / a - s = ray_start - v0 - u = f * s.dot(h) - - if u < 0.0 or u > 1.0: - valid = False - - if valid: - q = s.cross(edge1) - v = f * ray_dir.dot(q) - - if v < 0.0 or u + v > 1.0: - valid = False - - if valid: - # At this stage we can compute t to find out where the intersection point is on the line - t = f * edge2.dot(q) - - # Ray intersection - if t <= gs.EPS: - valid = False - - if valid: - result = ti.math.vec4(t, u, v, 1.0) - - return result - - -@ti.func -def ray_aabb_intersection(ray_start, ray_dir, aabb_min, aabb_max): - """ - Fast ray-AABB intersection test. - Returns the t value of intersection, or -1.0 if no intersection. - """ - result = -1.0 - - # Use the slab method for ray-AABB intersection - sign = ti.select(ray_dir >= 0.0, 1.0, -1.0) - ray_dir = sign * ti.max(ti.abs(ray_dir), gs.EPS) - inv_dir = 1.0 / ray_dir - - t1 = (aabb_min - ray_start) * inv_dir - t2 = (aabb_max - ray_start) * inv_dir - - tmin = ti.min(t1, t2) - tmax = ti.max(t1, t2) - - t_near = ti.max(tmin.x, tmin.y, tmin.z, 0.0) - t_far = ti.min(tmax.x, tmax.y, tmax.z) - - # Check if ray intersects AABB - if t_near <= t_far: - result = t_near - - return result - - -@ti.kernel -def kernel_update_aabbs( - free_verts_state: array_class.VertsState, - fixed_verts_state: array_class.VertsState, - verts_info: array_class.VertsInfo, - faces_info: array_class.FacesInfo, - aabb_state: ti.template(), -): - for i_b, i_f in ti.ndrange(free_verts_state.pos.shape[1], faces_info.verts_idx.shape[0]): - aabb_state.aabbs[i_b, i_f].min.fill(ti.math.inf) - aabb_state.aabbs[i_b, i_f].max.fill(-ti.math.inf) - - for i in ti.static(range(3)): - i_v = faces_info.verts_idx[i_f][i] - i_fv = verts_info.verts_state_idx[i_v] - if verts_info.is_fixed[i_v]: - pos_v = fixed_verts_state.pos[i_fv] - aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) - aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) - else: - pos_v = free_verts_state.pos[i_fv, i_b] - aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) - aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) - - @ti.kernel def kernel_cast_rays( fixed_verts_state: array_class.VertsState, @@ -195,8 +85,7 @@ def kernel_cast_rays( ray_dir_local = ti.math.vec3(ray_directions[i_p, 0], ray_directions[i_p, 1], ray_directions[i_p, 2]) ray_direction_world = ti_normalize(ti_transform_by_quat(ray_dir_local, link_quat), gs.EPS) - # --- 2. BVH Traversal --- - # FIXME: this duplicates the logic in LBVH.query() which also does traversal + # --- 2. BVH Traversal for ray intersection --- max_range = max_ranges[i_s] hit_face = -1 diff --git a/genesis/ext/pyrender/interaction/__init__.py b/genesis/ext/pyrender/interaction/__init__.py new file mode 100644 index 0000000000..ea9fb562fd --- /dev/null +++ b/genesis/ext/pyrender/interaction/__init__.py @@ -0,0 +1,11 @@ +from .keybindings import KeyAction, Keybind, Keybindings, get_keycode_string +from .plugins.default_controls import DefaultControls +from .plugins.mesh_point_selector import MeshPointSelectorPlugin +from .plugins.mouse_spring import MouseSpringPlugin +from .viewer_plugin import ( + EVENT_HANDLE_STATE, + EVENT_HANDLED, + VIEWER_PLUGIN_MAP, + ViewerPlugin, + register_viewer_plugin, +) diff --git a/genesis/ext/pyrender/interaction/keybindings.py b/genesis/ext/pyrender/interaction/keybindings.py new file mode 100644 index 0000000000..8bb8ce4e7e --- /dev/null +++ b/genesis/ext/pyrender/interaction/keybindings.py @@ -0,0 +1,116 @@ +from enum import IntEnum +from typing import Callable, NamedTuple + +KEY_STRING_TO_CHAR = { + "backslash": "\\", + "slash": "/", + "comma": ",", + "period": ".", + "bracketleft": "[", + "bracketright": "]", + "semicolon": ";", + "minus": "-", + "equal": "=", +} + +class KeyAction(IntEnum): + PRESS = 0 + HOLD = 1 + RELEASE = 2 + +def get_key_hash(key_code: int, modifiers: int | None, action: KeyAction) -> int: + """Generate a unique hash for a key combination. + + Parameters + ---------- + key_code : int + The key code from pyglet. + modifiers : int | None + The modifier keys pressed. + action : KeyAction + The type of key action (press, hold, release). + + Returns + ------- + int + A unique hash for this key combination. + """ + return hash((key_code, modifiers, action)) + +def get_keycode_string(key_code: int) -> str: + from pyglet.window.key import symbol_string + symbol = symbol_string(key_code).lower() + if symbol in KEY_STRING_TO_CHAR: + return KEY_STRING_TO_CHAR[symbol] + return symbol + +class Keybind(NamedTuple): + key_code: int + name: str = "" + key_action: KeyAction = KeyAction.PRESS + callback_func: Callable[[], None] | None = None + modifiers: int | None = None + args: tuple = () + kwargs: dict = {} + + def key_hash(self) -> int: + """Generate a unique hash for the keybind based on key code and modifiers.""" + return get_key_hash(self.key_code, self.modifiers, self.key_action) + +class Keybindings: + + def __init__(self, keybinds: tuple[Keybind] = ()): + self._keybinds_map: dict[int, Keybind] = {kb.key_hash(): kb for kb in keybinds} + + def register(self, keybind: Keybind) -> None: + if keybind.key_hash() in self._keybinds_map: + existing_kb = self._keybinds_map[keybind.key_hash()] + raise ValueError( + f"Key '{get_keycode_string(keybind.key_code)}' is already assigned to '{existing_kb.name}'." + ) + self._keybinds_map[keybind.key_hash()] = keybind + + def rebind(self, name: str, new_key_code: int | None, new_modifiers: int | None = None, new_key_action: KeyAction | None = None) -> None: + for kb in self._keybinds_map.values(): + if kb.name == name: + new_kb = Keybind( + name=kb.name, + key_code=new_key_code or kb.key_code, + key_action=new_key_action or kb.key_action, + modifiers=new_modifiers or kb.modifiers, + callback_func=kb.callback_func, + args=kb.args, + kwargs=kb.kwargs, + ) + del self._keybinds_map[kb.key_hash()] + self._keybinds_map[new_kb.key_hash()] = new_kb + return + raise ValueError(f"No keybind found with name '{name}'.") + + def get(self, key: int, modifiers: int, key_action: KeyAction) -> Keybind | None: + key_hash = get_key_hash(key, modifiers, key_action) + if key_hash in self._keybinds_map: + return self._keybinds_map[key_hash] + + # Try ignoring modifiers (for keybinds where modifiers=None) + key_hash_no_mods = get_key_hash(key, None, key_action) + if key_hash_no_mods in self._keybinds_map: + return self._keybinds_map[key_hash_no_mods] + + return None + + def get_by_name(self, name: str) -> Keybind | None: + for kb in self._keybinds_map.values(): + if kb.name == name: + return kb + return None + + @property + def keys(self) -> tuple[str]: + """Return a list of all registered keys as ASCII characters.""" + return tuple(get_keycode_string(kb.key_code) for kb in self._keybinds_map.values()) + + @property + def keybinds(self) -> tuple[Keybind]: + """Return a tuple of all registered Keybinds.""" + return tuple(self._keybinds_map.values()) \ No newline at end of file diff --git a/genesis/ext/pyrender/interaction/mouse_spring.py b/genesis/ext/pyrender/interaction/mouse_spring.py deleted file mode 100644 index f7a692dd15..0000000000 --- a/genesis/ext/pyrender/interaction/mouse_spring.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -from .ray import Plane, Ray, RayHit -from .vec3 import Pose, Quat, Vec3, Color - -if TYPE_CHECKING: - from genesis.engine.entities.rigid_entity.rigid_link import RigidLink - - -MOUSE_SPRING_POSITION_CORRECTION_FACTOR = 1.0 -MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR = 1.0 - - -class MouseSpring: - def __init__(self) -> None: - self.held_link: "RigidLink | None" = None - self.held_point_in_local: Vec3 | None = None - self.prev_control_point: Vec3 | None = None - - def attach(self, picked_link: "RigidLink", control_point: Vec3) -> None: - # for now, we just pick the first geometry - self.held_link = picked_link - pose: Pose = Pose.from_link(self.held_link) - self.held_point_in_local = pose.inverse_transform_point(control_point) - self.prev_control_point = control_point - - def detach(self) -> None: - self.held_link = None - - def apply_force(self, control_point: Vec3, delta_time: float) -> None: - # note when threaded: apply_force is called before attach! - # note2: that was before we added a lock to ViewerInteraction; this migth be fixed now - if not self.held_link: - return - - self.prev_control_point = control_point - - # do simple force on COM only: - link: "RigidLink" = self.held_link - lin_vel: Vec3 = Vec3.from_tensor(link.get_vel()) - ang_vel: Vec3 = Vec3.from_tensor(link.get_ang()) - link_pose: Pose = Pose.from_link(link) - held_point_in_world: Vec3 = link_pose.transform_point(self.held_point_in_local) - - # note: we should assert earlier that link inertial_pos/quat are not None - # todo: verify inertial_pos/quat are stored in local frame - link_T_principal: Pose = Pose(Vec3.from_arraylike(link.inertial_pos), Quat.from_arraylike(link.inertial_quat)) - world_T_principal: Pose = link_pose * link_T_principal - - arm_in_principal: Vec3 = link_T_principal.inverse_transform_point(self.held_point_in_local) # for non-spherical inertia - arm_in_world: Vec3 = world_T_principal.rot * arm_in_principal # for spherical inertia - - pos_err_v: Vec3 = control_point - held_point_in_world - inv_mass: float = float(1.0 / link.get_mass() if link.get_mass() > 0.0 else 0.0) - inv_spherical_inertia: float = float(1.0 / link.inertial_i[0, 0] if link.inertial_i[0, 0] > 0.0 else 0.0) - - inv_dt: float = 1.0 / delta_time - tau: float = MOUSE_SPRING_POSITION_CORRECTION_FACTOR - damp: float = MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR - - total_impulse: Vec3 = Vec3.zero() - total_torque_impulse: Vec3 = Vec3.zero() - - for i in range(3*4): - body_point_vel: Vec3 = lin_vel + ang_vel.cross(arm_in_world) - vel_err_v: Vec3 = Vec3.zero() - body_point_vel - - dir: Vec3 = Vec3.zero() - dir.v[i % 3] = 1.0 - pos_err: float = dir.dot(pos_err_v) - vel_err: float = dir.dot(vel_err_v) - error: float = tau * pos_err * inv_dt + damp * vel_err - - arm_x_dir: Vec3 = arm_in_world.cross(dir) - virtual_mass: float = 1.0 / (inv_mass + arm_x_dir.sqr_magnitude() * inv_spherical_inertia + 1e-24) - impulse: float = error * virtual_mass - - lin_vel += impulse * inv_mass * dir - ang_vel += impulse * inv_spherical_inertia * arm_x_dir - total_impulse.v[i % 3] += impulse - total_torque_impulse += impulse * arm_x_dir - - # Apply the new force - total_force = total_impulse * inv_dt - total_torque = total_torque_impulse * inv_dt - force_tensor: torch.Tensor = total_force.as_tensor()[None] - torque_tensor: torch.Tensor = total_torque.as_tensor()[None] - link.solver.apply_links_external_force(force_tensor, (link.idx,), ref='link_com', local=False) - link.solver.apply_links_external_torque(torque_tensor, (link.idx,), ref='link_com', local=False) - - @property - def is_attached(self) -> bool: - return self.held_link is not None diff --git a/genesis/ext/pyrender/interaction/plugins/__init__.py b/genesis/ext/pyrender/interaction/plugins/__init__.py new file mode 100644 index 0000000000..a30d1981fa --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/__init__.py @@ -0,0 +1,11 @@ +from .default_controls import DefaultControls +from .help_text import HelpTextPlugin +from .mesh_point_selector import MeshPointSelectorPlugin +from .mouse_spring import MouseSpringPlugin + +__all__ = [ + "DefaultControls", + "HelpTextPlugin", + "MeshPointSelectorPlugin", + "MouseSpringPlugin", +] diff --git a/genesis/ext/pyrender/interaction/plugins/default_controls.py b/genesis/ext/pyrender/interaction/plugins/default_controls.py new file mode 100644 index 0000000000..48fbb91479 --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/default_controls.py @@ -0,0 +1,159 @@ +import os +from typing import TYPE_CHECKING + +import pyglet + +import genesis as gs +from genesis.options.viewer_plugins import DefaultControlsPlugin as DefaultControlsOptions + +from ..keybindings import Keybind +from ..viewer_plugin import register_viewer_plugin +from .help_text import HelpTextPlugin + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + +INSTR_KEYBIND_NAME = "toggle_instructions" + +@register_viewer_plugin(DefaultControlsOptions) +class DefaultControls(HelpTextPlugin): + """ + Default keyboard controls for the Genesis viewer. + + This plugin handles the standard viewer keyboard shortcuts for recording, changing render modes, etc. + """ + + def __init__( + self, + viewer, + options=None, + camera: "Node" = None, + scene: "Scene" = None, + viewport_size: tuple[int, int] = None, + ): + super().__init__(viewer, options, camera, scene, viewport_size) + + self.viewer.register_keybinds(( + Keybind(key_code=pyglet.window.key.R, name="record_video", callback_func=self._toggle_record_video), + Keybind(key_code=pyglet.window.key.S, name="save_image", callback_func=self._save_image), + Keybind(key_code=pyglet.window.key.Z, name="reset_camera", callback_func=self._reset_camera), + Keybind(key_code=pyglet.window.key.A, name="camera_rotation", callback_func=self._toggle_camera_rotation), + Keybind(key_code=pyglet.window.key.H, name="shadow", callback_func=self._toggle_shadow), + Keybind(key_code=pyglet.window.key.F, name="face_normals", callback_func=self._toggle_face_normals), + Keybind(key_code=pyglet.window.key.V, name="vertex_normals", callback_func=self._toggle_vertex_normals), + Keybind(key_code=pyglet.window.key.W, name="world_frame", callback_func=self._toggle_world_frame), + Keybind(key_code=pyglet.window.key.L, name="link_frame", callback_func=self._toggle_link_frame), + Keybind(key_code=pyglet.window.key.D, name="wireframe", callback_func=self._toggle_wireframe), + Keybind(key_code=pyglet.window.key.C, name="camera_frustum", callback_func=self._toggle_camera_frustum), + Keybind(key_code=pyglet.window.key.P, name="reload_shader", callback_func=self._reload_shader), + Keybind(key_code=pyglet.window.key.F11, name="fullscreen_mode", callback_func=self._toggle_fullscreen), + )) + + def _toggle_camera_rotation(self): + self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] + if self.viewer.viewer_flags["rotate"]: + self.set_message_text("Rotation On") + else: + self.set_message_text("Rotation Off") + + def _toggle_fullscreen(self): + self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] + self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) + self.viewer.activate() + if self.viewer.viewer_flags["fullscreen"]: + self.set_message_text("Fullscreen On") + else: + self.set_message_text("Fullscreen Off") + + def _toggle_shadow(self): + self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] + if self.viewer.render_flags["shadows"]: + self.set_message_text("Shadows On") + else: + self.set_message_text("Shadows Off") + + def _toggle_world_frame(self): + if not self.viewer.gs_context.world_frame_shown: + self.viewer.gs_context.on_world_frame() + self.set_message_text("World Frame On") + else: + self.viewer.gs_context.off_world_frame() + self.set_message_text("World Frame Off") + + def _toggle_link_frame(self): + if not self.viewer.gs_context.link_frame_shown: + self.viewer.gs_context.on_link_frame() + self.set_message_text("Link Frame On") + else: + self.viewer.gs_context.off_link_frame() + self.set_message_text("Link Frame Off") + + def _toggle_camera_frustum(self): + if not self.viewer.gs_context.camera_frustum_shown: + self.viewer.gs_context.on_camera_frustum() + self.set_message_text("Camera Frustum On") + else: + self.viewer.gs_context.off_camera_frustum() + self.set_message_text("Camera Frustum Off") + + def _toggle_face_normals(self): + self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] + if self.viewer.render_flags["face_normals"]: + self.set_message_text("Face Normals On") + else: + self.set_message_text("Face Normals Off") + + def _toggle_vertex_normals(self): + self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] + if self.viewer.render_flags["vertex_normals"]: + self.set_message_text("Vert Normals On") + else: + self.set_message_text("Vert Normals Off") + + def _toggle_record_video(self): + if self.viewer.viewer_flags["record"]: + self.viewer.save_video() + self.viewer.set_caption(self.viewer.viewer_flags["window_title"]) + else: + # Importing moviepy is very slow and not used very often. Let's delay import. + from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter + + self.viewer._video_recorder = FFMPEG_VideoWriter( + filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), + fps=self.viewer.viewer_flags["refresh_rate"], + size=self.viewer.viewport_size, + ) + self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) + self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] + + def _save_image(self): + self.viewer._save_image() + + def _toggle_wireframe(self): + if self.viewer.render_flags["flip_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = True + self.viewer.render_flags["all_solid"] = False + self.set_message_text("All Wireframe") + elif self.viewer.render_flags["all_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = True + self.set_message_text("All Solid") + elif self.viewer.render_flags["all_solid"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.set_message_text("Default Wireframe") + else: + self.viewer.render_flags["flip_wireframe"] = True + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.set_message_text("Flip Wireframe") + + def _reset_camera(self): + self.viewer._reset_view() + + def _reload_shader(self): + self.viewer._renderer.reload_program() \ No newline at end of file diff --git a/genesis/ext/pyrender/interaction/plugins/help_text.py b/genesis/ext/pyrender/interaction/plugins/help_text.py new file mode 100644 index 0000000000..ff891f8fbd --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/help_text.py @@ -0,0 +1,104 @@ +from typing import TYPE_CHECKING + +import numpy as np +import pyglet +from typing_extensions import override + +from genesis.options.viewer_plugins import HelpTextPlugin as HelpTextPluginOptions + +from ...constants import TEXT_PADDING, TextAlign +from ..keybindings import Keybind, get_keycode_string +from ..viewer_plugin import ViewerPlugin, register_viewer_plugin + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + from genesis.options.viewer_plugins import ViewerPlugin as ViewerPluginOptions + +INSTR_KEYBIND_NAME = "toggle_instructions" + +@register_viewer_plugin(HelpTextPluginOptions) +class HelpTextPlugin(ViewerPlugin): + """ + Default keyboard controls for the Genesis viewer. + """ + + def __init__( + self, + viewer, + options: "ViewerPluginOptions", + camera: "Node" = None, + scene: "Scene" = None, + viewport_size: tuple[int, int] = None, + ): + super().__init__(viewer, options, camera, scene, viewport_size) + + self.viewer.register_keybinds(( + Keybind(key_code=pyglet.window.key.I, name=INSTR_KEYBIND_NAME, callback_func=self._toggle_instructions), + )) + self._collapse_instructions = True + self._instr_texts: tuple[list[str], list[str]] = ([], []) + self._update_instr_texts() + + self._message_text = None + self._ticks_till_fade = 2.0 / 3.0 * self.viewer.viewer_flags["refresh_rate"] + self._message_opac = 1.0 + self._ticks_till_fade + + def _update_instr_texts(self): + self.instr_key_str = get_keycode_string(self.viewer._keybindings.get_by_name(INSTR_KEYBIND_NAME).key_code) + kb_texts = [ + f"{'[' + get_keycode_string(kb.key_code):>{7}}]: " + + kb.name.replace('_', ' ') for kb in self.viewer._keybindings.keybinds if kb.name != INSTR_KEYBIND_NAME + ] + self._instr_texts = ( + [f"> [{self.instr_key_str}]: show keyboard instructions"], + [f"< [{self.instr_key_str}]: hide keyboard instructions"] + kb_texts + ) + + def _toggle_instructions(self): + self._collapse_instructions = not self._collapse_instructions + self._update_instr_texts() + + def set_message_text(self, text: str): + self._message_text = text + self._message_opac = 1.0 + self._ticks_till_fade + + @override + def on_draw(self): + if self._message_text is not None: + self.viewer._renderer.render_text( + self._message_text, + self.viewport_size[0] - TEXT_PADDING, + TEXT_PADDING, + font_pt=self.options.font_size, + color=np.array([0.1, 0.7, 0.2, np.clip(self._message_opac, 0.0, 1.0)]), + align=TextAlign.BOTTOM_RIGHT, + ) + + if self._message_opac > 1.0: + self._message_opac -= 1.0 + else: + self._message_opac *= 0.90 + + if self._message_opac < 0.05: + self._message_opac = 1.0 + self._ticks_till_fade + self._message_text = None + + + if self._collapse_instructions: + self.viewer._renderer.render_texts( + self._instr_texts[0], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=self.options.font_size, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + else: + self.viewer._renderer.render_texts( + self._instr_texts[1], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=self.options.font_size, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + \ No newline at end of file diff --git a/genesis/ext/pyrender/interaction/plugins/mesh_point_selector.py b/genesis/ext/pyrender/interaction/plugins/mesh_point_selector.py new file mode 100644 index 0000000000..55bc4b1bb8 --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/mesh_point_selector.py @@ -0,0 +1,205 @@ +import csv +from typing import TYPE_CHECKING, NamedTuple + +from typing_extensions import override + +import genesis as gs +from genesis.options.viewer_plugins import MeshPointSelectorPlugin as MeshPointSelectorPluginOptions + +from ..utils import Pose, Ray, Vec3, ViewerRaycaster +from ..viewer_plugin import EVENT_HANDLE_STATE, EVENT_HANDLED, register_viewer_plugin +from .help_text import HelpTextPlugin + +if TYPE_CHECKING: + from genesis.engine.entities.rigid_entity import RigidLink + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + + +class SelectedPoint(NamedTuple): + """ + Represents a selected point on a rigid mesh surface. + + Attributes + ---------- + link : RigidLink + The rigid link that the point belongs to. + local_position : Vec3 + The position of the point in the link's local coordinate frame. + local_normal : Vec3 + The surface normal at the point in the link's local coordinate frame. + """ + link: "RigidLink" + local_position: Vec3 + local_normal: Vec3 + + +@register_viewer_plugin(MeshPointSelectorPluginOptions) +class MeshPointSelectorPlugin(HelpTextPlugin): + """ + Interactive viewer plugin that enables using mouse clicks to select points on rigid meshes. + Selected points are stored in local coordinates relative to their link's frame. + """ + + def __init__( + self, + viewer, + options: MeshPointSelectorPluginOptions, + camera: "Node", + scene: "Scene", + viewport_size: tuple[int, int], + ) -> None: + super().__init__(viewer, options, camera, scene, viewport_size) + self.prev_mouse_pos: tuple[int, int] = (viewport_size[0] // 2, viewport_size[1] // 2) + + # List of selected points with link, local position, and local normal + self.selected_points: list[SelectedPoint] = [] + + self.raycaster: ViewerRaycaster = ViewerRaycaster(self.scene) + + def _snap_to_grid(self, position: Vec3) -> Vec3: + """ + Snap a position to the grid based on grid_snap settings. + + Parameters + ---------- + position : Vec3 + The position to snap. + + Returns + ------- + Vec3 + The snapped position. + """ + snap_x, snap_y, snap_z = self.options.grid_snap + + # Snap each axis if the snap value is non-negative + x = round(position.x / snap_x) * snap_x if snap_x >= 0 else position.x + y = round(position.y / snap_y) * snap_y if snap_y >= 0 else position.y + z = round(position.z / snap_z) * snap_z if snap_z >= 0 else position.z + + return Vec3.from_xyz(x, y, z) + + @override + def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: + super().on_mouse_motion(x, y, dx, dy) + self.prev_mouse_pos = (x, y) + return None + + @override + def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + super().on_mouse_press(x, y, button, modifiers) + if button == 1: # left mouse button + ray = self._screen_position_to_ray(x, y) + ray_hit = self.raycaster.cast_ray(ray.origin.v, ray.direction.v) + + if ray_hit.is_hit and ray_hit.geom: + link = ray_hit.geom.link + world_pos = ray_hit.position + world_normal = ray_hit.normal + + pose: Pose = Pose.from_link(link) + local_pos = pose.inverse_transform_point(world_pos) + local_normal = pose.inverse_transform_direction(world_normal) + + # Apply grid snapping to local position + local_pos = self._snap_to_grid(local_pos) + + selected_point = SelectedPoint( + link=link, + local_position=local_pos, + local_normal=local_normal + ) + self.selected_points.append(selected_point) + + return EVENT_HANDLED + return None + + @override + def update_on_sim_step(self) -> None: + self.raycaster.update_bvh() + + @override + def on_draw(self) -> None: + super().on_draw() + if self.scene._visualizer is not None and self.scene._visualizer.is_built: + self.scene.clear_debug_objects() + mouse_ray: Ray = self._screen_position_to_ray(*self.prev_mouse_pos) + + closest_hit = self.raycaster.cast_ray(mouse_ray.origin.v, mouse_ray.direction.v) + if closest_hit.is_hit: + snap_pos = self._snap_to_grid(closest_hit.position) + # Draw hover preview + self.scene.draw_debug_sphere( + snap_pos.v, + self.options.sphere_radius, + self.options.hover_color, + ) + self.scene.draw_debug_arrow( + snap_pos.v, + closest_hit.normal.v * 0.1, + self.options.sphere_radius / 2, + self.options.hover_color, + ) + + if self.selected_points: + world_positions = [] + for point in self.selected_points: + pose = Pose.from_link(point.link) + current_world_pos = pose.transform_point(point.local_position) + world_positions.append(current_world_pos.v) + + if len(world_positions) == 1: + self.scene.draw_debug_sphere( + world_positions[0], + self.options.sphere_radius, + self.options.sphere_color, + ) + else: + import numpy as np + + positions_array = np.array(world_positions) + self.scene.draw_debug_spheres( + positions_array, self.options.sphere_radius, self.options.sphere_color + ) + + @override + def on_close(self) -> None: + super().on_close() + + if not self.selected_points: + print("[MeshPointSelectorPlugin] No points selected.") + return + + output_file = self.options.output_file + try: + with open(output_file, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + + writer.writerow([ + 'point_idx', + 'link_idx', + 'local_pos_x', + 'local_pos_y', + 'local_pos_z', + 'local_normal_x', + 'local_normal_y', + 'local_normal_z' + ]) + + for i, point in enumerate(self.selected_points, 1): + writer.writerow([ + i, + point.link.idx, + point.local_position.x, + point.local_position.y, + point.local_position.z, + point.local_normal.x, + point.local_normal.y, + point.local_normal.z, + ]) + + gs.logger.info(f"[MeshPointSelectorPlugin] Wrote {len(self.selected_points)} selected points to '{output_file}'") + + except Exception as e: + gs.logger.error(f"[MeshPointSelectorPlugin] Error writing to '{output_file}': {e}") diff --git a/genesis/ext/pyrender/interaction/viewer_interaction.py b/genesis/ext/pyrender/interaction/plugins/mouse_spring.py similarity index 53% rename from genesis/ext/pyrender/interaction/viewer_interaction.py rename to genesis/ext/pyrender/interaction/plugins/mouse_spring.py index 46d24244d1..cad7e84f35 100644 --- a/genesis/ext/pyrender/interaction/viewer_interaction.py +++ b/genesis/ext/pyrender/interaction/plugins/mouse_spring.py @@ -1,44 +1,121 @@ -from typing import TYPE_CHECKING, cast -from typing_extensions import override # Made it into standard lib from Python 3.12 from threading import Lock as threading_Lock +from typing import TYPE_CHECKING -import numpy as np +import torch +from typing_extensions import override # Made it into standard lib from Python 3.12 import genesis as gs +from genesis.options.viewer_plugins import MouseSpringPlugin as MouseSpringPluginOptions -from .aabb import AABB, OBB -from .mouse_spring import MouseSpring -from .ray import Plane, Ray, RayHit -from .vec3 import Pose, Quat, Vec3, Color -from .viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE, EVENT_HANDLED +from ..utils import AABB, OBB, Color, Plane, Pose, Quat, Ray, RayHit, Vec3, ViewerRaycaster +from ..viewer_plugin import EVENT_HANDLE_STATE, EVENT_HANDLED, ViewerPlugin, register_viewer_plugin if TYPE_CHECKING: - from genesis.engine.entities.rigid_entity import RigidGeom, RigidLink, RigidEntity + from genesis.engine.entities.rigid_entity import RigidEntity, RigidGeom, RigidLink from genesis.engine.scene import Scene from genesis.ext.pyrender.node import Node -class ViewerInteraction(ViewerInteractionBase): - """Functionalities to be implemented: - - mouse picking - - mouse dragging +MOUSE_SPRING_POSITION_CORRECTION_FACTOR = 1.0 +MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR = 1.0 + +class MouseSpring: + def __init__(self) -> None: + self.held_link: "RigidLink | None" = None + self.held_point_in_local: Vec3 | None = None + self.prev_control_point: Vec3 | None = None + + def attach(self, picked_link: "RigidLink", control_point: Vec3) -> None: + # for now, we just pick the first geometry + self.held_link = picked_link + pose: Pose = Pose.from_link(self.held_link) + self.held_point_in_local = pose.inverse_transform_point(control_point) + self.prev_control_point = control_point + + def detach(self) -> None: + self.held_link = None + + def apply_force(self, control_point: Vec3, delta_time: float) -> None: + # note when threaded: apply_force is called before attach! + # note2: that was before we added a lock to ViewerInteraction; this migth be fixed now + if not self.held_link: + return + + self.prev_control_point = control_point + + # do simple force on COM only: + link: "RigidLink" = self.held_link + lin_vel: Vec3 = Vec3.from_tensor(link.get_vel()) + ang_vel: Vec3 = Vec3.from_tensor(link.get_ang()) + link_pose: Pose = Pose.from_link(link) + held_point_in_world: Vec3 = link_pose.transform_point(self.held_point_in_local) + + # note: we should assert earlier that link inertial_pos/quat are not None + # todo: verify inertial_pos/quat are stored in local frame + link_T_principal: Pose = Pose(Vec3.from_arraylike(link.inertial_pos), Quat.from_arraylike(link.inertial_quat)) + world_T_principal: Pose = link_pose * link_T_principal + + arm_in_principal: Vec3 = link_T_principal.inverse_transform_point(self.held_point_in_local) # for non-spherical inertia + arm_in_world: Vec3 = world_T_principal.rot * arm_in_principal # for spherical inertia + + pos_err_v: Vec3 = control_point - held_point_in_world + inv_mass: float = float(1.0 / link.get_mass() if link.get_mass() > 0.0 else 0.0) + inv_spherical_inertia: float = float(1.0 / link.inertial_i[0, 0] if link.inertial_i[0, 0] > 0.0 else 0.0) + + inv_dt: float = 1.0 / delta_time + tau: float = MOUSE_SPRING_POSITION_CORRECTION_FACTOR + damp: float = MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR + + total_impulse: Vec3 = Vec3.zero() + total_torque_impulse: Vec3 = Vec3.zero() + + for i in range(3*4): + body_point_vel: Vec3 = lin_vel + ang_vel.cross(arm_in_world) + vel_err_v: Vec3 = Vec3.zero() - body_point_vel + + dir: Vec3 = Vec3.zero() + dir.v[i % 3] = 1.0 + pos_err: float = dir.dot(pos_err_v) + vel_err: float = dir.dot(vel_err_v) + error: float = tau * pos_err * inv_dt + damp * vel_err + + arm_x_dir: Vec3 = arm_in_world.cross(dir) + virtual_mass: float = 1.0 / (inv_mass + arm_x_dir.sqr_magnitude() * inv_spherical_inertia + 1e-24) + impulse: float = error * virtual_mass + + lin_vel += impulse * inv_mass * dir + ang_vel += impulse * inv_spherical_inertia * arm_x_dir + total_impulse.v[i % 3] += impulse + total_torque_impulse += impulse * arm_x_dir + + # Apply the new force + total_force = total_impulse * inv_dt + total_torque = total_torque_impulse * inv_dt + force_tensor: torch.Tensor = total_force.as_tensor()[None] + torque_tensor: torch.Tensor = total_torque.as_tensor()[None] + link.solver.apply_links_external_force(force_tensor, (link.idx,), ref='link_com', local=False) + link.solver.apply_links_external_torque(torque_tensor, (link.idx,), ref='link_com', local=False) + + @property + def is_attached(self) -> bool: + return self.held_link is not None + + +@register_viewer_plugin(MouseSpringPluginOptions) +class MouseSpringPlugin(ViewerPlugin): + """ + Basic interactive viewer plugin that enables using mouse to apply spring force on rigid entities. """ - def __init__(self, - camera: 'Node', - scene: 'Scene', + def __init__( + self, + viewer, + options: MouseSpringPluginOptions, + camera: "Node", + scene: "Scene", viewport_size: tuple[int, int], - camera_yfov: float, - log_events: bool = False, - camera_fov: float = 60.0, ) -> None: - super().__init__(log_events) - self.camera: 'Node' = camera - self.scene: 'Scene' = scene - self.viewport_size: tuple[int, int] = viewport_size - self.camera_yfov: float = camera_yfov - - self.tan_half_fov: float = np.tan(0.5 * self.camera_yfov) + super().__init__(viewer, options, camera, scene, viewport_size) self.prev_mouse_pos: tuple[int, int] = (viewport_size[0] // 2, viewport_size[1] // 2) self.picked_link: RigidLink | None = None @@ -49,6 +126,8 @@ def __init__(self, self.mouse_spring: MouseSpring = MouseSpring() self.lock = threading_Lock() + self.raycaster: ViewerRaycaster = ViewerRaycaster(self.scene) + @override def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: super().on_mouse_motion(x, y, dx, dy) @@ -67,13 +146,14 @@ def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifier def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: super().on_mouse_press(x, y, button, modifiers) if button == 1: # left mouse button - ray_hit = self.raycast_against_entities(self.screen_position_to_ray(x, y)) + + ray_hit = self.raycaster.cast_ray(self._screen_position_to_ray(x, y).origin.v, self._screen_position_to_ray(x, y).direction.v) with self.lock: if ray_hit.geom: self.picked_link = ray_hit.geom.link assert self.picked_link is not None - temp_fwd = self.get_camera_forward() + temp_fwd = self._get_camera_forward() temp_back = -temp_fwd self.mouse_drag_plane = Plane(temp_back, ray_hit.position) @@ -96,17 +176,13 @@ def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT self.mouse_spring.detach() - @override - def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: - super().on_resize(width, height) - self.viewport_size = (width, height) - self.tan_half_fov = np.tan(0.5 * self.camera_yfov) - @override def update_on_sim_step(self) -> None: + self.raycaster.update_bvh() + with self.lock: if self.picked_link: - mouse_ray: Ray = self.screen_position_to_ray(*self.prev_mouse_pos) + mouse_ray: Ray = self._screen_position_to_ray(*self.prev_mouse_pos) ray_hit: RayHit = self.mouse_drag_plane.raycast(mouse_ray) assert ray_hit.is_hit if ray_hit.is_hit: @@ -119,7 +195,7 @@ def update_on_sim_step(self) -> None: # apply force self.mouse_spring.apply_force(new_mouse_3d_pos, self.scene.sim.dt) else: - #apply displacement + # apply displacement pos = Vec3.from_tensor(self.picked_link.entity.get_pos()) pos += delta_3d_pos self.picked_link.entity.set_pos(pos.as_tensor()) @@ -129,11 +205,8 @@ def on_draw(self) -> None: super().on_draw() if self.scene._visualizer is not None and self.scene._visualizer.is_built: self.scene.clear_debug_objects() - mouse_ray: Ray = self.screen_position_to_ray(*self.prev_mouse_pos) - - closest_hit = self.raycast_against_entities(mouse_ray) - if not closest_hit.is_hit: - closest_hit = self._raycast_against_ground_plane(mouse_ray) + mouse_ray: Ray = self._screen_position_to_ray(*self.prev_mouse_pos) + closest_hit = self.raycaster.cast_ray(mouse_ray.origin.v, mouse_ray.direction.v) with self.lock: if self.picked_link: @@ -157,76 +230,6 @@ def on_draw(self) -> None: self._draw_entity_unrotated_obb(closest_hit.geom) - def screen_position_to_ray(self, x: float, y: float) -> Ray: - # convert screen position to ray - if True: - x = x - 0.5 * self.viewport_size[0] - y = y - 0.5 * self.viewport_size[1] - x = 2.0 * x / self.viewport_size[1] * self.tan_half_fov - y = 2.0 * y / self.viewport_size[1] * self.tan_half_fov - else: - # alternative way - projection_matrix = self.camera.camera.get_projection_matrix(*self.viewport_size) - x = x - 0.5 * self.viewport_size[0] - y = y - 0.5 * self.viewport_size[1] - x = 2.0 * x / self.viewport_size[0] / projection_matrix[0, 0] - y = 2.0 * y / self.viewport_size[1] / projection_matrix[1, 1] - - # Note: ignoring pixel aspect ratio - - mtx = self.camera.matrix - position = Vec3.from_array(mtx[:3, 3]) - forward = Vec3.from_array(-mtx[:3, 2]) - right = Vec3.from_array(mtx[:3, 0]) - up = Vec3.from_array(mtx[:3, 1]) - - direction = forward + right * x + up * y - return Ray(position, direction) - - def get_camera_forward(self) -> Vec3: - mtx = self.camera.matrix - return Vec3.from_array(-mtx[:3, 2]) - - def get_camera_ray(self) -> Ray: - mtx = self.camera.matrix - position = Vec3.from_array(mtx[:3, 3]) - forward = Vec3.from_array(-mtx[:3, 2]) - return Ray(position, forward) - - def _raycast_against_ground_plane(self, ray: Ray) -> RayHit: - ground_plane = Plane(Vec3.from_xyz(0, 0, 1), Vec3.zero()) - return ground_plane.raycast(ray) - - def raycast_against_entity_obb(self, entity: "RigidEntity", ray: Ray) -> RayHit: - if isinstance(entity.morph, gs.morphs.Box): - obb: OBB = self._get_box_obb(entity) - ray_hit = obb.raycast(ray) - if ray_hit.is_hit: - ray_hit.geom = entity.geoms[0] - return ray_hit - elif isinstance(entity.morph, gs.morphs.Plane): - # ignore plane - return RayHit.no_hit() - else: - closest_hit = RayHit.no_hit() - for link in entity.links: - if not link.is_fixed: - for geom in link.geoms: - obb: OBB = self._get_geom_placeholder_obb(geom) - ray_hit = obb.raycast(ray) - if ray_hit.distance < closest_hit.distance: - ray_hit.geom = geom - closest_hit = ray_hit - return closest_hit - - def raycast_against_entities(self, ray: Ray) -> RayHit: - closest_hit = RayHit.no_hit() - for entity in self.scene.sim.rigid_solver.entities: - rigid_entity: "RigidEntity" = cast("RigidEntity", entity) - ray_hit = self.raycast_against_entity_obb(rigid_entity, ray) - if ray_hit.distance < closest_hit.distance: - closest_hit = ray_hit - return closest_hit def _get_box_obb(self, box_entity: "RigidEntity") -> OBB: box: gs.morphs.Box = box_entity.morph diff --git a/genesis/ext/pyrender/interaction/utils/__init__.py b/genesis/ext/pyrender/interaction/utils/__init__.py new file mode 100644 index 0000000000..7182f96982 --- /dev/null +++ b/genesis/ext/pyrender/interaction/utils/__init__.py @@ -0,0 +1,4 @@ +from .aabb import AABB, OBB +from .ray import Plane, Ray, RayHit +from .raycaster import ViewerRaycaster +from .vec3 import Color, Pose, Quat, Vec3 diff --git a/genesis/ext/pyrender/interaction/aabb.py b/genesis/ext/pyrender/interaction/utils/aabb.py similarity index 100% rename from genesis/ext/pyrender/interaction/aabb.py rename to genesis/ext/pyrender/interaction/utils/aabb.py diff --git a/genesis/ext/pyrender/interaction/ray.py b/genesis/ext/pyrender/interaction/utils/ray.py similarity index 100% rename from genesis/ext/pyrender/interaction/ray.py rename to genesis/ext/pyrender/interaction/utils/ray.py diff --git a/genesis/ext/pyrender/interaction/utils/raycaster.py b/genesis/ext/pyrender/interaction/utils/raycaster.py new file mode 100644 index 0000000000..6c8d31bb1f --- /dev/null +++ b/genesis/ext/pyrender/interaction/utils/raycaster.py @@ -0,0 +1,282 @@ +from typing import TYPE_CHECKING + +import gstaichi as ti +import numpy as np + +import genesis as gs +from genesis.engine.bvh import AABB, LBVH, STACK_SIZE +from genesis.utils.raycast import kernel_update_aabbs, ray_aabb_intersection, ray_triangle_intersection + +from .ray import RayHit +from .vec3 import Vec3 + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + + +# Constant to indicate no hit occurred +NO_HIT_DISTANCE = -1.0 + + + +@ti.kernel +def kernel_cast_single_ray_for_viewer( + fixed_verts_state: ti.template(), + free_verts_state: ti.template(), + verts_info: ti.template(), + faces_info: ti.template(), + bvh_nodes: ti.template(), + bvh_morton_codes: ti.template(), + ray_start: ti.types.ndarray(ndim=1), # [3] + ray_direction: ti.types.ndarray(ndim=1), # [3] + max_range: ti.f32, + envs_idx: ti.types.ndarray(ndim=1), # [n_envs] + result: ti.types.ndarray(ndim=1), # [9]: [distance, geom_idx, hit_x, hit_y, hit_z, normal_x, normal_y, normal_z, env_idx] +): + """ + Taichi kernel for casting a single ray for viewer interaction. + + This loops over all environments in envs_idx and returns the closest hit. + + Returns: + result[0]: distance to hit point (NO_HIT_DISTANCE if no hit) + result[1]: geom_idx of hit geometry + result[2]: hit_point x coordinate + result[3]: hit_point y coordinate + result[4]: hit_point z coordinate + result[5]: normal x coordinate + result[6]: normal y coordinate + result[7]: normal z coordinate + result[8]: env_idx of hit environment + """ + n_triangles = faces_info.verts_idx.shape[0] + + # Setup ray + ray_start_world = ti.math.vec3(ray_start[0], ray_start[1], ray_start[2]) + ray_direction_world = ti.math.vec3(ray_direction[0], ray_direction[1], ray_direction[2]) + + # Initialize result with no hit + result[0] = -1.0 # NO_HIT_DISTANCE + result[1] = -1.0 # no geom + result[2] = 0.0 # hit_point x + result[3] = 0.0 # hit_point y + result[4] = 0.0 # hit_point z + result[5] = 0.0 # normal x + result[6] = 0.0 # normal y + result[7] = 0.0 # normal z + result[8] = -1.0 # no env + + global_closest_distance = max_range + global_hit_face = -1 + global_hit_env_idx = -1 + global_hit_normal = ti.math.vec3(0.0, 0.0, 0.0) + + # Loop over all environments in envs_idx + for i_b in range(envs_idx.shape[0]): + rendered_env_idx = ti.cast(envs_idx[i_b], ti.i32) + + hit_face = -1 + closest_distance = global_closest_distance + hit_normal = ti.math.vec3(0.0, 0.0, 0.0) + + # Stack for non-recursive BVH traversal + node_stack = ti.Vector.zero(ti.i32, STACK_SIZE) + node_stack[0] = 0 # Start at root node + stack_idx = 1 + + while stack_idx > 0: + stack_idx -= 1 + node_idx = node_stack[stack_idx] + + node = bvh_nodes[i_b, node_idx] + + # Check if ray hits the node's bounding box + aabb_t = ray_aabb_intersection(ray_start_world, ray_direction_world, node.bound.min, node.bound.max) + + if aabb_t >= 0.0 and aabb_t < closest_distance: + if node.left == -1: # Leaf node + # Get original triangle/face index + sorted_leaf_idx = node_idx - (n_triangles - 1) + i_f = ti.cast(bvh_morton_codes[0, sorted_leaf_idx][1], ti.i32) + + # Get triangle vertices + tri_vertices = ti.Matrix.zero(gs.ti_float, 3, 3) + for i in ti.static(range(3)): + i_v = faces_info.verts_idx[i_f][i] + i_fv = verts_info.verts_state_idx[i_v] + if verts_info.is_fixed[i_v]: + tri_vertices[:, i] = fixed_verts_state.pos[i_fv] + else: + tri_vertices[:, i] = free_verts_state.pos[i_fv, rendered_env_idx] + v0, v1, v2 = tri_vertices[:, 0], tri_vertices[:, 1], tri_vertices[:, 2] + + # Perform ray-triangle intersection + hit_result = ray_triangle_intersection(ray_start_world, ray_direction_world, v0, v1, v2) + + if hit_result.w > 0.0 and hit_result.x < closest_distance and hit_result.x >= 0.0: + closest_distance = hit_result.x + hit_face = i_f + # Compute triangle normal + edge1 = v1 - v0 + edge2 = v2 - v0 + hit_normal = edge1.cross(edge2).normalized() + else: # Internal node + # Push children onto stack + if stack_idx < ti.static(STACK_SIZE - 2): + node_stack[stack_idx] = node.left + node_stack[stack_idx + 1] = node.right + stack_idx += 2 + + # Update global closest if this environment had a closer hit + if hit_face >= 0 and closest_distance < global_closest_distance: + global_closest_distance = closest_distance + global_hit_face = hit_face + global_hit_env_idx = rendered_env_idx + global_hit_normal = hit_normal + + # Store result + if global_hit_face >= 0: + result[0] = global_closest_distance # distance (positive value indicates hit) + # Find which geom this face belongs to + i_g = faces_info.geom_idx[global_hit_face] + result[1] = gs.ti_float(i_g) + # Compute hit point + hit_point = ray_start_world + global_closest_distance * ray_direction_world + result[2] = hit_point.x + result[3] = hit_point.y + result[4] = hit_point.z + # Store normal + result[5] = global_hit_normal.x + result[6] = global_hit_normal.y + result[7] = global_hit_normal.z + result[8] = gs.ti_float(global_hit_env_idx) + + + +class ViewerRaycaster: + """ + BVH-accelerated raycaster for viewer interaction plugins. + + This class manages a BVH structure built from the scene's rigid geometry + and provides efficient single-ray casting for interactive applications. + Only considers environments specified in rendered_envs_idx. + """ + + def __init__(self, scene: "Scene"): + """ + Initialize the ViewerRaycaster. + + Parameters + ---------- + scene : Scene + The scene to build the raycaster for. + """ + self.scene = scene + self.solver = scene.sim.rigid_solver + + # Store rendered_envs_idx as numpy array for Taichi kernel + + # self.rendered_envs_idx = np.asarray(scene.vis_options.rendered_envs_idx or [0], dtype=gs.np_int) + self.rendered_envs_idx = np.asarray([0], dtype=gs.np_int) + + # Build the BVH structure for rendered environments. + n_faces = self.solver.faces_info.geom_idx.shape[0] + + if n_faces == 0: + gs.logger.warning("No faces found in scene, viewer raycasting will not work.") + self.aabb = None + self.bvh = None + return + + self.aabb = AABB(n_batches=len(self.rendered_envs_idx), n_aabbs=n_faces) + self.bvh = LBVH( + self.aabb, + max_n_query_result_per_aabb=0, # Not used for ray queries + n_radix_sort_groups=min(64, n_faces), + ) + + self.update_bvh() + + def update_bvh(self): + """Update the BVH structure with current geometry state.""" + if self.bvh is None: + return + + # Update vertex positions + from genesis.engine.solvers.rigid.rigid_solver_decomp import kernel_update_all_verts + + kernel_update_all_verts( + geoms_info=self.solver.geoms_info, + geoms_state=self.solver.geoms_state, + verts_info=self.solver.verts_info, + free_verts_state=self.solver.free_verts_state, + fixed_verts_state=self.solver.fixed_verts_state, + static_rigid_sim_config=self.solver._static_rigid_sim_config, + ) + + # Update AABBs for each rendered environment + kernel_update_aabbs( + free_verts_state=self.solver.free_verts_state, + fixed_verts_state=self.solver.fixed_verts_state, + verts_info=self.solver.verts_info, + faces_info=self.solver.faces_info, + aabb_state=self.aabb, + ) + + # Rebuild BVH + self.bvh.build() + + def cast_ray( + self, + ray_origin: np.ndarray, + ray_direction: np.ndarray, + max_range: float = 1000.0, + ) -> RayHit: + """ + Cast a single ray against all rendered environments and return the closest hit. + + Parameters + ---------- + ray_origin : np.ndarray, shape (3,) + The origin point of the ray in world coordinates. + ray_direction : np.ndarray, shape (3,) + The direction vector of the ray (will be normalized). + max_range : float, optional + Maximum distance to check for intersections. Default is 1000.0. + + Returns + ------- + RayHit + A RayHit object containing distance, position, normal, and geom. + If no hit, returns RayHit.no_hit(). + """ + ray_direction = ray_direction / (np.linalg.norm(ray_direction) + gs.EPS) + + ray_start_np = np.asarray(ray_origin, dtype=gs.np_float) + ray_dir_np = np.asarray(ray_direction, dtype=gs.np_float) + result_np = np.zeros(9, dtype=gs.np_float) + + kernel_cast_single_ray_for_viewer( + fixed_verts_state=self.solver.fixed_verts_state, + free_verts_state=self.solver.free_verts_state, + verts_info=self.solver.verts_info, + faces_info=self.solver.faces_info, + bvh_nodes=self.bvh.nodes, + bvh_morton_codes=self.bvh.morton_codes, + ray_start=ray_start_np, + ray_direction=ray_dir_np, + max_range=max_range, + envs_idx=self.rendered_envs_idx, + result=result_np, + ) + + distance = float(result_np[0]) + if distance < NO_HIT_DISTANCE + gs.EPS: # NO_HIT_DISTANCE + return RayHit.no_hit() + + geom_idx = int(result_np[1]) + position = Vec3(result_np[2:5]) + normal = Vec3(result_np[5:8]) + geom = self.solver.geoms[geom_idx] + + return RayHit(distance=distance, position=position, normal=normal, geom=geom) diff --git a/genesis/ext/pyrender/interaction/vec3.py b/genesis/ext/pyrender/interaction/utils/vec3.py similarity index 100% rename from genesis/ext/pyrender/interaction/vec3.py rename to genesis/ext/pyrender/interaction/utils/vec3.py diff --git a/genesis/ext/pyrender/interaction/viewer_interaction_base.py b/genesis/ext/pyrender/interaction/viewer_interaction_base.py deleted file mode 100644 index ede759dc0c..0000000000 --- a/genesis/ext/pyrender/interaction/viewer_interaction_base.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Union, Literal - -import genesis as gs - - -EVENT_HANDLE_STATE = Union[Literal[True], None] -EVENT_HANDLED: Literal[True] = True - -# Note: Viewer window is based on pyglet.window.Window, mouse events are defined in pyglet.window.BaseWindow - -class ViewerInteractionBase(): - """Base class for handling pyglet.window.Window events. - """ - - log_events: bool - - def __init__(self, log_events: bool = False): - self.log_events = log_events - - def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse moved to {x}, {y}") - - def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse dragged to {x}, {y}") - - def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse buttons {button} pressed at {x}, {y}") - - def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse buttons {button} released at {x}, {y}") - - def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Key pressed: {chr(symbol)}") - - def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Key released: {chr(symbol)}") - - def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Window resized to {width}x{height}") - - def update_on_sim_step(self) -> None: - pass - - def on_draw(self) -> None: - pass diff --git a/genesis/ext/pyrender/interaction/viewer_plugin.py b/genesis/ext/pyrender/interaction/viewer_plugin.py new file mode 100644 index 0000000000..e6708e2dfd --- /dev/null +++ b/genesis/ext/pyrender/interaction/viewer_plugin.py @@ -0,0 +1,126 @@ +from typing import TYPE_CHECKING, Literal, Type + +import numpy as np + +from .utils import Ray, Vec3 + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + from genesis.options.viewer_plugins import ViewerPlugin as ViewerPluginOptions + + +EVENT_HANDLE_STATE = Literal[True] | None +EVENT_HANDLED: Literal[True] = True + +# Global map from options class to viewer plugin class +VIEWER_PLUGIN_MAP: dict[Type["ViewerPluginOptions"], Type["ViewerPlugin"]] = {} + + +def register_viewer_plugin(options_cls: Type["ViewerPluginOptions"]): + """ + Decorator to register a viewer plugin class with its corresponding options class. + + Parameters + ---------- + options_cls : Type[ViewerPluginOptions] + The options class that configures this viewer plugin. + + Returns + ------- + Callable + The decorator function that registers the plugin class. + + Example + ------- + @register_viewer_plugin(ViewerInteractionOptions) + class ViewerInteraction(ViewerInteractionBase): + ... + """ + def _impl(plugin_cls: Type["ViewerPlugin"]): + VIEWER_PLUGIN_MAP[options_cls] = plugin_cls + return plugin_cls + return _impl + +# Note: Viewer window is based on pyglet.window.Window, mouse events are defined in pyglet.window.BaseWindow + +class ViewerPlugin(): + """ + Base class for handling pyglet.window.Window events. + """ + + def __init__( + self, + viewer, + options: "ViewerPluginOptions", + camera: "Node", + scene: "Scene", + viewport_size: tuple[int, int], + ): + self.viewer = viewer + self.options: "ViewerPluginOptions" = options + self.camera: 'Node' = camera + self.scene: 'Scene' = scene + self.viewport_size: tuple[int, int] = viewport_size + + self.camera_yfov: float = camera.camera.yfov + self.tan_half_fov: float = np.tan(0.5 * self.camera_yfov) + + def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: + self.viewport_size = (width, height) + self.tan_half_fov = np.tan(0.5 * self.camera_yfov) + + def update_on_sim_step(self) -> None: + pass + + def on_draw(self) -> None: + pass + + def on_close(self) -> None: + pass + + def _screen_position_to_ray(self, x: float, y: float) -> Ray: + # convert screen position to ray + x = x - 0.5 * self.viewport_size[0] + y = y - 0.5 * self.viewport_size[1] + x = 2.0 * x / self.viewport_size[1] * self.tan_half_fov + y = 2.0 * y / self.viewport_size[1] * self.tan_half_fov + + # Note: ignoring pixel aspect ratio + + mtx = self.camera.matrix + position = Vec3.from_array(mtx[:3, 3]) + forward = Vec3.from_array(-mtx[:3, 2]) + right = Vec3.from_array(mtx[:3, 0]) + up = Vec3.from_array(mtx[:3, 1]) + + direction = forward + right * x + up * y + return Ray(position, direction) + + def _get_camera_forward(self) -> Vec3: + mtx = self.camera.matrix + return Vec3.from_array(-mtx[:3, 2]) + + def _get_camera_ray(self) -> Ray: + mtx = self.camera.matrix + position = Vec3.from_array(mtx[:3, 3]) + forward = Vec3.from_array(-mtx[:3, 2]) + return Ray(position, forward) \ No newline at end of file diff --git a/genesis/ext/pyrender/viewer.py b/genesis/ext/pyrender/viewer.py index 569273f614..ad1beb9d2c 100644 --- a/genesis/ext/pyrender/viewer.py +++ b/genesis/ext/pyrender/viewer.py @@ -1,14 +1,14 @@ """A pyglet-based interactive 3D scene viewer.""" import copy -from contextlib import nullcontext import os import shutil import sys -import time import threading +import time +from contextlib import nullcontext from threading import Event, RLock, Semaphore, Thread -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import numpy as np import OpenGL @@ -16,6 +16,8 @@ import genesis as gs +from .interaction.keybindings import KeyAction, Keybind, Keybindings + # Importing tkinter and creating a first context before importing pyglet is necessary to avoid later segfault on MacOS. # Note that destroying the window will cause segfault at exit. root = None @@ -44,12 +46,11 @@ RenderFlags, TextAlign, ) -from .interaction.viewer_interaction import ViewerInteraction -from .interaction.viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE, EVENT_HANDLED +from .interaction import EVENT_HANDLE_STATE, EVENT_HANDLED, VIEWER_PLUGIN_MAP from .light import DirectionalLight from .node import Node from .renderer import Renderer -from .shader_program import ShaderProgram, ShaderProgramCache +from .shader_program import ShaderProgram from .trackball import Trackball if TYPE_CHECKING: @@ -80,17 +81,7 @@ class Viewer(pyglet.window.Window): viewer_flags : dict A set of flags for controlling the viewer's behavior. Described in the note below. - registered_keys : dict - A map from ASCII key characters to tuples containing: - - - A function to be called whenever the key is pressed, - whose first argument will be the viewer itself. - - (Optionally) A list of additional positional arguments - to be passed to the function. - - (Optionally) A dict of keyword arguments to be passed - to the function. - - kwargs : dict + **kwargs : dict Any keyword arguments left over will be interpreted as belonging to either the :attr:`.Viewer.render_flags` or :attr:`.Viewer.viewer_flags` dictionaries. Those flag sets will be updated appropriately. @@ -199,14 +190,12 @@ def __init__( viewport_size=None, render_flags=None, viewer_flags=None, - registered_keys=None, run_in_thread=False, auto_start=True, shadow=False, plane_reflection=False, env_separate_rigid=False, - enable_interaction=False, - disable_keyboard_shortcuts=False, + plugin_options=None, **kwargs, ): ####################################################################### @@ -231,7 +220,6 @@ def __init__( self._offscreen_semaphore = Semaphore(0) self._offscreen_result = None - self._video_saver = None self._video_recorder = None self._default_render_flags = { @@ -282,42 +270,12 @@ def __init__( elif key in self.viewer_flags: self._viewer_flags[key] = kwargs[key] - self._registered_keys = {} - if registered_keys is not None: - self._registered_keys = {ord(k.lower()): registered_keys[k] for k in registered_keys} - - self._disable_keyboard_shortcuts = disable_keyboard_shortcuts - + self._keybindings: Keybindings = Keybindings() + self._held_keys: dict[tuple[int, int], bool] = {} ####################################################################### # Save internal settings ####################################################################### - # Set up caption stuff - self._message_text = None - self._ticks_till_fade = 2.0 / 3.0 * self.viewer_flags["refresh_rate"] - self._message_opac = 1.0 + self._ticks_till_fade - - self._display_instr = False - - self._instr_texts = [ - ["> [i]: show keyboard instructions"], - [ - "< [i]: hide keyboard instructions", - " [r]: record video", - " [s]: save image", - " [z]: reset camera", - " [a]: camera rotation", - " [h]: shadow", - " [f]: face normal", - " [v]: vertex normal", - " [w]: world frame", - " [l]: link frame", - " [d]: wireframe", - " [c]: camera & frustrum", - " [F11]: full-screen mode", - ], - ] - # Set up raymond lights and direct lights self._raymond_lights = self._create_raymond_lights() self._direct_light = self._create_direct_light() @@ -379,14 +337,21 @@ def __init__( self.scene.main_camera_node = self._camera_node self._reset_view() - # Setup mouse interaction - # Note: context.scene is genesis.engine.scene.Scene # Note: context._scene is genesis.ext.pyrender.scene.Scene - self.viewer_interaction = ( - ViewerInteraction(self._camera_node, context.scene, viewport_size, camera.yfov) - if enable_interaction - else ViewerInteractionBase() + + # Setup viewer interaction + if plugin_options is None: + plugin_options = gs.options.viewer_plugins.ViewerDefaultControls() + + plugin_cls = VIEWER_PLUGIN_MAP.get(type(plugin_options)) + if plugin_cls is None: + gs.raise_exception( + f"Viewer plugin type {type(plugin_options).__name__} is not registered. " + f"Available plugins: {list(VIEWER_PLUGIN_MAP.keys())}" + ) + self.interaction_plugin = plugin_cls( + self, plugin_options, self._camera_node, context.scene, viewport_size ) ####################################################################### @@ -410,7 +375,7 @@ def __init__( self._initialized_event.wait() if not self._is_active: if self._exception: - raise RuntimeError(f"Unable to initialize an OpenGL 3+ context.") from self._exception + raise RuntimeError("Unable to initialize an OpenGL 3+ context.") from self._exception raise OpenGL.error.Error("Invalid OpenGL context.") else: if self.auto_start: @@ -520,26 +485,19 @@ def viewer_flags(self): @viewer_flags.setter def viewer_flags(self, value): self._viewer_flags = value - - @property - def registered_keys(self): - """dict : Map from ASCII key character to a handler function. - - This is a map from ASCII key characters to tuples containing: - - - A function to be called whenever the key is pressed, - whose first argument will be the viewer itself. - - (Optionally) A list of additional positional arguments - to be passed to the function. - - (Optionally) A dict of keyword arguments to be passed - to the function. - + + def register_keybinds(self, keybinds: tuple[Keybind]) -> None: """ - return self._registered_keys + Add a key handler to call a function when the given key is pressed. - @registered_keys.setter - def registered_keys(self, value): - self._registered_keys = value + Parameters + ---------- + keybinds : tuple[Keybind] + A tuple of Keybind objects to register. + """ + for keybind in keybinds: + self._keybindings.register(keybind) + def close(self): """Close the viewer. @@ -590,6 +548,8 @@ def on_close(self): # Do not consider the viewer as active anymore self._is_active = False + self.interaction_plugin.on_close() + # Remove our camera and restore the prior one try: if self._camera_node is not None: @@ -733,49 +693,21 @@ def on_draw(self): self.clear() self._render() - self.viewer_interaction.on_draw() - - if not self._disable_keyboard_shortcuts: - if self._display_instr: - self._renderer.render_texts( - self._instr_texts[1], - TEXT_PADDING, - self.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - else: - self._renderer.render_texts( - self._instr_texts[0], - TEXT_PADDING, - self.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - - if self._message_text is not None: - self._renderer.render_text( - self._message_text, - self.viewport_size[0] - TEXT_PADDING, - TEXT_PADDING, - font_pt=20, - color=np.array([0.1, 0.7, 0.2, np.clip(self._message_opac, 0.0, 1.0)]), - align=TextAlign.BOTTOM_RIGHT, - ) - - if self.viewer_flags["caption"] is not None: - for caption in self.viewer_flags["caption"]: - xpos, ypos = self._location_to_x_y(caption["location"]) - self._renderer.render_text( - caption["text"], - xpos, - ypos, - font_name=caption["font_name"], - font_pt=caption["font_pt"], - color=caption["color"], - scale=caption["scale"], - align=caption["location"], - ) + self.interaction_plugin.on_draw() + + if self.viewer_flags["caption"] is not None: + for caption in self.viewer_flags["caption"]: + xpos, ypos = self._location_to_x_y(caption["location"]) + self._renderer.render_text( + caption["text"], + xpos, + ypos, + font_name=caption["font_name"], + font_pt=caption["font_pt"], + color=caption["color"], + scale=caption["scale"], + align=caption["location"], + ) def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: """Resize the camera and trackball when the window is resized.""" @@ -789,15 +721,22 @@ def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: self._trackball.resize(self._viewport_size) self._renderer.viewport_width = self._viewport_size[0] self._renderer.viewport_height = self._viewport_size[1] - self.viewer_interaction.on_resize(width, height) + self.interaction_plugin.on_resize(width, height) self.on_draw() def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: """The mouse was moved with no buttons held down.""" - return self.viewer_interaction.on_mouse_motion(x, y, dx, dy) + return self.interaction_plugin.on_mouse_motion(x, y, dx, dy) def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record an initial mouse press.""" + # Stop animating while using the mouse + self.viewer_flags["mouse_pressed"] = True + + result = self.interaction_plugin.on_mouse_press(x, y, button, modifiers) + if result is EVENT_HANDLED: + return result + self._trackball.set_state(Trackball.STATE_ROTATE) if button == pyglet.window.mouse.LEFT: ctrl = modifiers & pyglet.window.key.MOD_CTRL @@ -814,24 +753,27 @@ def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_H self._trackball.down(np.array([x, y])) - # Stop animating while using the mouse - self.viewer_flags["mouse_pressed"] = True - return self.viewer_interaction.on_mouse_press(x, y, button, modifiers) + return EVENT_HANDLED def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: """The mouse was moved with one or more buttons held down.""" - result = self.viewer_interaction.on_mouse_drag(x, y, dx, dy, buttons, modifiers) - if result is not EVENT_HANDLED: - result = self._trackball.drag(np.array([x, y])) - return result + result = self.interaction_plugin.on_mouse_drag(x, y, dx, dy, buttons, modifiers) + if result is EVENT_HANDLED: + return result + + result = self._trackball.drag(np.array([x, y])) + return EVENT_HANDLED def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a mouse release.""" self.viewer_flags["mouse_pressed"] = False - return self.viewer_interaction.on_mouse_release(x, y, button, modifiers) + return self.interaction_plugin.on_mouse_release(x, y, button, modifiers) - def on_mouse_scroll(self, x, y, dx, dy): + def on_mouse_scroll(self, x, y, dx, dy) -> EVENT_HANDLE_STATE: """Record a mouse scroll.""" + if self.interaction_plugin.on_mouse_scroll(x, y, dx, dy) == EVENT_HANDLED: + return EVENT_HANDLED + if self.viewer_flags["use_perspective_cam"]: self._trackball.scroll(dy) else: @@ -848,177 +790,32 @@ def on_mouse_scroll(self, x, y, dx, dy): ymag = max(c.ymag * sf, 1e-8 * c.ymag / c.xmag) c.xmag = xmag c.ymag = ymag + + return EVENT_HANDLED + + def _call_keybind_callback(self, symbol: int, modifiers: int, action: KeyAction) -> None: + """Call registered keybind callbacks for the given key event.""" + keybind: Keybind = self._keybindings.get(symbol, modifiers, action) + if keybind is not None and keybind.callback_func is not None: + keybind.callback_func(*keybind.args, **keybind.kwargs) def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a key press.""" - # First, check for registered key callbacks - if symbol in self.registered_keys: - tup = self.registered_keys[symbol] - callback = None - args = [] - kwargs = {} - if not isinstance(tup, (list, tuple, np.ndarray)): - callback = tup - else: - callback = tup[0] - if len(tup) == 2: - args = tup[1] - if len(tup) == 3: - kwargs = tup[2] - callback(self, *args, **kwargs) - return self.viewer_interaction.on_key_press(symbol, modifiers) - - # If keyboard shortcuts are disabled, skip default key functions - if self._disable_keyboard_shortcuts: - return self.viewer_interaction.on_key_press(symbol, modifiers) - - # Otherwise, use default key functions - - # A causes the frame to rotate - self._message_text = None - if symbol == pyglet.window.key.A: - self.viewer_flags["rotate"] = not self.viewer_flags["rotate"] - if self.viewer_flags["rotate"]: - self._message_text = "Rotation On" - else: - self._message_text = "Rotation Off" - - # F11 toggles face normals - elif symbol == pyglet.window.key.F11: - self.viewer_flags["fullscreen"] = not self.viewer_flags["fullscreen"] - self.set_fullscreen(self.viewer_flags["fullscreen"]) - self.activate() - if self.viewer_flags["fullscreen"]: - self._message_text = "Fullscreen On" - else: - self._message_text = "Fullscreen Off" - - # H toggles shadows - elif symbol == pyglet.window.key.H: - self.render_flags["shadows"] = not self.render_flags["shadows"] - if self.render_flags["shadows"]: - self._message_text = "Shadows On" - else: - self._message_text = "Shadows Off" + self._held_keys[(symbol, modifiers)] = True - # W toggles world frame - elif symbol == pyglet.window.key.W: - if not self.gs_context.world_frame_shown: - self.gs_context.on_world_frame() - self._message_text = "World Frame On" - else: - self.gs_context.off_world_frame() - self._message_text = "World Frame Off" - - # L toggles link frame - elif symbol == pyglet.window.key.L: - if not self.gs_context.link_frame_shown: - self.gs_context.on_link_frame() - self._message_text = "Link Frame On" - else: - self.gs_context.off_link_frame() - self._message_text = "Link Frame Off" - - # C toggles camera frustum - elif symbol == pyglet.window.key.C: - if not self.gs_context.camera_frustum_shown: - self.gs_context.on_camera_frustum() - self._message_text = "Camera Frustrum On" - else: - self.gs_context.off_camera_frustum() - self._message_text = "Camera Frustrum Off" - - # F toggles face normals - elif symbol == pyglet.window.key.F: - self.render_flags["face_normals"] = not self.render_flags["face_normals"] - if self.render_flags["face_normals"]: - self._message_text = "Face Normals On" - else: - self._message_text = "Face Normals Off" - - # V toggles vertex normals - elif symbol == pyglet.window.key.V: - self.render_flags["vertex_normals"] = not self.render_flags["vertex_normals"] - if self.render_flags["vertex_normals"]: - self._message_text = "Vert Normals On" - else: - self._message_text = "Vert Normals Off" - - # R starts recording frames - elif symbol == pyglet.window.key.R: - if self.viewer_flags["record"]: - self.save_video() - self.set_caption(self.viewer_flags["window_title"]) - else: - # Importing moviepy is very slow and not used very often. Let's delay import. - from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter - - self._video_recorder = FFMPEG_VideoWriter( - filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), - fps=self.viewer_flags["refresh_rate"], - size=self.viewport_size, - ) - self.set_caption("{} (RECORDING)".format(self.viewer_flags["window_title"])) - self.viewer_flags["record"] = not self.viewer_flags["record"] - - # S saves the current frame as an image - elif symbol == pyglet.window.key.S: - self._save_image() - - # T toggles through geom types - # elif symbol == pyglet.window.key.T: - # if self.gs_context.rigid_shown == 'visual': - # self.gs_context.on_rigid('collision') - # self._message_text = "Geom Type: 'collision'" - # elif self.gs_context.rigid_shown == 'collision': - # self.gs_context.on_rigid('sdf') - # self._message_text = "Geom Type: 'sdf'" - # else: - # self.gs_context.on_rigid('visual') - # self._message_text = "Geom Type: 'visual'" - - # D toggles through wireframe modes - elif symbol == pyglet.window.key.D: - if self.render_flags["flip_wireframe"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = True - self.render_flags["all_solid"] = False - self._message_text = "All Wireframe" - elif self.render_flags["all_wireframe"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = True - self._message_text = "All Solid" - elif self.render_flags["all_solid"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = False - self._message_text = "Default Wireframe" - else: - self.render_flags["flip_wireframe"] = True - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = False - self._message_text = "Flip Wireframe" - - # Z resets the camera viewpoint - elif symbol == pyglet.window.key.Z: - self._reset_view() - - # i toggles instruction display - elif symbol == pyglet.window.key.I: - self._display_instr = not self._display_instr - - elif symbol == pyglet.window.key.P: - self._renderer.reload_program() - - if self._message_text is not None: - self._message_opac = 1.0 + self._ticks_till_fade - - return self.viewer_interaction.on_key_press(symbol, modifiers) + self._call_keybind_callback(symbol, modifiers, KeyAction.PRESS) + return self.interaction_plugin.on_key_press(symbol, modifiers) def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a key release.""" - return self.viewer_interaction.on_key_release(symbol, modifiers) + self._held_keys.pop((symbol, modifiers), None) + + self._call_keybind_callback(symbol, modifiers, KeyAction.RELEASE) + return self.interaction_plugin.on_key_release(symbol, modifiers) + + def on_deactivate(self) -> EVENT_HANDLE_STATE: + """Clear held keys when window loses focus.""" + self._held_keys.clear() @staticmethod def _time_event(dt, self): @@ -1032,26 +829,6 @@ def _time_event(dt, self): if self.viewer_flags["rotate"] and not self.viewer_flags["mouse_pressed"]: self._rotate() - # Manage message opacity - if self._message_text is not None: - if self._message_opac > 1.0: - self._message_opac -= 1.0 - else: - self._message_opac *= 0.90 - if self._message_opac < 0.05: - self._message_opac = 1.0 + self._ticks_till_fade - self._message_text = None - - # video saving warning - if self._video_saver is not None: - if self._video_saver.is_alive(): - self._message_text = "Saving video... Please don't exit." - self._message_opac = 1.0 - else: - self._message_text = f"Video saved to {self._video_file_name}" - self._message_opac = self.viewer_flags["refresh_rate"] * 2 - self._video_saver = None - self.on_draw() def _reset_view(self): @@ -1089,8 +866,7 @@ def _get_save_filename(self, file_exts): try: # Importing tkinter is very slow and not used very often. Let's delay import. - from tkinter import Tk - from tkinter import filedialog + from tkinter import Tk, filedialog if root is None: root = Tk() @@ -1232,7 +1008,8 @@ def get_program(self, vertex_shader, fragment_shader, geometry_shader=None, defi def start(self, auto_refresh=True): import pyglet # For some reason, this is necessary if 'pyglet.window.xlib' fails to import... try: - import pyglet.window.xlib, pyglet.display.xlib + import pyglet.display.xlib + import pyglet.window.xlib xlib_exceptions = (pyglet.window.xlib.XlibException, pyglet.display.xlib.NoSuchDisplayException) except ImportError: xlib_exceptions = () @@ -1303,7 +1080,7 @@ def start(self, auto_refresh=True): self._exception = e return else: - raise RuntimeError(f"Unable to initialize an OpenGL 3+ context.") from e + raise RuntimeError("Unable to initialize an OpenGL 3+ context.") from e pyglet.window.xlib._have_utf8 = False confs.insert(0, conf) except (pyglet.window.NoSuchConfigException, pyglet.gl.ContextException) as e: @@ -1313,7 +1090,7 @@ def start(self, auto_refresh=True): self._exception = e return else: - raise RuntimeError(f"Unable to initialize an OpenGL 3+ context.") from e + raise RuntimeError("Unable to initialize an OpenGL 3+ context.") from e if self._run_in_thread: pyglet.clock.schedule_interval(Viewer._time_event, 1.0 / self.viewer_flags["refresh_rate"], self) @@ -1348,7 +1125,7 @@ def start(self, auto_refresh=True): # The viewer can be considered as fully initialized at this point if not self._initialized_event.is_set(): self._initialized_event.set() - + if auto_refresh: while self._is_active: try: @@ -1400,7 +1177,10 @@ def refresh(self): self.flip() def update_on_sim_step(self): - self.viewer_interaction.update_on_sim_step() + # Call HOLD callbacks for all currently held keys + for (symbol, modifiers) in list(self._held_keys.keys()): + self._call_keybind_callback(symbol, modifiers, KeyAction.HOLD) + self.interaction_plugin.update_on_sim_step() def _compute_initial_camera_pose(self): centroid = self.scene.centroid diff --git a/genesis/options/recorders.py b/genesis/options/recorders.py index ddeda88549..79a099af88 100644 --- a/genesis/options/recorders.py +++ b/genesis/options/recorders.py @@ -6,7 +6,6 @@ from .options import Options - IS_PYAV_AVAILABLE = False try: import av @@ -297,3 +296,43 @@ class MPLImagePlot(BasePlotterOptions): """ pass + + +class MPLDepthScatterPlot(BasePlotterOptions): + """ + Live visualization of depth sensor data as a 3D scatter plot using matplotlib. + + The data should be a tuple of (positions, distances) where: + - positions: array-like with shape (N, 3) for (x, y, z) coordinates + - distances: array-like with shape (N,) for depth/distance values + + Parameters + ---------- + title: str + The title of the plot. + window_size: tuple[int, int] + The size of the window in pixels. + save_to_filename: str | None + If provided, the animation will be saved to a file with the given filename. + show_window: bool | None + Whether to show the window. If not provided, it will be set to True if a display is connected, False otherwise. + cmap: str + The colormap to use for depth visualization. Defaults to 'viridis'. + point_size: float + The size of each point in the scatter plot. Defaults to 50. + x_label: str + Label for the x-axis. Defaults to 'X'. + y_label: str + Label for the y-axis. Defaults to 'Y'. + z_label: str + Label for the z-axis. Defaults to 'Z'. + colorbar_label: str + Label for the colorbar. Defaults to 'Distance'. + """ + + cmap: str = "viridis" + point_size: float = 50.0 + x_label: str = "X" + y_label: str = "Y" + z_label: str = "Z" + colorbar_label: str = "Distance" diff --git a/genesis/options/sensors/raycaster.py b/genesis/options/sensors/raycaster.py index b21d41eaf1..0a42e07d99 100644 --- a/genesis/options/sensors/raycaster.py +++ b/genesis/options/sensors/raycaster.py @@ -55,6 +55,21 @@ def ray_starts(self) -> torch.Tensor: # ============================== Generic Patterns ============================== +def _sanitize_rays_to_tensor(rays: Sequence[float]) -> torch.Tensor: + tensor = torch.tensor(rays, dtype=gs.tc_float, device=gs.device) + if tensor.ndim < 2 or tensor.shape[-1] != 3: + gs.raise_exception(f"Rays should have shape (..., 3). Got: {tensor.shape}") + return tensor + + +class RaycastCustomPattern(RaycastPattern): + + def __init__(self, ray_dirs: Sequence[float], ray_starts: Sequence[float]): + self._ray_dirs = _sanitize_rays_to_tensor(ray_dirs) + self._ray_starts = _sanitize_rays_to_tensor(ray_starts) + self._return_shape: tuple[int, ...] = ray_dirs.shape[:-1] + + class GridPattern(RaycastPattern): """ Configuration for grid-based ray casting. diff --git a/genesis/options/viewer_plugins.py b/genesis/options/viewer_plugins.py new file mode 100644 index 0000000000..ab8a10c0c1 --- /dev/null +++ b/genesis/options/viewer_plugins.py @@ -0,0 +1,54 @@ +from .options import Options + + +class ViewerPlugin(Options): + """ + Base class for viewer interaction options. + + All viewer interaction option classes should inherit from this base class. + """ + + +class HelpTextPlugin(ViewerPlugin): + """ + Displays keyboard instructions in the viewer. + """ + + display_instructions: bool = True + font_size: int = 26 + + +class DefaultControlsPlugin(HelpTextPlugin): + """ + Default viewer interaction controls with keyboard shortcuts for recording, changing render modes, etc. + """ + + +class MouseSpringPlugin(HelpTextPlugin): + """ + Options for the interactive viewer plugin that allows mouse-based object manipulation. + """ + + +class MeshPointSelectorPlugin(HelpTextPlugin): + """ + Options for the mesh point selector plugin that allows selecting points on a mesh. + + Parameters + ---------- + sphere_radius : float + The radius of the sphere used to visualize selected points. + sphere_color : tuple[float, float, float, float] + The color of the sphere used to visualize selected points. + hover_color : tuple[float, float, float, float] + The color of the sphere used to visualize the point and normal when hovering over a mesh. + grid_snap : tuple[float, float, float] + Grid snap spacing for each axis (x, y, z). Any negative value disables snapping for that axis. + Default is (-1.0, -1.0, -1.0) which means no snapping. + """ + + sphere_radius: float = 0.005 + sphere_color: tuple = (0.1, 0.3, 1.0, 1.0) + hover_color: tuple = (0.3, 0.5, 1.0, 1.0) + grid_snap: tuple[float, float, float] = (-1.0, -1.0, -1.0) + output_file: str = "selected_points.csv" diff --git a/genesis/options/vis.py b/genesis/options/vis.py index fce8ed0d22..952213d723 100644 --- a/genesis/options/vis.py +++ b/genesis/options/vis.py @@ -3,6 +3,7 @@ import genesis as gs from .options import Options +from .viewer_plugins import DefaultControlsPlugin, ViewerPlugin class ViewerOptions(Options): @@ -33,20 +34,19 @@ class ViewerOptions(Options): The up vector of the camera's extrinsic pose. camera_fov : float The field of view (in degrees) of the camera. - disable_keyboard_shortcuts : bool - Whether to disable all keyboard shortcuts in the viewer. Defaults to False. + viewer_plugin : ViewerPlugin + Viewer plugin that adds interactive functionality to the viewer. """ - res: Optional[tuple] = None - run_in_thread: Optional[bool] = None + res: tuple | None = None + run_in_thread: bool | None = None refresh_rate: int = 60 - max_FPS: Optional[int] = 60 + max_FPS: int | None = 60 camera_pos: tuple = (3.5, 0.5, 2.5) camera_lookat: tuple = (0.0, 0.0, 0.5) camera_up: tuple = (0.0, 0.0, 1.0) camera_fov: float = 40 - enable_interaction: bool = False - disable_keyboard_shortcuts: bool = False + viewer_plugin: ViewerPlugin = DefaultControlsPlugin() class VisOptions(Options): @@ -142,7 +142,7 @@ def __init__(self, **data): f"Unsupported `render_particle_as`: {self.render_particle_as}, must be one of ['sphere', 'tet']" ) - if not self.n_rendered_envs is None: + if self.n_rendered_envs is not None: gs.logger.warning( "Viewer option 'n_rendered_envs' is deprecated and will be removed in future release. Please use " "'rendered_envs_idx' instead." diff --git a/genesis/recorders/plotters.py b/genesis/recorders/plotters.py index 2443a23eac..d7e2946a31 100644 --- a/genesis/recorders/plotters.py +++ b/genesis/recorders/plotters.py @@ -16,10 +16,19 @@ from genesis.options.recorders import ( BasePlotterOptions, LinePlotterMixinOptions, - PyQtLinePlot as PyQtLinePlotterOptions, - MPLLinePlot as MPLLinePlotterOptions, +) +from genesis.options.recorders import ( + MPLDepthScatterPlot as MPLDepthScatterPlotterOptions, +) +from genesis.options.recorders import ( MPLImagePlot as MPLImagePlotterOptions, ) +from genesis.options.recorders import ( + MPLLinePlot as MPLLinePlotterOptions, +) +from genesis.options.recorders import ( + PyQtLinePlot as PyQtLinePlotterOptions, +) from genesis.utils import has_display, tensor_to_array from .base_recorder import Recorder @@ -634,3 +643,138 @@ def cleanup(self): self.ax = None self.image_plot = None self.background = None + + +@register_recording(MPLDepthScatterPlotterOptions) +class MPLDepthScatterPlotter(BaseMPLPlotter): + """ + Live depth sensor visualization using matplotlib 3D scatter plot. + + The data should be a tuple of (positions, distances) where: + - positions: array-like with shape (N, 3) for (x, y, z) coordinates + - distances: array-like with shape (N,) for depth/distance values + """ + + def build(self): + super().build() + + import matplotlib.pyplot as plt + + self.scatter = None + self.colorbar = None + self.positions = None + self.distances = None + + # Create 3D figure + self.fig = plt.figure(figsize=self.figsize) + self.ax = self.fig.add_subplot(111, projection="3d") + self.fig.suptitle(self._options.title) + self.ax.set_xlabel(self._options.x_label) + self.ax.set_ylabel(self._options.y_label) + self.ax.set_zlabel(self._options.z_label) + self.ax.grid(True, alpha=0.3) + + # Initialize with empty scatter (will be set on first data) + self.scatter = self.ax.scatter( + [], [], [], c=[], s=self._options.point_size, cmap=self._options.cmap, edgecolors="none" + ) + self.colorbar = self.fig.colorbar( + self.scatter, ax=self.ax, label=self._options.colorbar_label, shrink=0.5, aspect=5 + ) + + self._show_fig() + + def process(self, data, cur_time): + """Process new position and distance data.""" + # Expect data to be a tuple of (positions, distances) + if not isinstance(data, (tuple, list)) or len(data) != 2: + gs.logger.warning(f"[{type(self).__name__}] Data must be a tuple (positions, distances). Got: {type(data)}") + return + + positions, distances = data + + # Convert to numpy arrays + if isinstance(positions, torch.Tensor): + positions = tensor_to_array(positions) + else: + positions = np.asarray(positions) + + if isinstance(distances, torch.Tensor): + distances = tensor_to_array(distances) + else: + distances = np.asarray(distances) + + # Flatten if needed + if positions.ndim > 2: + positions = positions.reshape(-1, positions.shape[-1]) + distances = distances.flatten() + + # Validate shapes + if positions.ndim != 2 or positions.shape[1] < 3: + gs.logger.warning( + f"[{type(self).__name__}] Positions must have shape (N, 3) or (N, D) where D >= 3. " + f"Got shape {positions.shape}" + ) + return + + if len(positions) != len(distances): + gs.logger.warning( + f"[{type(self).__name__}] Number of positions ({len(positions)}) doesn't match " + f"number of distances ({len(distances)})" + ) + return + + self.positions = positions[:, :3] # use (x, y, z) + self.distances = distances + + super().process(data, cur_time) + + def _update_plot(self): + """Update the 3D scatter plot with new positions and distances.""" + if self.positions is None or self.distances is None: + return + + # Update scatter plot data + self.scatter._offsets3d = (self.positions[:, 0], self.positions[:, 1], self.positions[:, 2]) + self.scatter.set_array(self.distances) + + # Update color limits + vmin, vmax = np.min(self.distances), np.max(self.distances) + current_vmin, current_vmax = self.scatter.get_clim() + if abs(vmin - current_vmin) > 1e-6 or abs(vmax - current_vmax) > 1e-6: + self.scatter.set_clim(vmin, vmax) + + # Update axis limits if this is the first update or limits have changed significantly + if not hasattr(self, "_limits_set") or not self._limits_set: + x_min, x_max = self.positions[:, 0].min(), self.positions[:, 0].max() + y_min, y_max = self.positions[:, 1].min(), self.positions[:, 1].max() + z_min, z_max = self.positions[:, 2].min(), self.positions[:, 2].max() + + # Calculate uniform range for all axes + max_range = max(x_max - x_min, y_max - y_min, z_max - z_min) + padding = max_range * 0.1 or 0.1 + + # Center each axis range + x_center = (x_max + x_min) / 2 + y_center = (y_max + y_min) / 2 + z_center = (z_max + z_min) / 2 + + half_range = (max_range + 2 * padding) / 2 + + self.ax.set_xlim(x_center - half_range, x_center + half_range) + self.ax.set_ylim(y_center - half_range, y_center + half_range) + self.ax.set_zlim(z_center - half_range, z_center + half_range) + self._limits_set = True + + self._lock.acquire() + self.fig.canvas.draw() + self.fig.canvas.flush_events() + self._lock.release() + + def cleanup(self): + super().cleanup() + + self.scatter = None + self.colorbar = None + self.positions = None + self.distances = None diff --git a/genesis/utils/raycast.py b/genesis/utils/raycast.py new file mode 100644 index 0000000000..4779622309 --- /dev/null +++ b/genesis/utils/raycast.py @@ -0,0 +1,121 @@ +import gstaichi as ti + +import genesis as gs + + +@ti.func +def ray_triangle_intersection(ray_start, ray_dir, v0, v1, v2): + """ + Moller-Trumbore ray-triangle intersection. + + Returns: vec4(t, u, v, hit) where hit=1.0 if intersection found, 0.0 otherwise + """ + result = ti.Vector.zero(gs.ti_float, 4) + + edge1 = v1 - v0 + edge2 = v2 - v0 + + # Begin calculating determinant - also used to calculate u parameter + h = ray_dir.cross(edge2) + a = edge1.dot(h) + + # Check all conditions in sequence without early returns + valid = True + + t = gs.ti_float(0.0) + u = gs.ti_float(0.0) + v = gs.ti_float(0.0) + f = gs.ti_float(0.0) + s = ti.Vector.zero(gs.ti_float, 3) + q = ti.Vector.zero(gs.ti_float, 3) + + # If determinant is near zero, ray lies in plane of triangle + if ti.abs(a) < gs.EPS: + valid = False + + if valid: + f = 1.0 / a + s = ray_start - v0 + u = f * s.dot(h) + + if u < 0.0 or u > 1.0: + valid = False + + if valid: + q = s.cross(edge1) + v = f * ray_dir.dot(q) + + if v < 0.0 or u + v > 1.0: + valid = False + + if valid: + # At this stage we can compute t to find out where the intersection point is on the line + t = f * edge2.dot(q) + + # Ray intersection + if t <= gs.EPS: + valid = False + + if valid: + result = ti.math.vec4(t, u, v, 1.0) + + return result + + +@ti.func +def ray_aabb_intersection(ray_start, ray_dir, aabb_min, aabb_max): + """ + Fast ray-AABB intersection test. + Returns the t value of intersection, or -1.0 if no intersection. + """ + result = -1.0 + + # Use the slab method for ray-AABB intersection + sign = ti.select(ray_dir >= 0.0, 1.0, -1.0) + ray_dir = sign * ti.max(ti.abs(ray_dir), gs.EPS) + inv_dir = 1.0 / ray_dir + + t1 = (aabb_min - ray_start) * inv_dir + t2 = (aabb_max - ray_start) * inv_dir + + tmin = ti.min(t1, t2) + tmax = ti.max(t1, t2) + + t_near = ti.max(tmin.x, tmin.y, tmin.z, 0.0) + t_far = ti.min(tmax.x, tmax.y, tmax.z) + + # Check if ray intersects AABB + if t_near <= t_far: + result = t_near + + return result + + +@ti.kernel +def kernel_update_aabbs( + free_verts_state: ti.template(), + fixed_verts_state: ti.template(), + verts_info: ti.template(), + faces_info: ti.template(), + # FIXME: can't import array_class since it is before gs.init + # free_verts_state: array_class.VertsState, + # fixed_verts_state: array_class.VertsState, + # verts_info: array_class.VertsInfo, + # faces_info: array_class.FacesInfo, + aabb_state: ti.template(), +): + for i_b, i_f in ti.ndrange(free_verts_state.pos.shape[1], faces_info.verts_idx.shape[0]): + aabb_state.aabbs[i_b, i_f].min.fill(ti.math.inf) + aabb_state.aabbs[i_b, i_f].max.fill(-ti.math.inf) + + for i in ti.static(range(3)): + i_v = faces_info.verts_idx[i_f][i] + i_fv = verts_info.verts_state_idx[i_v] + if verts_info.is_fixed[i_v]: + pos_v = fixed_verts_state.pos[i_fv] + aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) + aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) + else: + pos_v = free_verts_state.pos[i_fv, i_b] + aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) + aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) diff --git a/genesis/vis/viewer.py b/genesis/vis/viewer.py index 3a0c24d4c8..bf3444d0d7 100644 --- a/genesis/vis/viewer.py +++ b/genesis/vis/viewer.py @@ -1,6 +1,6 @@ +import importlib import os import threading -import importlib from typing import TYPE_CHECKING import numpy as np @@ -9,11 +9,11 @@ import genesis as gs import genesis.utils.geom as gu - from genesis.ext import pyrender +from genesis.ext.pyrender.interaction import KeyAction, Keybind from genesis.repr_base import RBC -from genesis.utils.tools import Rate from genesis.utils.misc import redirect_libc_stderr, tensor_to_array +from genesis.utils.tools import Rate if TYPE_CHECKING: from genesis.options.vis import ViewerOptions @@ -40,16 +40,12 @@ def __init__(self, options: "ViewerOptions", context): self._camera_init_lookat = np.asarray(options.camera_lookat, dtype=gs.np_float) self._camera_up = np.asarray(options.camera_up, dtype=gs.np_float) self._camera_fov = options.camera_fov - self._enable_interaction = options.enable_interaction - self._disable_keyboard_shortcuts = options.disable_keyboard_shortcuts + self._viewer_plugin = options.viewer_plugin # Validate viewer options if any(e.shape != (3,) for e in (self._camera_init_pos, self._camera_init_lookat, self._camera_up)): gs.raise_exception("ViewerOptions.camera_(pos|lookat|up) must be sequences of length 3.") - if options.enable_interaction and gs.backend != gs.cpu: - gs.logger.warning("Interaction code is slow on GPU. Switch to CPU backend or disable interaction.") - self._pyrender_viewer = None self.context = context @@ -100,8 +96,7 @@ def build(self, scene): shadow=self.context.shadow, plane_reflection=self.context.plane_reflection, env_separate_rigid=self.context.env_separate_rigid, - enable_interaction=self._enable_interaction, - disable_keyboard_shortcuts=self._disable_keyboard_shortcuts, + plugin_options=self._viewer_plugin, viewer_flags={ "window_title": f"Genesis {gs.__version__}", "refresh_rate": self._refresh_rate, @@ -265,6 +260,46 @@ def update_following(self): else: self.set_camera_pose(pos=camera_pos, lookat=self._follow_lookat) + def register_keybinds(self, keybinds: tuple[Keybind]) -> None: + """ + Register a callback function to be called when a key is pressed. + + Parameters + ---------- + keybinds : tuple[gs.ext.pyrender.interaction.keybindings.Keybind] + The Keybind objects to register. See Keybind documentation for usage. + """ + self._pyrender_viewer.register_keybinds(keybinds) + + def remap_keybind( + self, + keybind_name: str, + new_key_code: int, + new_modifiers: int | None = None, + new_key_action: KeyAction = KeyAction.PRESS, + ) -> None: + """ + Remap an existing keybind to a new key combination. + Use `None` for any parameter you do not wish to change. + + Parameters + ---------- + keybind_name : str + The name of the keybind to remap. + new_key_code : int + The new key code from pyglet. + new_modifiers : int | None, optional + The new modifier keys pressed. + new_key_action : KeyAction, optional + The new type of key action (press, hold, release). + """ + self._pyrender_viewer._keybindings.rebind( + keybind_name, + new_key_code, + new_modifiers, + new_key_action, + ) + # ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------ diff --git a/tests/test_render.py b/tests/test_render.py index 9a1b4869fb..59277f68a4 100644 --- a/tests/test_render.py +++ b/tests/test_render.py @@ -9,7 +9,6 @@ import pyglet import pytest import torch -import OpenGL.error import genesis as gs import genesis.utils.geom as gu @@ -1152,28 +1151,6 @@ def on_key_press(self, symbol: int, modifiers: int): assert f.read() == png_snapshot -@pytest.mark.required -@pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) -@pytest.mark.skipif(not IS_INTERACTIVE_VIEWER_AVAILABLE, reason="Interactive viewer not supported on this platform.") -@pytest.mark.xfail(sys.platform == "win32", raises=OpenGL.error.Error, reason="Invalid OpenGL context.") -def test_interactive_viewer_disable_keyboard_shortcuts(): - """Test that keyboard shortcuts can be disabled in the interactive viewer.""" - - # Test with keyboard shortcuts DISABLED - scene = gs.Scene( - viewer_options=gs.options.ViewerOptions( - disable_keyboard_shortcuts=True, - ), - show_viewer=True, - ) - scene.build() - pyrender_viewer = scene.visualizer.viewer._pyrender_viewer - assert pyrender_viewer.is_active - - # Verify the flag is set correctly - assert pyrender_viewer._disable_keyboard_shortcuts is True - - @pytest.mark.required @pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) def test_camera_gimbal_lock_singularity(renderer, show_viewer):