From 424c23a74fedfec3c7bc082a4883b64009e3b82f Mon Sep 17 00:00:00 2001 From: Trinity Chung Date: Fri, 2 Jan 2026 16:00:55 -0500 Subject: [PATCH 1/5] clean rebase, initial viewer plugin --- examples/interactive_viewer/mouse_spring.py | 43 +++ examples/keyboard_teleop.py | 309 ++++++++++-------- genesis/engine/sensors/raycaster.py | 121 +------ genesis/ext/pyrender/interaction/__init__.py | 10 + .../pyrender/interaction/base_interaction.py | 126 +++++++ .../ext/pyrender/interaction/keybindings.py | 28 ++ .../ext/pyrender/interaction/mouse_spring.py | 95 ------ .../pyrender/interaction/plugins/__init__.py | 9 + .../interaction/plugins/default_keyboard.py | 213 ++++++++++++ .../interaction/plugins/mesh_selector.py | 206 ++++++++++++ .../mouse_interaction.py} | 219 +++++++------ .../interaction/plugins/viewer_controls.py | 219 +++++++++++++ .../pyrender/interaction/utils/__init__.py | 4 + .../pyrender/interaction/{ => utils}/aabb.py | 0 .../pyrender/interaction/{ => utils}/ray.py | 0 .../pyrender/interaction/utils/raycaster.py | 282 ++++++++++++++++ .../pyrender/interaction/{ => utils}/vec3.py | 0 .../interaction/viewer_interaction_base.py | 52 --- genesis/ext/pyrender/viewer.py | 281 +++------------- genesis/options/recorders.py | 41 ++- genesis/options/sensors/raycaster.py | 15 + genesis/options/viewer_interactions.py | 56 ++++ genesis/options/vis.py | 14 +- genesis/recorders/plotters.py | 148 ++++++++- genesis/utils/raycast.py | 121 +++++++ genesis/vis/viewer.py | 14 +- tests/test_render.py | 23 -- 27 files changed, 1863 insertions(+), 786 deletions(-) create mode 100644 examples/interactive_viewer/mouse_spring.py create mode 100644 genesis/ext/pyrender/interaction/__init__.py create mode 100644 genesis/ext/pyrender/interaction/base_interaction.py create mode 100644 genesis/ext/pyrender/interaction/keybindings.py delete mode 100644 genesis/ext/pyrender/interaction/mouse_spring.py create mode 100644 genesis/ext/pyrender/interaction/plugins/__init__.py create mode 100644 genesis/ext/pyrender/interaction/plugins/default_keyboard.py create mode 100644 genesis/ext/pyrender/interaction/plugins/mesh_selector.py rename genesis/ext/pyrender/interaction/{viewer_interaction.py => plugins/mouse_interaction.py} (54%) create mode 100644 genesis/ext/pyrender/interaction/plugins/viewer_controls.py create mode 100644 genesis/ext/pyrender/interaction/utils/__init__.py rename genesis/ext/pyrender/interaction/{ => utils}/aabb.py (100%) rename genesis/ext/pyrender/interaction/{ => utils}/ray.py (100%) create mode 100644 genesis/ext/pyrender/interaction/utils/raycaster.py rename genesis/ext/pyrender/interaction/{ => utils}/vec3.py (100%) delete mode 100644 genesis/ext/pyrender/interaction/viewer_interaction_base.py create mode 100644 genesis/options/viewer_interactions.py create mode 100644 genesis/utils/raycast.py diff --git a/examples/interactive_viewer/mouse_spring.py b/examples/interactive_viewer/mouse_spring.py new file mode 100644 index 0000000000..1421bbcaab --- /dev/null +++ b/examples/interactive_viewer/mouse_spring.py @@ -0,0 +1,43 @@ +import math + +import genesis as gs + +if __name__ == "__main__": + + gs.init(backend=gs.gpu) + + scene = gs.Scene( + viewer_options=gs.options.ViewerOptions( + camera_pos=(3.5, 0.0, 2.5), + camera_lookat=(0.0, 0.0, 0.5), + camera_fov=40, + viewer_plugin=gs.options.viewer_interactions.MouseSpringViewerPlugin(), + ), + profiling_options=gs.options.ProfilingOptions( + show_FPS=False, + ), + show_viewer=True, + ) + + scene.add_entity(gs.morphs.Plane()) + + sphere = scene.add_entity( + morph=gs.morphs.Sphere( + pos=(-0.3, -0.3, 0), + radius=0.1, + ), + ) + for i in range(20): + angle = i * (2 * math.pi / 20) + radius = 0.5 + i * 0.1 + cube = scene.add_entity( + morph=gs.morphs.Box( + pos=(radius * math.cos(angle), radius * math.sin(angle), 0.1 + i * 0.1), + size=(0.2, 0.2, 0.2), + ), + ) + + scene.build() + + while True: + scene.step() diff --git a/examples/keyboard_teleop.py b/examples/keyboard_teleop.py index d741b1f0b0..4af79140ea 100644 --- a/examples/keyboard_teleop.py +++ b/examples/keyboard_teleop.py @@ -11,48 +11,176 @@ u - Reset Scene space - Press to close gripper, release to open gripper esc - Quit + +Plus all default viewer controls (press 'i' to see them) """ -import os import random -import threading -import genesis as gs import numpy as np -from pynput import keyboard +import pyglet from scipy.spatial.transform import Rotation as R +from typing_extensions import override - -class KeyboardDevice: - def __init__(self): - self.pressed_keys = set() - self.lock = threading.Lock() - self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) - - def start(self): - self.listener.start() - - def stop(self): - try: - self.listener.stop() - except NotImplementedError: - # Dummy backend does not implement stop - pass - self.listener.join() - - def on_press(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.add(key) - - def on_release(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.discard(key) - - def get_cmd(self): - return self.pressed_keys +import genesis as gs +from genesis.ext.pyrender.interaction import register_viewer_plugin +from genesis.ext.pyrender.interaction.plugins.viewer_controls import ViewerDefaultControls +from genesis.options.viewer_interactions import ViewerDefaultControls as ViewerDefaultControlsOptions + + +class FrankaTeleopOptions(ViewerDefaultControlsOptions): + """Options for Franka teleoperation plugin.""" + + pass + + +@register_viewer_plugin(FrankaTeleopOptions) +class FrankaTeleopPlugin(ViewerDefaultControls): + """ + Viewer plugin for teleoperating Franka robot with keyboard. + Extends ViewerDefaultControls to add robot-specific controls. + """ + + def __init__(self, viewer, options=None, camera=None, scene=None, viewport_size=None): + super().__init__(viewer, options, camera, scene, viewport_size) + + # Robot control state + self.robot = None + self.target_entity = None + self.cube_entity = None + self.target_pos = None + self.target_R = None + self.robot_init_pos = np.array([0.5, 0, 0.55]) + self.robot_init_R = R.from_euler("y", np.pi) + + # Control parameters + self.dpos = 0.002 + self.drot = 0.01 + self.is_close_gripper = False + + # keybindings + self.keybindings.extend( + dict( + move_forward=pyglet.window.key.UP, + move_backward=pyglet.window.key.DOWN, + move_left=pyglet.window.key.LEFT, + move_right=pyglet.window.key.RIGHT, + move_up=pyglet.window.key.N, + move_down=pyglet.window.key.M, + rotate_ccw=pyglet.window.key.J, + rotate_cw=pyglet.window.key.K, + reset_scene=pyglet.window.key.U, + close_gripper=pyglet.window.key.SPACE, + ) + ) + self._instr_texts = ( + ["> [i]: show keyboard instructions"], + ["< [i]: hide keyboard instructions"] + + self.keybindings.as_instruction_texts(padding=3, exclude=("toggle_keyboard_instructions")), + ) + + def set_entities(self, robot, target_entity, cube_entity): + """Set references to scene entities.""" + self.robot = robot + self.target_entity = target_entity + self.cube_entity = cube_entity + + # Initialize target pose + self.target_pos = self.robot_init_pos.copy() + self.target_R = self.robot_init_R + + # Get DOF indices + n_dofs = robot.n_dofs + self.motors_dof = np.arange(n_dofs - 2) + self.fingers_dof = np.arange(n_dofs - 2, n_dofs) + self.ee_link = robot.get_link("hand") + + # Reset to initial pose + self._reset_robot() + + def _reset_robot(self): + """Reset robot and cube to initial positions.""" + if self.robot is None: + return + + self.target_pos = self.robot_init_pos.copy() + self.target_R = self.robot_init_R + target_quat = self.target_R.as_quat(scalar_first=True) + self.target_entity.set_qpos(np.concatenate([self.target_pos, target_quat])) + q = self.robot.inverse_kinematics(link=self.ee_link, pos=self.target_pos, quat=target_quat) + self.robot.set_qpos(q[:-2], self.motors_dof) + + # Randomize cube position + self.cube_entity.set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) + self.cube_entity.set_quat(R.from_euler("z", random.uniform(0, np.pi * 2)).as_quat(scalar_first=True)) + + @override + def on_key_press(self, symbol: int, modifiers: int): + # First handle default viewer controls + result = super().on_key_press(symbol, modifiers) + + if self.robot is None: + return result + + # Handle teleoperation controls + if symbol == pyglet.window.key.UP: + self.target_pos[0] -= self.dpos + elif symbol == pyglet.window.key.DOWN: + self.target_pos[0] += self.dpos + elif symbol == pyglet.window.key.RIGHT: + self.target_pos[1] += self.dpos + elif symbol == pyglet.window.key.LEFT: + self.target_pos[1] -= self.dpos + elif symbol == pyglet.window.key.N: + self.target_pos[2] += self.dpos + elif symbol == pyglet.window.key.M: + self.target_pos[2] -= self.dpos + elif symbol == pyglet.window.key.J: + self.target_R = R.from_euler("z", self.drot) * self.target_R + elif symbol == pyglet.window.key.K: + self.target_R = R.from_euler("z", -self.drot) * self.target_R + elif symbol == pyglet.window.key.U: + self._reset_robot() + elif symbol == pyglet.window.key.SPACE: + self.is_close_gripper = True + + return result + + @override + def on_key_release(self, symbol: int, modifiers: int): + result = super().on_key_release(symbol, modifiers) + + if symbol == pyglet.window.key.SPACE: + self.is_close_gripper = False + + return result + + @override + def update_on_sim_step(self): + """Update robot control every simulation step.""" + super().update_on_sim_step() + + if self.robot is None: + return + + # Update target entity visualization + target_quat = self.target_R.as_quat(scalar_first=True) + self.target_entity.set_qpos(np.concatenate([self.target_pos, target_quat])) + + # Control arm with inverse kinematics + q, err = self.robot.inverse_kinematics( + link=self.ee_link, pos=self.target_pos, quat=target_quat, return_error=True + ) + self.robot.control_dofs_position(q[:-2], self.motors_dof) + + # Control gripper + if self.is_close_gripper: + self.robot.control_dofs_force(np.array([-1.0, -1.0]), self.fingers_dof) + else: + self.robot.control_dofs_force(np.array([1.0, 1.0]), self.fingers_dof) -def build_scene(): +if __name__ == "__main__": ########################## init ########################## gs.init(precision="32", logging_level="info", backend=gs.cpu) np.set_printoptions(precision=7, suppress=True) @@ -74,25 +202,26 @@ def build_scene(): camera_lookat=(0.2, 0.0, 0.1), camera_fov=50, max_FPS=60, + viewer_plugin=FrankaTeleopOptions(), ), show_viewer=True, show_FPS=False, ) ########################## entities ########################## - entities = dict() - entities["plane"] = scene.add_entity( + plane = scene.add_entity( gs.morphs.Plane(), ) - entities["robot"] = scene.add_entity( + robot = scene.add_entity( material=gs.materials.Rigid(gravity_compensation=1), morph=gs.morphs.MJCF( file="xml/franka_emika_panda/panda.xml", euler=(0, 0, 0), ), ) - entities["cube"] = scene.add_entity( + + cube = scene.add_entity( material=gs.materials.Rigid(rho=300), morph=gs.morphs.Box( pos=(0.5, 0.0, 0.07), @@ -101,7 +230,7 @@ def build_scene(): surface=gs.surfaces.Default(color=(0.5, 1, 0.5)), ) - entities["target"] = scene.add_entity( + target = scene.add_entity( gs.morphs.Mesh( file="meshes/axis.obj", scale=0.15, @@ -113,34 +242,9 @@ def build_scene(): ########################## build ########################## scene.build() - return scene, entities - - -def run_sim(scene, entities, clients): - robot = entities["robot"] - target_entity = entities["target"] - - robot_init_pos = np.array([0.5, 0, 0.55]) - robot_init_R = R.from_euler("y", np.pi) - target_pos = robot_init_pos.copy() - target_R = robot_init_R - - n_dofs = robot.n_dofs - motors_dof = np.arange(n_dofs - 2) - fingers_dof = np.arange(n_dofs - 2, n_dofs) - ee_link = robot.get_link("hand") - - def reset_scene(): - nonlocal target_pos, target_R - target_pos = robot_init_pos.copy() - target_R = robot_init_R - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) - q = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat) - robot.set_qpos(q[:-2], motors_dof) - - entities["cube"].set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) - entities["cube"].set_quat(R.from_euler("z", random.uniform(0, np.pi * 2)).as_quat(scalar_first=True)) + # Set up the teleoperation plugin with entity references + teleop_plugin = scene.viewer.viewer_interaction + teleop_plugin.set_entities(robot, target, cube) print("\nKeyboard Controls:") print("↑\t- Move Forward (North)") @@ -153,75 +257,8 @@ def reset_scene(): print("k\t- Rotate Clockwise") print("u\t- Reset Scene") print("space\t- Press to close gripper, release to open gripper") - print("esc\t- Quit") - - # reset scen before starting teleoperation - reset_scene() - - # start teleoperation - stop = False - while not stop: - pressed_keys = clients["keyboard"].pressed_keys.copy() - - # reset scene: - reset_flag = False - reset_flag |= keyboard.KeyCode.from_char("u") in pressed_keys - if reset_flag: - reset_scene() - - # stop teleoperation - stop = keyboard.Key.esc in pressed_keys - - # get ee target pose - is_close_gripper = False - dpos = 0.002 - drot = 0.01 - for key in pressed_keys: - if key == keyboard.Key.up: - target_pos[0] -= dpos - elif key == keyboard.Key.down: - target_pos[0] += dpos - elif key == keyboard.Key.right: - target_pos[1] += dpos - elif key == keyboard.Key.left: - target_pos[1] -= dpos - elif key == keyboard.KeyCode.from_char("n"): - target_pos[2] += dpos - elif key == keyboard.KeyCode.from_char("m"): - target_pos[2] -= dpos - elif key == keyboard.KeyCode.from_char("j"): - target_R = R.from_euler("z", drot) * target_R - elif key == keyboard.KeyCode.from_char("k"): - target_R = R.from_euler("z", -drot) * target_R - elif key == keyboard.Key.space: - is_close_gripper = True - - # control arm - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) - q, _err = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat, return_error=True) - robot.control_dofs_position(q[:-2], motors_dof) - - # control gripper - if is_close_gripper: - robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) - else: - robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) + print("\nPress 'i' in the viewer to see all keyboard controls") + ########################## run simulation ########################## + while True: scene.step() - - if "PYTEST_VERSION" in os.environ: - break - - -def main(): - clients = dict() - clients["keyboard"] = KeyboardDevice() - clients["keyboard"].start() - - scene, entities = build_scene() - run_sim(scene, entities, clients) - - -if __name__ == "__main__": - main() diff --git a/genesis/engine/sensors/raycaster.py b/genesis/engine/sensors/raycaster.py index e112b9998e..7922e93357 100644 --- a/genesis/engine/sensors/raycaster.py +++ b/genesis/engine/sensors/raycaster.py @@ -8,11 +8,13 @@ import genesis as gs import genesis.utils.array_class as array_class +from genesis.engine.bvh import AABB, LBVH, STACK_SIZE from genesis.options.sensors import ( Raycaster as RaycasterOptions, +) +from genesis.options.sensors import ( RaycastPattern, ) -from genesis.engine.bvh import AABB, LBVH, STACK_SIZE from genesis.utils.geom import ( ti_normalize, ti_transform_by_quat, @@ -21,6 +23,7 @@ transform_by_trans_quat, ) from genesis.utils.misc import concat_with_tensor, make_tensor_field +from genesis.utils.raycast import kernel_update_aabbs, ray_aabb_intersection, ray_triangle_intersection from genesis.vis.rasterizer_context import RasterizerContext from .base_sensor import ( @@ -36,119 +39,6 @@ from genesis.utils.ring_buffer import TensorRingBuffer -@ti.func -def ray_triangle_intersection(ray_start, ray_dir, v0, v1, v2): - """ - Moller-Trumbore ray-triangle intersection. - - Returns: vec4(t, u, v, hit) where hit=1.0 if intersection found, 0.0 otherwise - """ - result = ti.Vector.zero(gs.ti_float, 4) - - edge1 = v1 - v0 - edge2 = v2 - v0 - - # Begin calculating determinant - also used to calculate u parameter - h = ray_dir.cross(edge2) - a = edge1.dot(h) - - # Check all conditions in sequence without early returns - valid = True - - t = gs.ti_float(0.0) - u = gs.ti_float(0.0) - v = gs.ti_float(0.0) - f = gs.ti_float(0.0) - s = ti.Vector.zero(gs.ti_float, 3) - q = ti.Vector.zero(gs.ti_float, 3) - - # If determinant is near zero, ray lies in plane of triangle - if ti.abs(a) < gs.EPS: - valid = False - - if valid: - f = 1.0 / a - s = ray_start - v0 - u = f * s.dot(h) - - if u < 0.0 or u > 1.0: - valid = False - - if valid: - q = s.cross(edge1) - v = f * ray_dir.dot(q) - - if v < 0.0 or u + v > 1.0: - valid = False - - if valid: - # At this stage we can compute t to find out where the intersection point is on the line - t = f * edge2.dot(q) - - # Ray intersection - if t <= gs.EPS: - valid = False - - if valid: - result = ti.math.vec4(t, u, v, 1.0) - - return result - - -@ti.func -def ray_aabb_intersection(ray_start, ray_dir, aabb_min, aabb_max): - """ - Fast ray-AABB intersection test. - Returns the t value of intersection, or -1.0 if no intersection. - """ - result = -1.0 - - # Use the slab method for ray-AABB intersection - sign = ti.select(ray_dir >= 0.0, 1.0, -1.0) - ray_dir = sign * ti.max(ti.abs(ray_dir), gs.EPS) - inv_dir = 1.0 / ray_dir - - t1 = (aabb_min - ray_start) * inv_dir - t2 = (aabb_max - ray_start) * inv_dir - - tmin = ti.min(t1, t2) - tmax = ti.max(t1, t2) - - t_near = ti.max(tmin.x, tmin.y, tmin.z, 0.0) - t_far = ti.min(tmax.x, tmax.y, tmax.z) - - # Check if ray intersects AABB - if t_near <= t_far: - result = t_near - - return result - - -@ti.kernel -def kernel_update_aabbs( - free_verts_state: array_class.VertsState, - fixed_verts_state: array_class.VertsState, - verts_info: array_class.VertsInfo, - faces_info: array_class.FacesInfo, - aabb_state: ti.template(), -): - for i_b, i_f in ti.ndrange(free_verts_state.pos.shape[1], faces_info.verts_idx.shape[0]): - aabb_state.aabbs[i_b, i_f].min.fill(ti.math.inf) - aabb_state.aabbs[i_b, i_f].max.fill(-ti.math.inf) - - for i in ti.static(range(3)): - i_v = faces_info.verts_idx[i_f][i] - i_fv = verts_info.verts_state_idx[i_v] - if verts_info.is_fixed[i_v]: - pos_v = fixed_verts_state.pos[i_fv] - aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) - aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) - else: - pos_v = free_verts_state.pos[i_fv, i_b] - aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) - aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) - - @ti.kernel def kernel_cast_rays( fixed_verts_state: array_class.VertsState, @@ -195,8 +85,7 @@ def kernel_cast_rays( ray_dir_local = ti.math.vec3(ray_directions[i_p, 0], ray_directions[i_p, 1], ray_directions[i_p, 2]) ray_direction_world = ti_normalize(ti_transform_by_quat(ray_dir_local, link_quat), gs.EPS) - # --- 2. BVH Traversal --- - # FIXME: this duplicates the logic in LBVH.query() which also does traversal + # --- 2. BVH Traversal for ray intersection --- max_range = max_ranges[i_s] hit_face = -1 diff --git a/genesis/ext/pyrender/interaction/__init__.py b/genesis/ext/pyrender/interaction/__init__.py new file mode 100644 index 0000000000..2bb9b27154 --- /dev/null +++ b/genesis/ext/pyrender/interaction/__init__.py @@ -0,0 +1,10 @@ +from .base_interaction import ( + EVENT_HANDLE_STATE, + EVENT_HANDLED, + VIEWER_PLUGIN_MAP, + BaseViewerInteraction, + register_viewer_plugin, +) +from .plugins.mesh_selector import MeshPointSelectorPlugin +from .plugins.mouse_interaction import MouseSpringViewerPlugin +from .plugins.viewer_controls import ViewerDefaultControls diff --git a/genesis/ext/pyrender/interaction/base_interaction.py b/genesis/ext/pyrender/interaction/base_interaction.py new file mode 100644 index 0000000000..15082e2004 --- /dev/null +++ b/genesis/ext/pyrender/interaction/base_interaction.py @@ -0,0 +1,126 @@ +from typing import TYPE_CHECKING, Literal, Type + +import numpy as np + +from .utils import Ray, Vec3 + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + from genesis.options.viewer_interactions import ViewerInteraction as ViewerPluginOptions + + +EVENT_HANDLE_STATE = Literal[True] | None +EVENT_HANDLED: Literal[True] = True + +# Global map from options class to viewer plugin class +VIEWER_PLUGIN_MAP: dict[Type["ViewerPluginOptions"], Type["BaseViewerInteraction"]] = {} + + +def register_viewer_plugin(options_cls: Type["ViewerPluginOptions"]): + """ + Decorator to register a viewer plugin class with its corresponding options class. + + Parameters + ---------- + options_cls : Type[ViewerPluginOptions] + The options class that configures this viewer plugin. + + Returns + ------- + Callable + The decorator function that registers the plugin class. + + Example + ------- + @register_viewer_plugin(ViewerInteractionOptions) + class ViewerInteraction(ViewerInteractionBase): + ... + """ + def _impl(plugin_cls: Type["BaseViewerInteraction"]): + VIEWER_PLUGIN_MAP[options_cls] = plugin_cls + return plugin_cls + return _impl + +# Note: Viewer window is based on pyglet.window.Window, mouse events are defined in pyglet.window.BaseWindow + +class BaseViewerInteraction(): + """ + Base class for handling pyglet.window.Window events. + """ + + def __init__( + self, + viewer, + options: "ViewerPluginOptions", + camera: "Node", + scene: "Scene", + viewport_size: tuple[int, int], + ): + self.viewer = viewer + self.options: "ViewerPluginOptions" = options + self.camera: 'Node' = camera + self.scene: 'Scene' = scene + self.viewport_size: tuple[int, int] = viewport_size + + self.camera_yfov: float = camera.camera.yfov + self.tan_half_fov: float = np.tan(0.5 * self.camera_yfov) + + def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + pass + + def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: + self.viewport_size = (width, height) + self.tan_half_fov = np.tan(0.5 * self.camera_yfov) + + def update_on_sim_step(self) -> None: + pass + + def on_draw(self) -> None: + pass + + def on_close(self) -> None: + pass + + def _screen_position_to_ray(self, x: float, y: float) -> Ray: + # convert screen position to ray + x = x - 0.5 * self.viewport_size[0] + y = y - 0.5 * self.viewport_size[1] + x = 2.0 * x / self.viewport_size[1] * self.tan_half_fov + y = 2.0 * y / self.viewport_size[1] * self.tan_half_fov + + # Note: ignoring pixel aspect ratio + + mtx = self.camera.matrix + position = Vec3.from_array(mtx[:3, 3]) + forward = Vec3.from_array(-mtx[:3, 2]) + right = Vec3.from_array(mtx[:3, 0]) + up = Vec3.from_array(mtx[:3, 1]) + + direction = forward + right * x + up * y + return Ray(position, direction) + + def _get_camera_forward(self) -> Vec3: + mtx = self.camera.matrix + return Vec3.from_array(-mtx[:3, 2]) + + def _get_camera_ray(self) -> Ray: + mtx = self.camera.matrix + position = Vec3.from_array(mtx[:3, 3]) + forward = Vec3.from_array(-mtx[:3, 2]) + return Ray(position, forward) \ No newline at end of file diff --git a/genesis/ext/pyrender/interaction/keybindings.py b/genesis/ext/pyrender/interaction/keybindings.py new file mode 100644 index 0000000000..3cd7ba7990 --- /dev/null +++ b/genesis/ext/pyrender/interaction/keybindings.py @@ -0,0 +1,28 @@ +class Keybindings: + def __init__(self, map: dict[str, int] = {}, **kwargs: dict[str, int]): + self._map: dict[str, int] = {**map, **kwargs} + + def __getattr__(self, name: str) -> int: + if name in self._map: + return self._map[name] + raise AttributeError(f"Action '{name}' not found in keybindings.") + + def as_instruction_texts(self, padding, exclude: tuple[str]) -> list[str]: + from pyglet.window.key import symbol_string + + width = 4 + padding + return [ + f"{'[' + symbol_string(self._map[action]).lower():>{width}}]: " + + action.replace('_', ' ') for action in self._map.keys() if action not in exclude + ] + + def extend(self, mapping: dict[str, int], replace_only: bool = False) -> None: + from pyglet.window.key import symbol_string + + current_keys = self._map.keys() + for action, key in mapping.items(): + if replace_only and action not in self._map: + raise KeyError(f"Action '{action}' not found. Available actions: {list(self._map.keys())}") + if key in current_keys: + raise ValueError(f"Key '{symbol_string(key)}' is already assigned to another action.") + self._map[action] = key \ No newline at end of file diff --git a/genesis/ext/pyrender/interaction/mouse_spring.py b/genesis/ext/pyrender/interaction/mouse_spring.py deleted file mode 100644 index f7b53a2d33..0000000000 --- a/genesis/ext/pyrender/interaction/mouse_spring.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import TYPE_CHECKING - -import torch - -from .vec3 import Pose, Quat, Vec3 - -if TYPE_CHECKING: - from genesis.engine.entities.rigid_entity.rigid_link import RigidLink - - -MOUSE_SPRING_POSITION_CORRECTION_FACTOR = 1.0 -MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR = 1.0 - - -class MouseSpring: - def __init__(self) -> None: - self.held_link: "RigidLink | None" = None - self.held_point_in_local: Vec3 | None = None - self.prev_control_point: Vec3 | None = None - - def attach(self, picked_link: "RigidLink", control_point: Vec3) -> None: - # for now, we just pick the first geometry - self.held_link = picked_link - pose: Pose = Pose.from_link(self.held_link) - self.held_point_in_local = pose.inverse_transform_point(control_point) - self.prev_control_point = control_point - - def detach(self) -> None: - self.held_link = None - - def apply_force(self, control_point: Vec3, delta_time: float) -> None: - # note when threaded: apply_force is called before attach! - # note2: that was before we added a lock to ViewerInteraction; this migth be fixed now - if not self.held_link: - return - - self.prev_control_point = control_point - - # do simple force on COM only: - link: "RigidLink" = self.held_link - lin_vel: Vec3 = Vec3.from_tensor(link.get_vel()) - ang_vel: Vec3 = Vec3.from_tensor(link.get_ang()) - link_pose: Pose = Pose.from_link(link) - held_point_in_world: Vec3 = link_pose.transform_point(self.held_point_in_local) - - # note: we should assert earlier that link inertial_pos/quat are not None - # todo: verify inertial_pos/quat are stored in local frame - link_T_principal: Pose = Pose(Vec3.from_arraylike(link.inertial_pos), Quat.from_arraylike(link.inertial_quat)) - world_T_principal: Pose = link_pose * link_T_principal - - # for non-spherical inertia - arm_in_principal: Vec3 = link_T_principal.inverse_transform_point(self.held_point_in_local) - arm_in_world: Vec3 = world_T_principal.rot * arm_in_principal # for spherical inertia - - pos_err_v: Vec3 = control_point - held_point_in_world - inv_mass: float = float(1.0 / link.get_mass() if link.get_mass() > 0.0 else 0.0) - inv_spherical_inertia: float = float(1.0 / link.inertial_i[0, 0] if link.inertial_i[0, 0] > 0.0 else 0.0) - - inv_dt: float = 1.0 / delta_time - tau: float = MOUSE_SPRING_POSITION_CORRECTION_FACTOR - damp: float = MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR - - total_impulse: Vec3 = Vec3.zero() - total_torque_impulse: Vec3 = Vec3.zero() - - for i in range(3 * 4): - body_point_vel: Vec3 = lin_vel + ang_vel.cross(arm_in_world) - vel_err_v: Vec3 = Vec3.zero() - body_point_vel - - dir: Vec3 = Vec3.zero() - dir.v[i % 3] = 1.0 - pos_err: float = dir.dot(pos_err_v) - vel_err: float = dir.dot(vel_err_v) - error: float = tau * pos_err * inv_dt + damp * vel_err - - arm_x_dir: Vec3 = arm_in_world.cross(dir) - virtual_mass: float = 1.0 / (inv_mass + arm_x_dir.sqr_magnitude() * inv_spherical_inertia + 1e-24) - impulse: float = error * virtual_mass - - lin_vel += impulse * inv_mass * dir - ang_vel += impulse * inv_spherical_inertia * arm_x_dir - total_impulse.v[i % 3] += impulse - total_torque_impulse += impulse * arm_x_dir - - # Apply the new force - total_force = total_impulse * inv_dt - total_torque = total_torque_impulse * inv_dt - force_tensor: torch.Tensor = total_force.as_tensor()[None] - torque_tensor: torch.Tensor = total_torque.as_tensor()[None] - link.solver.apply_links_external_force(force_tensor, (link.idx,), ref="link_com", local=False) - link.solver.apply_links_external_torque(torque_tensor, (link.idx,), ref="link_com", local=False) - - @property - def is_attached(self) -> bool: - return self.held_link is not None diff --git a/genesis/ext/pyrender/interaction/plugins/__init__.py b/genesis/ext/pyrender/interaction/plugins/__init__.py new file mode 100644 index 0000000000..4d9ca6c59c --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/__init__.py @@ -0,0 +1,9 @@ +from .mesh_selector import MeshPointSelectorPlugin +from .mouse_interaction import MouseSpringViewerPlugin +from .viewer_controls import ViewerDefaultControls + +__all__ = [ + "ViewerDefaultControls", + "MeshPointSelectorPlugin", + "MouseSpringViewerPlugin", +] diff --git a/genesis/ext/pyrender/interaction/plugins/default_keyboard.py b/genesis/ext/pyrender/interaction/plugins/default_keyboard.py new file mode 100644 index 0000000000..9c211e84e3 --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/default_keyboard.py @@ -0,0 +1,213 @@ +import os +from typing import TYPE_CHECKING + +import numpy as np +import pyglet +from typing_extensions import override + +import genesis as gs + +from ...constants import TEXT_PADDING +from ..base_interaction import EVENT_HANDLE_STATE, BaseViewerInteraction + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + + +class ViewerControls(BaseViewerInteraction): + """ + Default keyboard controls for the Genesis viewer. + + This plugin handles the standard viewer keyboard shortcuts for recording, changing render modes, etc. + """ + + def __init__( + self, + viewer, + options=None, + camera: "Node" = None, + scene: "Scene" = None, + viewport_size: tuple[int, int] = None, + ): + super().__init__(viewer, options, camera, scene, viewport_size) + + # Instruction display state + self._display_instr = False + self._instr_texts = [ + ["> [i]: show keyboard instructions"], + [ + "< [i]: hide keyboard instructions", + " [r]: record video", + " [s]: save image", + " [z]: reset camera", + " [a]: camera rotation", + " [h]: shadow", + " [f]: face normal", + " [v]: vertex normal", + " [w]: world frame", + " [l]: link frame", + " [d]: wireframe", + " [c]: camera & frustrum", + " [F11]: full-screen mode", + ], + ] + + @override + def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + if self.viewer is None: + return None + + # A causes the frame to rotate + self.viewer._message_text = None + if symbol == pyglet.window.key.A: + self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] + if self.viewer.viewer_flags["rotate"]: + self.viewer._message_text = "Rotation On" + else: + self.viewer._message_text = "Rotation Off" + + # F11 toggles fullscreen + elif symbol == pyglet.window.key.F11: + self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] + self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) + self.viewer.activate() + if self.viewer.viewer_flags["fullscreen"]: + self.viewer._message_text = "Fullscreen On" + else: + self.viewer._message_text = "Fullscreen Off" + + # H toggles shadows + elif symbol == pyglet.window.key.H: + self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] + if self.viewer.render_flags["shadows"]: + self.viewer._message_text = "Shadows On" + else: + self.viewer._message_text = "Shadows Off" + + # W toggles world frame + elif symbol == pyglet.window.key.W: + if not self.viewer.gs_context.world_frame_shown: + self.viewer.gs_context.on_world_frame() + self.viewer._message_text = "World Frame On" + else: + self.viewer.gs_context.off_world_frame() + self.viewer._message_text = "World Frame Off" + + # L toggles link frame + elif symbol == pyglet.window.key.L: + if not self.viewer.gs_context.link_frame_shown: + self.viewer.gs_context.on_link_frame() + self.viewer._message_text = "Link Frame On" + else: + self.viewer.gs_context.off_link_frame() + self.viewer._message_text = "Link Frame Off" + + # C toggles camera frustum + elif symbol == pyglet.window.key.C: + if not self.viewer.gs_context.camera_frustum_shown: + self.viewer.gs_context.on_camera_frustum() + self.viewer._message_text = "Camera Frustrum On" + else: + self.viewer.gs_context.off_camera_frustum() + self.viewer._message_text = "Camera Frustrum Off" + + # F toggles face normals + elif symbol == pyglet.window.key.F: + self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] + if self.viewer.render_flags["face_normals"]: + self.viewer._message_text = "Face Normals On" + else: + self.viewer._message_text = "Face Normals Off" + + # V toggles vertex normals + elif symbol == pyglet.window.key.V: + self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] + if self.viewer.render_flags["vertex_normals"]: + self.viewer._message_text = "Vert Normals On" + else: + self.viewer._message_text = "Vert Normals Off" + + # R starts recording frames + elif symbol == pyglet.window.key.R: + if self.viewer.viewer_flags["record"]: + self.viewer.save_video() + self.viewer.set_caption(self.viewer.viewer_flags["window_title"]) + else: + # Importing moviepy is very slow and not used very often. Let's delay import. + from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter + + self.viewer._video_recorder = FFMPEG_VideoWriter( + filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), + fps=self.viewer.viewer_flags["refresh_rate"], + size=self.viewer.viewport_size, + ) + self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) + self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] + + # S saves the current frame as an image + elif symbol == pyglet.window.key.S: + self.viewer._save_image() + + # D toggles through wireframe modes + elif symbol == pyglet.window.key.D: + if self.viewer.render_flags["flip_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = True + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "All Wireframe" + elif self.viewer.render_flags["all_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = True + self.viewer._message_text = "All Solid" + elif self.viewer.render_flags["all_solid"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "Default Wireframe" + else: + self.viewer.render_flags["flip_wireframe"] = True + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "Flip Wireframe" + + # Z resets the camera viewpoint + elif symbol == pyglet.window.key.Z: + self.viewer._reset_view() + + # I toggles instruction display + elif symbol == pyglet.window.key.I: + self._display_instr = not self._display_instr + + # P reloads shader program + elif symbol == pyglet.window.key.P: + self.viewer._renderer.reload_program() + + if self.viewer._message_text is not None: + self.viewer._message_opac = 1.0 + self.viewer._ticks_till_fade + + return None + + @override + def on_draw(self): + """Render keyboard instructions.""" + if self.viewer is None: + return + + if self._display_instr: + self.viewer._renderer.render_texts( + self._instr_texts[1], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=26, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + else: + self.viewer._renderer.render_texts( + self._instr_texts[0], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=26, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) diff --git a/genesis/ext/pyrender/interaction/plugins/mesh_selector.py b/genesis/ext/pyrender/interaction/plugins/mesh_selector.py new file mode 100644 index 0000000000..c2399a5e21 --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/mesh_selector.py @@ -0,0 +1,206 @@ +import csv +from typing import TYPE_CHECKING, NamedTuple + +from typing_extensions import override + +import genesis as gs +from genesis.options.viewer_interactions import MeshPointSelectorPlugin as MeshPointSelectorPluginOptions + +from ..base_interaction import EVENT_HANDLE_STATE, EVENT_HANDLED, register_viewer_plugin +from ..utils import Pose, Ray, Vec3, ViewerRaycaster +from .viewer_controls import ViewerDefaultControls + +if TYPE_CHECKING: + from genesis.engine.entities.rigid_entity import RigidLink + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + + +class SelectedPoint(NamedTuple): + """ + Represents a selected point on a rigid mesh surface. + + Attributes + ---------- + link : RigidLink + The rigid link that the point belongs to. + local_position : Vec3 + The position of the point in the link's local coordinate frame. + local_normal : Vec3 + The surface normal at the point in the link's local coordinate frame. + """ + link: "RigidLink" + local_position: Vec3 + local_normal: Vec3 + + + +@register_viewer_plugin(MeshPointSelectorPluginOptions) +class MeshPointSelectorPlugin(ViewerDefaultControls): + """ + Interactive viewer plugin that enables using mouse clicks to select points on rigid meshes. + Selected points are stored in local coordinates relative to their link's frame. + """ + + def __init__( + self, + viewer, + options: MeshPointSelectorPluginOptions, + camera: "Node", + scene: "Scene", + viewport_size: tuple[int, int], + ) -> None: + super().__init__(viewer, options, camera, scene, viewport_size) + self.prev_mouse_pos: tuple[int, int] = (viewport_size[0] // 2, viewport_size[1] // 2) + + # List of selected points with link, local position, and local normal + self.selected_points: list[SelectedPoint] = [] + + self.raycaster: ViewerRaycaster = ViewerRaycaster(self.scene) + + def _snap_to_grid(self, position: Vec3) -> Vec3: + """ + Snap a position to the grid based on grid_snap settings. + + Parameters + ---------- + position : Vec3 + The position to snap. + + Returns + ------- + Vec3 + The snapped position. + """ + snap_x, snap_y, snap_z = self.options.grid_snap + + # Snap each axis if the snap value is non-negative + x = round(position.x / snap_x) * snap_x if snap_x >= 0 else position.x + y = round(position.y / snap_y) * snap_y if snap_y >= 0 else position.y + z = round(position.z / snap_z) * snap_z if snap_z >= 0 else position.z + + return Vec3.from_xyz(x, y, z) + + @override + def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: + super().on_mouse_motion(x, y, dx, dy) + self.prev_mouse_pos = (x, y) + return None + + @override + def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: + super().on_mouse_press(x, y, button, modifiers) + if button == 1: # left mouse button + ray = self._screen_position_to_ray(x, y) + ray_hit = self.raycaster.cast_ray(ray.origin.v, ray.direction.v) + + if ray_hit.is_hit and ray_hit.geom: + link = ray_hit.geom.link + world_pos = ray_hit.position + world_normal = ray_hit.normal + + pose: Pose = Pose.from_link(link) + local_pos = pose.inverse_transform_point(world_pos) + local_normal = pose.inverse_transform_direction(world_normal) + + # Apply grid snapping to local position + local_pos = self._snap_to_grid(local_pos) + + selected_point = SelectedPoint( + link=link, + local_position=local_pos, + local_normal=local_normal + ) + self.selected_points.append(selected_point) + + return EVENT_HANDLED + return None + + @override + def update_on_sim_step(self) -> None: + self.raycaster.update_bvh() + + @override + def on_draw(self) -> None: + super().on_draw() + if self.scene._visualizer is not None and self.scene._visualizer.is_built: + self.scene.clear_debug_objects() + mouse_ray: Ray = self._screen_position_to_ray(*self.prev_mouse_pos) + + closest_hit = self.raycaster.cast_ray(mouse_ray.origin.v, mouse_ray.direction.v) + if closest_hit.is_hit: + snap_pos = self._snap_to_grid(closest_hit.position) + # Draw hover preview + self.scene.draw_debug_sphere( + snap_pos.v, + self.options.sphere_radius, + self.options.hover_color, + ) + self.scene.draw_debug_arrow( + snap_pos.v, + closest_hit.normal.v * 0.1, + self.options.sphere_radius / 2, + self.options.hover_color, + ) + + if self.selected_points: + world_positions = [] + for point in self.selected_points: + pose = Pose.from_link(point.link) + current_world_pos = pose.transform_point(point.local_position) + world_positions.append(current_world_pos.v) + + if len(world_positions) == 1: + self.scene.draw_debug_sphere( + world_positions[0], + self.options.sphere_radius, + self.options.sphere_color, + ) + else: + import numpy as np + + positions_array = np.array(world_positions) + self.scene.draw_debug_spheres( + positions_array, self.options.sphere_radius, self.options.sphere_color + ) + + @override + def on_close(self) -> None: + super().on_close() + + if not self.selected_points: + print("[MeshPointSelectorPlugin] No points selected.") + return + + output_file = self.options.output_file + try: + with open(output_file, 'w', newline='') as csvfile: + writer = csv.writer(csvfile) + + writer.writerow([ + 'point_idx', + 'link_idx', + 'local_pos_x', + 'local_pos_y', + 'local_pos_z', + 'local_normal_x', + 'local_normal_y', + 'local_normal_z' + ]) + + for i, point in enumerate(self.selected_points, 1): + writer.writerow([ + i, + point.link.idx, + point.local_position.x, + point.local_position.y, + point.local_position.z, + point.local_normal.x, + point.local_normal.y, + point.local_normal.z, + ]) + + gs.logger.info(f"[MeshPointSelectorPlugin] Wrote {len(self.selected_points)} selected points to '{output_file}'") + + except Exception as e: + gs.logger.error(f"[MeshPointSelectorPlugin] Error writing to '{output_file}': {e}") diff --git a/genesis/ext/pyrender/interaction/viewer_interaction.py b/genesis/ext/pyrender/interaction/plugins/mouse_interaction.py similarity index 54% rename from genesis/ext/pyrender/interaction/viewer_interaction.py rename to genesis/ext/pyrender/interaction/plugins/mouse_interaction.py index c801baea08..f88b0a9db0 100644 --- a/genesis/ext/pyrender/interaction/viewer_interaction.py +++ b/genesis/ext/pyrender/interaction/plugins/mouse_interaction.py @@ -1,45 +1,122 @@ -from typing import TYPE_CHECKING, cast -from typing_extensions import override # Made it into standard lib from Python 3.12 from threading import Lock as threading_Lock +from typing import TYPE_CHECKING -import numpy as np +import torch +from typing_extensions import override # Made it into standard lib from Python 3.12 import genesis as gs +from genesis.options.viewer_interactions import MouseSpringViewerPlugin as MouseSpringViewerPluginOptions -from .aabb import AABB, OBB -from .mouse_spring import MouseSpring -from .ray import Plane, Ray, RayHit -from .vec3 import Pose, Vec3, Color -from .viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE, EVENT_HANDLED +from ..base_interaction import EVENT_HANDLE_STATE, EVENT_HANDLED, register_viewer_plugin +from ..utils import AABB, OBB, Color, Plane, Pose, Quat, Ray, RayHit, Vec3, ViewerRaycaster +from .viewer_controls import ViewerDefaultControls if TYPE_CHECKING: - from genesis.engine.entities.rigid_entity import RigidGeom, RigidLink, RigidEntity + from genesis.engine.entities.rigid_entity import RigidEntity, RigidGeom, RigidLink from genesis.engine.scene import Scene from genesis.ext.pyrender.node import Node -class ViewerInteraction(ViewerInteractionBase): - """Functionalities to be implemented: - - mouse picking - - mouse dragging +MOUSE_SPRING_POSITION_CORRECTION_FACTOR = 1.0 +MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR = 1.0 + +class MouseSpring: + def __init__(self) -> None: + self.held_link: "RigidLink | None" = None + self.held_point_in_local: Vec3 | None = None + self.prev_control_point: Vec3 | None = None + + def attach(self, picked_link: "RigidLink", control_point: Vec3) -> None: + # for now, we just pick the first geometry + self.held_link = picked_link + pose: Pose = Pose.from_link(self.held_link) + self.held_point_in_local = pose.inverse_transform_point(control_point) + self.prev_control_point = control_point + + def detach(self) -> None: + self.held_link = None + + def apply_force(self, control_point: Vec3, delta_time: float) -> None: + # note when threaded: apply_force is called before attach! + # note2: that was before we added a lock to ViewerInteraction; this migth be fixed now + if not self.held_link: + return + + self.prev_control_point = control_point + + # do simple force on COM only: + link: "RigidLink" = self.held_link + lin_vel: Vec3 = Vec3.from_tensor(link.get_vel()) + ang_vel: Vec3 = Vec3.from_tensor(link.get_ang()) + link_pose: Pose = Pose.from_link(link) + held_point_in_world: Vec3 = link_pose.transform_point(self.held_point_in_local) + + # note: we should assert earlier that link inertial_pos/quat are not None + # todo: verify inertial_pos/quat are stored in local frame + link_T_principal: Pose = Pose(Vec3.from_arraylike(link.inertial_pos), Quat.from_arraylike(link.inertial_quat)) + world_T_principal: Pose = link_pose * link_T_principal + + arm_in_principal: Vec3 = link_T_principal.inverse_transform_point(self.held_point_in_local) # for non-spherical inertia + arm_in_world: Vec3 = world_T_principal.rot * arm_in_principal # for spherical inertia + + pos_err_v: Vec3 = control_point - held_point_in_world + inv_mass: float = float(1.0 / link.get_mass() if link.get_mass() > 0.0 else 0.0) + inv_spherical_inertia: float = float(1.0 / link.inertial_i[0, 0] if link.inertial_i[0, 0] > 0.0 else 0.0) + + inv_dt: float = 1.0 / delta_time + tau: float = MOUSE_SPRING_POSITION_CORRECTION_FACTOR + damp: float = MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR + + total_impulse: Vec3 = Vec3.zero() + total_torque_impulse: Vec3 = Vec3.zero() + + for i in range(3*4): + body_point_vel: Vec3 = lin_vel + ang_vel.cross(arm_in_world) + vel_err_v: Vec3 = Vec3.zero() - body_point_vel + + dir: Vec3 = Vec3.zero() + dir.v[i % 3] = 1.0 + pos_err: float = dir.dot(pos_err_v) + vel_err: float = dir.dot(vel_err_v) + error: float = tau * pos_err * inv_dt + damp * vel_err + + arm_x_dir: Vec3 = arm_in_world.cross(dir) + virtual_mass: float = 1.0 / (inv_mass + arm_x_dir.sqr_magnitude() * inv_spherical_inertia + 1e-24) + impulse: float = error * virtual_mass + + lin_vel += impulse * inv_mass * dir + ang_vel += impulse * inv_spherical_inertia * arm_x_dir + total_impulse.v[i % 3] += impulse + total_torque_impulse += impulse * arm_x_dir + + # Apply the new force + total_force = total_impulse * inv_dt + total_torque = total_torque_impulse * inv_dt + force_tensor: torch.Tensor = total_force.as_tensor()[None] + torque_tensor: torch.Tensor = total_torque.as_tensor()[None] + link.solver.apply_links_external_force(force_tensor, (link.idx,), ref='link_com', local=False) + link.solver.apply_links_external_torque(torque_tensor, (link.idx,), ref='link_com', local=False) + + @property + def is_attached(self) -> bool: + return self.held_link is not None + + +@register_viewer_plugin(MouseSpringViewerPluginOptions) +class MouseSpringViewerPlugin(ViewerDefaultControls): + """ + Basic interactive viewer plugin that enables using mouse to apply spring force on rigid entities. """ def __init__( self, + viewer, + options: MouseSpringViewerPluginOptions, camera: "Node", scene: "Scene", viewport_size: tuple[int, int], - camera_yfov: float, - log_events: bool = False, - camera_fov: float = 60.0, ) -> None: - super().__init__(log_events) - self.camera: "Node" = camera - self.scene: "Scene" = scene - self.viewport_size: tuple[int, int] = viewport_size - self.camera_yfov: float = camera_yfov - - self.tan_half_fov: float = np.tan(0.5 * self.camera_yfov) + super().__init__(viewer, options, camera, scene, viewport_size) self.prev_mouse_pos: tuple[int, int] = (viewport_size[0] // 2, viewport_size[1] // 2) self.picked_link: RigidLink | None = None @@ -50,6 +127,8 @@ def __init__( self.mouse_spring: MouseSpring = MouseSpring() self.lock = threading_Lock() + self.raycaster: ViewerRaycaster = ViewerRaycaster(self.scene) + @override def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: super().on_mouse_motion(x, y, dx, dy) @@ -67,14 +146,15 @@ def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifier @override def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: super().on_mouse_press(x, y, button, modifiers) - if button == 1: # left mouse button - ray_hit = self.raycast_against_entities(self.screen_position_to_ray(x, y)) + if button == 1: # left mouse button + + ray_hit = self.raycaster.cast_ray(self._screen_position_to_ray(x, y).origin.v, self._screen_position_to_ray(x, y).direction.v) with self.lock: if ray_hit.geom: self.picked_link = ray_hit.geom.link assert self.picked_link is not None - temp_fwd = self.get_camera_forward() + temp_fwd = self._get_camera_forward() temp_back = -temp_fwd self.mouse_drag_plane = Plane(temp_back, ray_hit.position) @@ -97,17 +177,13 @@ def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT self.mouse_spring.detach() - @override - def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: - super().on_resize(width, height) - self.viewport_size = (width, height) - self.tan_half_fov = np.tan(0.5 * self.camera_yfov) - @override def update_on_sim_step(self) -> None: + self.raycaster.update_bvh() + with self.lock: if self.picked_link: - mouse_ray: Ray = self.screen_position_to_ray(*self.prev_mouse_pos) + mouse_ray: Ray = self._screen_position_to_ray(*self.prev_mouse_pos) ray_hit: RayHit = self.mouse_drag_plane.raycast(mouse_ray) assert ray_hit.is_hit if ray_hit.is_hit: @@ -130,11 +206,8 @@ def on_draw(self) -> None: super().on_draw() if self.scene._visualizer is not None and self.scene._visualizer.is_built: self.scene.clear_debug_objects() - mouse_ray: Ray = self.screen_position_to_ray(*self.prev_mouse_pos) - - closest_hit = self.raycast_against_entities(mouse_ray) - if not closest_hit.is_hit: - closest_hit = self._raycast_against_ground_plane(mouse_ray) + mouse_ray: Ray = self._screen_position_to_ray(*self.prev_mouse_pos) + closest_hit = self.raycaster.cast_ray(mouse_ray.origin.v, mouse_ray.direction.v) with self.lock: if self.picked_link: @@ -157,76 +230,6 @@ def on_draw(self) -> None: if closest_hit.geom: self._draw_entity_unrotated_obb(closest_hit.geom) - def screen_position_to_ray(self, x: float, y: float) -> Ray: - # convert screen position to ray - if True: - x = x - 0.5 * self.viewport_size[0] - y = y - 0.5 * self.viewport_size[1] - x = 2.0 * x / self.viewport_size[1] * self.tan_half_fov - y = 2.0 * y / self.viewport_size[1] * self.tan_half_fov - else: - # alternative way - projection_matrix = self.camera.camera.get_projection_matrix(*self.viewport_size) - x = x - 0.5 * self.viewport_size[0] - y = y - 0.5 * self.viewport_size[1] - x = 2.0 * x / self.viewport_size[0] / projection_matrix[0, 0] - y = 2.0 * y / self.viewport_size[1] / projection_matrix[1, 1] - - # Note: ignoring pixel aspect ratio - - mtx = self.camera.matrix - position = Vec3.from_array(mtx[:3, 3]) - forward = Vec3.from_array(-mtx[:3, 2]) - right = Vec3.from_array(mtx[:3, 0]) - up = Vec3.from_array(mtx[:3, 1]) - - direction = forward + right * x + up * y - return Ray(position, direction) - - def get_camera_forward(self) -> Vec3: - mtx = self.camera.matrix - return Vec3.from_array(-mtx[:3, 2]) - - def get_camera_ray(self) -> Ray: - mtx = self.camera.matrix - position = Vec3.from_array(mtx[:3, 3]) - forward = Vec3.from_array(-mtx[:3, 2]) - return Ray(position, forward) - - def _raycast_against_ground_plane(self, ray: Ray) -> RayHit: - ground_plane = Plane(Vec3.from_xyz(0, 0, 1), Vec3.zero()) - return ground_plane.raycast(ray) - - def raycast_against_entity_obb(self, entity: "RigidEntity", ray: Ray) -> RayHit: - if isinstance(entity.morph, gs.morphs.Box): - obb: OBB = self._get_box_obb(entity) - ray_hit = obb.raycast(ray) - if ray_hit.is_hit: - ray_hit.geom = entity.geoms[0] - return ray_hit - elif isinstance(entity.morph, gs.morphs.Plane): - # ignore plane - return RayHit.no_hit() - else: - closest_hit = RayHit.no_hit() - for link in entity.links: - if not link.is_fixed: - for geom in link.geoms: - obb: OBB = self._get_geom_placeholder_obb(geom) - ray_hit = obb.raycast(ray) - if ray_hit.distance < closest_hit.distance: - ray_hit.geom = geom - closest_hit = ray_hit - return closest_hit - - def raycast_against_entities(self, ray: Ray) -> RayHit: - closest_hit = RayHit.no_hit() - for entity in self.scene.sim.rigid_solver.entities: - rigid_entity: "RigidEntity" = cast("RigidEntity", entity) - ray_hit = self.raycast_against_entity_obb(rigid_entity, ray) - if ray_hit.distance < closest_hit.distance: - closest_hit = ray_hit - return closest_hit def _get_box_obb(self, box_entity: "RigidEntity") -> OBB: box: gs.morphs.Box = box_entity.morph diff --git a/genesis/ext/pyrender/interaction/plugins/viewer_controls.py b/genesis/ext/pyrender/interaction/plugins/viewer_controls.py new file mode 100644 index 0000000000..7857d122ce --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/viewer_controls.py @@ -0,0 +1,219 @@ +import os +from typing import TYPE_CHECKING + +import numpy as np +import pyglet +from genesis.options.viewer_interactions import ViewerDefaultControls as ViewerDefaultControlsOptions +from typing_extensions import override + +import genesis as gs + +from ...constants import TEXT_PADDING +from ..base_interaction import EVENT_HANDLE_STATE, BaseViewerInteraction, register_viewer_plugin +from ..keybindings import Keybindings + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + +@register_viewer_plugin(ViewerDefaultControlsOptions) +class ViewerDefaultControls(BaseViewerInteraction): + """ + Default keyboard controls for the Genesis viewer. + + This plugin handles the standard viewer keyboard shortcuts for recording, changing render modes, etc. + """ + + def __init__( + self, + viewer, + options=None, + camera: "Node" = None, + scene: "Scene" = None, + viewport_size: tuple[int, int] = None, + ): + super().__init__(viewer, options, camera, scene, viewport_size) + + self.keybindings: Keybindings = Keybindings( + toggle_keyboard_instructions=pyglet.window.key.I, + record_video=pyglet.window.key.R, + save_image=pyglet.window.key.S, + reset_camera=pyglet.window.key.Z, + camera_rotation=pyglet.window.key.A, + shadow=pyglet.window.key.H, + face_normals=pyglet.window.key.F, + vertex_normals=pyglet.window.key.V, + world_frame=pyglet.window.key.W, + link_frame=pyglet.window.key.L, + wireframe=pyglet.window.key.D, + camera_frustum=pyglet.window.key.C, + fullscreen_mode=pyglet.window.key.F11, + ) + if options and options.keybindings: + self.keybindings.apply_override_mapping(options.keybindings) + + self._display_instr = False + self._instr_texts = ( + ["> [i]: show keyboard instructions"], + ["< [i]: hide keyboard instructions"] + self.keybindings.as_instruction_texts( + padding=3, exclude=("toggle_keyboard_instructions")), + ) + + @override + def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: + if self.viewer is None: + return None + + # A causes the frame to rotate + self.viewer._message_text = None + if symbol == self.keybindings.camera_rotation: + self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] + if self.viewer.viewer_flags["rotate"]: + self.viewer._message_text = "Rotation On" + else: + self.viewer._message_text = "Rotation Off" + + # F11 toggles fullscreen + elif symbol == self.keybindings.fullscreen_mode: + self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] + self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) + self.viewer.activate() + if self.viewer.viewer_flags["fullscreen"]: + self.viewer._message_text = "Fullscreen On" + else: + self.viewer._message_text = "Fullscreen Off" + + # H toggles shadows + elif symbol == self.keybindings.shadow: + self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] + if self.viewer.render_flags["shadows"]: + self.viewer._message_text = "Shadows On" + else: + self.viewer._message_text = "Shadows Off" + + # W toggles world frame + elif symbol == self.keybindings.world_frame: + if not self.viewer.gs_context.world_frame_shown: + self.viewer.gs_context.on_world_frame() + self.viewer._message_text = "World Frame On" + else: + self.viewer.gs_context.off_world_frame() + self.viewer._message_text = "World Frame Off" + + # L toggles link frame + elif symbol == self.keybindings.link_frame: + if not self.viewer.gs_context.link_frame_shown: + self.viewer.gs_context.on_link_frame() + self.viewer._message_text = "Link Frame On" + else: + self.viewer.gs_context.off_link_frame() + self.viewer._message_text = "Link Frame Off" + + # C toggles camera frustum + elif symbol == self.keybindings.camera_frustum: + if not self.viewer.gs_context.camera_frustum_shown: + self.viewer.gs_context.on_camera_frustum() + self.viewer._message_text = "Camera Frustrum On" + else: + self.viewer.gs_context.off_camera_frustum() + self.viewer._message_text = "Camera Frustrum Off" + + # F toggles face normals + elif symbol == self.keybindings.face_normals: + self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] + if self.viewer.render_flags["face_normals"]: + self.viewer._message_text = "Face Normals On" + else: + self.viewer._message_text = "Face Normals Off" + + # V toggles vertex normals + elif symbol == self.keybindings.vertex_normals: + self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] + if self.viewer.render_flags["vertex_normals"]: + self.viewer._message_text = "Vert Normals On" + else: + self.viewer._message_text = "Vert Normals Off" + + # R starts recording frames + elif symbol == self.keybindings.record_video: + if self.viewer.viewer_flags["record"]: + self.viewer.save_video() + self.viewer.set_caption(self.viewer.viewer_flags["window_title"]) + else: + # Importing moviepy is very slow and not used very often. Let's delay import. + from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter + + self.viewer._video_recorder = FFMPEG_VideoWriter( + filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), + fps=self.viewer.viewer_flags["refresh_rate"], + size=self.viewer.viewport_size, + ) + self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) + self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] + + # S saves the current frame as an image + elif symbol == self.keybindings.save_image: + self.viewer._save_image() + + # D toggles through wireframe modes + elif symbol == self.keybindings.wireframe: + if self.viewer.render_flags["flip_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = True + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "All Wireframe" + elif self.viewer.render_flags["all_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = True + self.viewer._message_text = "All Solid" + elif self.viewer.render_flags["all_solid"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "Default Wireframe" + else: + self.viewer.render_flags["flip_wireframe"] = True + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "Flip Wireframe" + + # Z resets the camera viewpoint + elif symbol == self.keybindings.reset_camera: + self.viewer._reset_view() + + # I toggles instruction display + elif symbol == self.keybindings.toggle_keyboard_instructions: + self._display_instr = not self._display_instr + + # P reloads shader program + elif symbol == self.keybindings.reload_shader: + self.viewer._renderer.reload_program() + + if self.viewer._message_text is not None: + self.viewer._message_opac = 1.0 + self.viewer._ticks_till_fade + + return None + + @override + def on_draw(self): + """Render keyboard instructions.""" + if self.viewer is None: + return + + if self._display_instr: + self.viewer._renderer.render_texts( + self._instr_texts[1], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=26, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + else: + self.viewer._renderer.render_texts( + self._instr_texts[0], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=26, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) diff --git a/genesis/ext/pyrender/interaction/utils/__init__.py b/genesis/ext/pyrender/interaction/utils/__init__.py new file mode 100644 index 0000000000..7182f96982 --- /dev/null +++ b/genesis/ext/pyrender/interaction/utils/__init__.py @@ -0,0 +1,4 @@ +from .aabb import AABB, OBB +from .ray import Plane, Ray, RayHit +from .raycaster import ViewerRaycaster +from .vec3 import Color, Pose, Quat, Vec3 diff --git a/genesis/ext/pyrender/interaction/aabb.py b/genesis/ext/pyrender/interaction/utils/aabb.py similarity index 100% rename from genesis/ext/pyrender/interaction/aabb.py rename to genesis/ext/pyrender/interaction/utils/aabb.py diff --git a/genesis/ext/pyrender/interaction/ray.py b/genesis/ext/pyrender/interaction/utils/ray.py similarity index 100% rename from genesis/ext/pyrender/interaction/ray.py rename to genesis/ext/pyrender/interaction/utils/ray.py diff --git a/genesis/ext/pyrender/interaction/utils/raycaster.py b/genesis/ext/pyrender/interaction/utils/raycaster.py new file mode 100644 index 0000000000..6c8d31bb1f --- /dev/null +++ b/genesis/ext/pyrender/interaction/utils/raycaster.py @@ -0,0 +1,282 @@ +from typing import TYPE_CHECKING + +import gstaichi as ti +import numpy as np + +import genesis as gs +from genesis.engine.bvh import AABB, LBVH, STACK_SIZE +from genesis.utils.raycast import kernel_update_aabbs, ray_aabb_intersection, ray_triangle_intersection + +from .ray import RayHit +from .vec3 import Vec3 + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + + +# Constant to indicate no hit occurred +NO_HIT_DISTANCE = -1.0 + + + +@ti.kernel +def kernel_cast_single_ray_for_viewer( + fixed_verts_state: ti.template(), + free_verts_state: ti.template(), + verts_info: ti.template(), + faces_info: ti.template(), + bvh_nodes: ti.template(), + bvh_morton_codes: ti.template(), + ray_start: ti.types.ndarray(ndim=1), # [3] + ray_direction: ti.types.ndarray(ndim=1), # [3] + max_range: ti.f32, + envs_idx: ti.types.ndarray(ndim=1), # [n_envs] + result: ti.types.ndarray(ndim=1), # [9]: [distance, geom_idx, hit_x, hit_y, hit_z, normal_x, normal_y, normal_z, env_idx] +): + """ + Taichi kernel for casting a single ray for viewer interaction. + + This loops over all environments in envs_idx and returns the closest hit. + + Returns: + result[0]: distance to hit point (NO_HIT_DISTANCE if no hit) + result[1]: geom_idx of hit geometry + result[2]: hit_point x coordinate + result[3]: hit_point y coordinate + result[4]: hit_point z coordinate + result[5]: normal x coordinate + result[6]: normal y coordinate + result[7]: normal z coordinate + result[8]: env_idx of hit environment + """ + n_triangles = faces_info.verts_idx.shape[0] + + # Setup ray + ray_start_world = ti.math.vec3(ray_start[0], ray_start[1], ray_start[2]) + ray_direction_world = ti.math.vec3(ray_direction[0], ray_direction[1], ray_direction[2]) + + # Initialize result with no hit + result[0] = -1.0 # NO_HIT_DISTANCE + result[1] = -1.0 # no geom + result[2] = 0.0 # hit_point x + result[3] = 0.0 # hit_point y + result[4] = 0.0 # hit_point z + result[5] = 0.0 # normal x + result[6] = 0.0 # normal y + result[7] = 0.0 # normal z + result[8] = -1.0 # no env + + global_closest_distance = max_range + global_hit_face = -1 + global_hit_env_idx = -1 + global_hit_normal = ti.math.vec3(0.0, 0.0, 0.0) + + # Loop over all environments in envs_idx + for i_b in range(envs_idx.shape[0]): + rendered_env_idx = ti.cast(envs_idx[i_b], ti.i32) + + hit_face = -1 + closest_distance = global_closest_distance + hit_normal = ti.math.vec3(0.0, 0.0, 0.0) + + # Stack for non-recursive BVH traversal + node_stack = ti.Vector.zero(ti.i32, STACK_SIZE) + node_stack[0] = 0 # Start at root node + stack_idx = 1 + + while stack_idx > 0: + stack_idx -= 1 + node_idx = node_stack[stack_idx] + + node = bvh_nodes[i_b, node_idx] + + # Check if ray hits the node's bounding box + aabb_t = ray_aabb_intersection(ray_start_world, ray_direction_world, node.bound.min, node.bound.max) + + if aabb_t >= 0.0 and aabb_t < closest_distance: + if node.left == -1: # Leaf node + # Get original triangle/face index + sorted_leaf_idx = node_idx - (n_triangles - 1) + i_f = ti.cast(bvh_morton_codes[0, sorted_leaf_idx][1], ti.i32) + + # Get triangle vertices + tri_vertices = ti.Matrix.zero(gs.ti_float, 3, 3) + for i in ti.static(range(3)): + i_v = faces_info.verts_idx[i_f][i] + i_fv = verts_info.verts_state_idx[i_v] + if verts_info.is_fixed[i_v]: + tri_vertices[:, i] = fixed_verts_state.pos[i_fv] + else: + tri_vertices[:, i] = free_verts_state.pos[i_fv, rendered_env_idx] + v0, v1, v2 = tri_vertices[:, 0], tri_vertices[:, 1], tri_vertices[:, 2] + + # Perform ray-triangle intersection + hit_result = ray_triangle_intersection(ray_start_world, ray_direction_world, v0, v1, v2) + + if hit_result.w > 0.0 and hit_result.x < closest_distance and hit_result.x >= 0.0: + closest_distance = hit_result.x + hit_face = i_f + # Compute triangle normal + edge1 = v1 - v0 + edge2 = v2 - v0 + hit_normal = edge1.cross(edge2).normalized() + else: # Internal node + # Push children onto stack + if stack_idx < ti.static(STACK_SIZE - 2): + node_stack[stack_idx] = node.left + node_stack[stack_idx + 1] = node.right + stack_idx += 2 + + # Update global closest if this environment had a closer hit + if hit_face >= 0 and closest_distance < global_closest_distance: + global_closest_distance = closest_distance + global_hit_face = hit_face + global_hit_env_idx = rendered_env_idx + global_hit_normal = hit_normal + + # Store result + if global_hit_face >= 0: + result[0] = global_closest_distance # distance (positive value indicates hit) + # Find which geom this face belongs to + i_g = faces_info.geom_idx[global_hit_face] + result[1] = gs.ti_float(i_g) + # Compute hit point + hit_point = ray_start_world + global_closest_distance * ray_direction_world + result[2] = hit_point.x + result[3] = hit_point.y + result[4] = hit_point.z + # Store normal + result[5] = global_hit_normal.x + result[6] = global_hit_normal.y + result[7] = global_hit_normal.z + result[8] = gs.ti_float(global_hit_env_idx) + + + +class ViewerRaycaster: + """ + BVH-accelerated raycaster for viewer interaction plugins. + + This class manages a BVH structure built from the scene's rigid geometry + and provides efficient single-ray casting for interactive applications. + Only considers environments specified in rendered_envs_idx. + """ + + def __init__(self, scene: "Scene"): + """ + Initialize the ViewerRaycaster. + + Parameters + ---------- + scene : Scene + The scene to build the raycaster for. + """ + self.scene = scene + self.solver = scene.sim.rigid_solver + + # Store rendered_envs_idx as numpy array for Taichi kernel + + # self.rendered_envs_idx = np.asarray(scene.vis_options.rendered_envs_idx or [0], dtype=gs.np_int) + self.rendered_envs_idx = np.asarray([0], dtype=gs.np_int) + + # Build the BVH structure for rendered environments. + n_faces = self.solver.faces_info.geom_idx.shape[0] + + if n_faces == 0: + gs.logger.warning("No faces found in scene, viewer raycasting will not work.") + self.aabb = None + self.bvh = None + return + + self.aabb = AABB(n_batches=len(self.rendered_envs_idx), n_aabbs=n_faces) + self.bvh = LBVH( + self.aabb, + max_n_query_result_per_aabb=0, # Not used for ray queries + n_radix_sort_groups=min(64, n_faces), + ) + + self.update_bvh() + + def update_bvh(self): + """Update the BVH structure with current geometry state.""" + if self.bvh is None: + return + + # Update vertex positions + from genesis.engine.solvers.rigid.rigid_solver_decomp import kernel_update_all_verts + + kernel_update_all_verts( + geoms_info=self.solver.geoms_info, + geoms_state=self.solver.geoms_state, + verts_info=self.solver.verts_info, + free_verts_state=self.solver.free_verts_state, + fixed_verts_state=self.solver.fixed_verts_state, + static_rigid_sim_config=self.solver._static_rigid_sim_config, + ) + + # Update AABBs for each rendered environment + kernel_update_aabbs( + free_verts_state=self.solver.free_verts_state, + fixed_verts_state=self.solver.fixed_verts_state, + verts_info=self.solver.verts_info, + faces_info=self.solver.faces_info, + aabb_state=self.aabb, + ) + + # Rebuild BVH + self.bvh.build() + + def cast_ray( + self, + ray_origin: np.ndarray, + ray_direction: np.ndarray, + max_range: float = 1000.0, + ) -> RayHit: + """ + Cast a single ray against all rendered environments and return the closest hit. + + Parameters + ---------- + ray_origin : np.ndarray, shape (3,) + The origin point of the ray in world coordinates. + ray_direction : np.ndarray, shape (3,) + The direction vector of the ray (will be normalized). + max_range : float, optional + Maximum distance to check for intersections. Default is 1000.0. + + Returns + ------- + RayHit + A RayHit object containing distance, position, normal, and geom. + If no hit, returns RayHit.no_hit(). + """ + ray_direction = ray_direction / (np.linalg.norm(ray_direction) + gs.EPS) + + ray_start_np = np.asarray(ray_origin, dtype=gs.np_float) + ray_dir_np = np.asarray(ray_direction, dtype=gs.np_float) + result_np = np.zeros(9, dtype=gs.np_float) + + kernel_cast_single_ray_for_viewer( + fixed_verts_state=self.solver.fixed_verts_state, + free_verts_state=self.solver.free_verts_state, + verts_info=self.solver.verts_info, + faces_info=self.solver.faces_info, + bvh_nodes=self.bvh.nodes, + bvh_morton_codes=self.bvh.morton_codes, + ray_start=ray_start_np, + ray_direction=ray_dir_np, + max_range=max_range, + envs_idx=self.rendered_envs_idx, + result=result_np, + ) + + distance = float(result_np[0]) + if distance < NO_HIT_DISTANCE + gs.EPS: # NO_HIT_DISTANCE + return RayHit.no_hit() + + geom_idx = int(result_np[1]) + position = Vec3(result_np[2:5]) + normal = Vec3(result_np[5:8]) + geom = self.solver.geoms[geom_idx] + + return RayHit(distance=distance, position=position, normal=normal, geom=geom) diff --git a/genesis/ext/pyrender/interaction/vec3.py b/genesis/ext/pyrender/interaction/utils/vec3.py similarity index 100% rename from genesis/ext/pyrender/interaction/vec3.py rename to genesis/ext/pyrender/interaction/utils/vec3.py diff --git a/genesis/ext/pyrender/interaction/viewer_interaction_base.py b/genesis/ext/pyrender/interaction/viewer_interaction_base.py deleted file mode 100644 index 9be9f545f2..0000000000 --- a/genesis/ext/pyrender/interaction/viewer_interaction_base.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Union, Literal - -import genesis as gs - - -EVENT_HANDLE_STATE = Union[Literal[True], None] -EVENT_HANDLED: Literal[True] = True - -# Note: Viewer window is based on pyglet.window.Window, mouse events are defined in pyglet.window.BaseWindow - - -class ViewerInteractionBase: - """Base class for handling pyglet.window.Window events.""" - - log_events: bool - - def __init__(self, log_events: bool = False): - self.log_events = log_events - - def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse moved to {x}, {y}") - - def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse dragged to {x}, {y}") - - def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse buttons {button} pressed at {x}, {y}") - - def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Mouse buttons {button} released at {x}, {y}") - - def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Key pressed: {chr(symbol)}") - - def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Key released: {chr(symbol)}") - - def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: - if self.log_events: - gs.logger.info(f"Window resized to {width}x{height}") - - def update_on_sim_step(self) -> None: - pass - - def on_draw(self) -> None: - pass diff --git a/genesis/ext/pyrender/viewer.py b/genesis/ext/pyrender/viewer.py index 7c82a94f18..65d3452734 100644 --- a/genesis/ext/pyrender/viewer.py +++ b/genesis/ext/pyrender/viewer.py @@ -1,14 +1,14 @@ """A pyglet-based interactive 3D scene viewer.""" import copy -from contextlib import nullcontext import os import shutil import sys -import time import threading +import time +from contextlib import nullcontext from threading import Event, RLock, Semaphore, Thread -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import numpy as np import OpenGL @@ -44,8 +44,7 @@ RenderFlags, TextAlign, ) -from .interaction.viewer_interaction import ViewerInteraction -from .interaction.viewer_interaction_base import ViewerInteractionBase, EVENT_HANDLE_STATE, EVENT_HANDLED +from .interaction import EVENT_HANDLE_STATE, EVENT_HANDLED, VIEWER_PLUGIN_MAP from .light import DirectionalLight from .node import Node from .renderer import Renderer @@ -205,8 +204,7 @@ def __init__( shadow=False, plane_reflection=False, env_separate_rigid=False, - enable_interaction=False, - disable_keyboard_shortcuts=False, + plugin_options=gs.options.viewer_interactions.ViewerDefaultControls(), **kwargs, ): ####################################################################### @@ -286,8 +284,6 @@ def __init__( if registered_keys is not None: self._registered_keys = {ord(k.lower()): registered_keys[k] for k in registered_keys} - self._disable_keyboard_shortcuts = disable_keyboard_shortcuts - ####################################################################### # Save internal settings ####################################################################### @@ -297,27 +293,6 @@ def __init__( self._ticks_till_fade = 2.0 / 3.0 * self.viewer_flags["refresh_rate"] self._message_opac = 1.0 + self._ticks_till_fade - self._display_instr = False - - self._instr_texts = [ - ["> [i]: show keyboard instructions"], - [ - "< [i]: hide keyboard instructions", - " [r]: record video", - " [s]: save image", - " [z]: reset camera", - " [a]: camera rotation", - " [h]: shadow", - " [f]: face normal", - " [v]: vertex normal", - " [w]: world frame", - " [l]: link frame", - " [d]: wireframe", - " [c]: camera & frustrum", - " [F11]: full-screen mode", - ], - ] - # Set up raymond lights and direct lights self._raymond_lights = self._create_raymond_lights() self._direct_light = self._create_direct_light() @@ -379,14 +354,18 @@ def __init__( self.scene.main_camera_node = self._camera_node self._reset_view() - # Setup mouse interaction - # Note: context.scene is genesis.engine.scene.Scene # Note: context._scene is genesis.ext.pyrender.scene.Scene - self.viewer_interaction = ( - ViewerInteraction(self._camera_node, context.scene, viewport_size, camera.yfov) - if enable_interaction - else ViewerInteractionBase() + + # Setup viewer plugin + plugin_cls = VIEWER_PLUGIN_MAP.get(type(plugin_options)) + if plugin_cls is None: + gs.raise_exception( + f"Viewer plugin type {type(plugin_options).__name__} is not registered. " + f"Available plugins: {list(VIEWER_PLUGIN_MAP.keys())}" + ) + self.interaction_plugin = plugin_cls( + self, plugin_options, self._camera_node, context.scene, viewport_size ) ####################################################################### @@ -590,6 +569,8 @@ def on_close(self): # Do not consider the viewer as active anymore self._is_active = False + self.interaction_plugin.on_close() + # Remove our camera and restore the prior one try: if self._camera_node is not None: @@ -735,47 +716,21 @@ def on_draw(self): self.viewer_interaction.on_draw() - if not self._disable_keyboard_shortcuts: - if self._display_instr: - self._renderer.render_texts( - self._instr_texts[1], - TEXT_PADDING, - self.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - else: - self._renderer.render_texts( - self._instr_texts[0], - TEXT_PADDING, - self.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - - if self._message_text is not None: - self._renderer.render_text( - self._message_text, - self.viewport_size[0] - TEXT_PADDING, - TEXT_PADDING, - font_pt=20, - color=np.array([0.1, 0.7, 0.2, np.clip(self._message_opac, 0.0, 1.0)]), - align=TextAlign.BOTTOM_RIGHT, - ) - - if self.viewer_flags["caption"] is not None: - for caption in self.viewer_flags["caption"]: - xpos, ypos = self._location_to_x_y(caption["location"]) - self._renderer.render_text( - caption["text"], - xpos, - ypos, - font_name=caption["font_name"], - font_pt=caption["font_pt"], - color=caption["color"], - scale=caption["scale"], - align=caption["location"], - ) + self.interaction_plugin.on_draw() + + if self.viewer_flags["caption"] is not None: + for caption in self.viewer_flags["caption"]: + xpos, ypos = self._location_to_x_y(caption["location"]) + self._renderer.render_text( + caption["text"], + xpos, + ypos, + font_name=caption["font_name"], + font_pt=caption["font_pt"], + color=caption["color"], + scale=caption["scale"], + align=caption["location"], + ) def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: """Resize the camera and trackball when the window is resized.""" @@ -789,12 +744,12 @@ def on_resize(self, width: int, height: int) -> EVENT_HANDLE_STATE: self._trackball.resize(self._viewport_size) self._renderer.viewport_width = self._viewport_size[0] self._renderer.viewport_height = self._viewport_size[1] - self.viewer_interaction.on_resize(width, height) + self.interaction_plugin.on_resize(width, height) self.on_draw() def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STATE: """The mouse was moved with no buttons held down.""" - return self.viewer_interaction.on_mouse_motion(x, y, dx, dy) + return self.interaction_plugin.on_mouse_motion(x, y, dx, dy) def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record an initial mouse press.""" @@ -816,11 +771,11 @@ def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_H # Stop animating while using the mouse self.viewer_flags["mouse_pressed"] = True - return self.viewer_interaction.on_mouse_press(x, y, button, modifiers) + return self.interaction_plugin.on_mouse_press(x, y, button, modifiers) def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: """The mouse was moved with one or more buttons held down.""" - result = self.viewer_interaction.on_mouse_drag(x, y, dx, dy, buttons, modifiers) + result = self.interaction_plugin.on_mouse_drag(x, y, dx, dy, buttons, modifiers) if result is not EVENT_HANDLED: result = self._trackball.drag(np.array([x, y])) return result @@ -828,7 +783,7 @@ def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifier def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a mouse release.""" self.viewer_flags["mouse_pressed"] = False - return self.viewer_interaction.on_mouse_release(x, y, button, modifiers) + return self.interaction_plugin.on_mouse_release(x, y, button, modifiers) def on_mouse_scroll(self, x, y, dx, dy): """Record a mouse scroll.""" @@ -866,159 +821,14 @@ def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: if len(tup) == 3: kwargs = tup[2] callback(self, *args, **kwargs) - return self.viewer_interaction.on_key_press(symbol, modifiers) - - # If keyboard shortcuts are disabled, skip default key functions - if self._disable_keyboard_shortcuts: - return self.viewer_interaction.on_key_press(symbol, modifiers) - - # Otherwise, use default key functions - - # A causes the frame to rotate - self._message_text = None - if symbol == pyglet.window.key.A: - self.viewer_flags["rotate"] = not self.viewer_flags["rotate"] - if self.viewer_flags["rotate"]: - self._message_text = "Rotation On" - else: - self._message_text = "Rotation Off" - - # F11 toggles face normals - elif symbol == pyglet.window.key.F11: - self.viewer_flags["fullscreen"] = not self.viewer_flags["fullscreen"] - self.set_fullscreen(self.viewer_flags["fullscreen"]) - self.activate() - if self.viewer_flags["fullscreen"]: - self._message_text = "Fullscreen On" - else: - self._message_text = "Fullscreen Off" - - # H toggles shadows - elif symbol == pyglet.window.key.H: - self.render_flags["shadows"] = not self.render_flags["shadows"] - if self.render_flags["shadows"]: - self._message_text = "Shadows On" - else: - self._message_text = "Shadows Off" - - # W toggles world frame - elif symbol == pyglet.window.key.W: - if not self.gs_context.world_frame_shown: - self.gs_context.on_world_frame() - self._message_text = "World Frame On" - else: - self.gs_context.off_world_frame() - self._message_text = "World Frame Off" - - # L toggles link frame - elif symbol == pyglet.window.key.L: - if not self.gs_context.link_frame_shown: - self.gs_context.on_link_frame() - self._message_text = "Link Frame On" - else: - self.gs_context.off_link_frame() - self._message_text = "Link Frame Off" - - # C toggles camera frustum - elif symbol == pyglet.window.key.C: - if not self.gs_context.camera_frustum_shown: - self.gs_context.on_camera_frustum() - self._message_text = "Camera Frustrum On" - else: - self.gs_context.off_camera_frustum() - self._message_text = "Camera Frustrum Off" - - # F toggles face normals - elif symbol == pyglet.window.key.F: - self.render_flags["face_normals"] = not self.render_flags["face_normals"] - if self.render_flags["face_normals"]: - self._message_text = "Face Normals On" - else: - self._message_text = "Face Normals Off" - - # V toggles vertex normals - elif symbol == pyglet.window.key.V: - self.render_flags["vertex_normals"] = not self.render_flags["vertex_normals"] - if self.render_flags["vertex_normals"]: - self._message_text = "Vert Normals On" - else: - self._message_text = "Vert Normals Off" - - # R starts recording frames - elif symbol == pyglet.window.key.R: - if self.viewer_flags["record"]: - self.save_video() - self.set_caption(self.viewer_flags["window_title"]) - else: - # Importing moviepy is very slow and not used very often. Let's delay import. - from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter - - self._video_recorder = FFMPEG_VideoWriter( - filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), - fps=self.viewer_flags["refresh_rate"], - size=self.viewport_size, - ) - self.set_caption("{} (RECORDING)".format(self.viewer_flags["window_title"])) - self.viewer_flags["record"] = not self.viewer_flags["record"] - - # S saves the current frame as an image - elif symbol == pyglet.window.key.S: - self._save_image() - - # T toggles through geom types - # elif symbol == pyglet.window.key.T: - # if self.gs_context.rigid_shown == 'visual': - # self.gs_context.on_rigid('collision') - # self._message_text = "Geom Type: 'collision'" - # elif self.gs_context.rigid_shown == 'collision': - # self.gs_context.on_rigid('sdf') - # self._message_text = "Geom Type: 'sdf'" - # else: - # self.gs_context.on_rigid('visual') - # self._message_text = "Geom Type: 'visual'" - - # D toggles through wireframe modes - elif symbol == pyglet.window.key.D: - if self.render_flags["flip_wireframe"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = True - self.render_flags["all_solid"] = False - self._message_text = "All Wireframe" - elif self.render_flags["all_wireframe"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = True - self._message_text = "All Solid" - elif self.render_flags["all_solid"]: - self.render_flags["flip_wireframe"] = False - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = False - self._message_text = "Default Wireframe" - else: - self.render_flags["flip_wireframe"] = True - self.render_flags["all_wireframe"] = False - self.render_flags["all_solid"] = False - self._message_text = "Flip Wireframe" - - # Z resets the camera viewpoint - elif symbol == pyglet.window.key.Z: - self._reset_view() - - # i toggles instruction display - elif symbol == pyglet.window.key.I: - self._display_instr = not self._display_instr - - elif symbol == pyglet.window.key.P: - self._renderer.reload_program() - - if self._message_text is not None: - self._message_opac = 1.0 + self._ticks_till_fade - - return self.viewer_interaction.on_key_press(symbol, modifiers) + # Continue to plugins after registered callback + + # Delegate to viewer plugin + return self.interaction_plugin.on_key_press(symbol, modifiers) def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a key release.""" - return self.viewer_interaction.on_key_release(symbol, modifiers) + return self.interaction_plugin.on_key_release(symbol, modifiers) @staticmethod def _time_event(dt, self): @@ -1089,8 +899,7 @@ def _get_save_filename(self, file_exts): try: # Importing tkinter is very slow and not used very often. Let's delay import. - from tkinter import Tk - from tkinter import filedialog + from tkinter import Tk, filedialog if root is None: root = Tk() @@ -1234,8 +1043,8 @@ def start(self, auto_refresh=True): import pyglet # For some reason, this is necessary if 'pyglet.window.xlib' fails to import... try: - import pyglet.window.xlib import pyglet.display.xlib + import pyglet.window.xlib xlib_exceptions = (pyglet.window.xlib.XlibException, pyglet.display.xlib.NoSuchDisplayException) except ImportError: @@ -1404,7 +1213,7 @@ def refresh(self): self.flip() def update_on_sim_step(self): - self.viewer_interaction.update_on_sim_step() + self.interaction_plugin.update_on_sim_step() def _compute_initial_camera_pose(self): centroid = self.scene.centroid diff --git a/genesis/options/recorders.py b/genesis/options/recorders.py index ddeda88549..79a099af88 100644 --- a/genesis/options/recorders.py +++ b/genesis/options/recorders.py @@ -6,7 +6,6 @@ from .options import Options - IS_PYAV_AVAILABLE = False try: import av @@ -297,3 +296,43 @@ class MPLImagePlot(BasePlotterOptions): """ pass + + +class MPLDepthScatterPlot(BasePlotterOptions): + """ + Live visualization of depth sensor data as a 3D scatter plot using matplotlib. + + The data should be a tuple of (positions, distances) where: + - positions: array-like with shape (N, 3) for (x, y, z) coordinates + - distances: array-like with shape (N,) for depth/distance values + + Parameters + ---------- + title: str + The title of the plot. + window_size: tuple[int, int] + The size of the window in pixels. + save_to_filename: str | None + If provided, the animation will be saved to a file with the given filename. + show_window: bool | None + Whether to show the window. If not provided, it will be set to True if a display is connected, False otherwise. + cmap: str + The colormap to use for depth visualization. Defaults to 'viridis'. + point_size: float + The size of each point in the scatter plot. Defaults to 50. + x_label: str + Label for the x-axis. Defaults to 'X'. + y_label: str + Label for the y-axis. Defaults to 'Y'. + z_label: str + Label for the z-axis. Defaults to 'Z'. + colorbar_label: str + Label for the colorbar. Defaults to 'Distance'. + """ + + cmap: str = "viridis" + point_size: float = 50.0 + x_label: str = "X" + y_label: str = "Y" + z_label: str = "Z" + colorbar_label: str = "Distance" diff --git a/genesis/options/sensors/raycaster.py b/genesis/options/sensors/raycaster.py index b21d41eaf1..0a42e07d99 100644 --- a/genesis/options/sensors/raycaster.py +++ b/genesis/options/sensors/raycaster.py @@ -55,6 +55,21 @@ def ray_starts(self) -> torch.Tensor: # ============================== Generic Patterns ============================== +def _sanitize_rays_to_tensor(rays: Sequence[float]) -> torch.Tensor: + tensor = torch.tensor(rays, dtype=gs.tc_float, device=gs.device) + if tensor.ndim < 2 or tensor.shape[-1] != 3: + gs.raise_exception(f"Rays should have shape (..., 3). Got: {tensor.shape}") + return tensor + + +class RaycastCustomPattern(RaycastPattern): + + def __init__(self, ray_dirs: Sequence[float], ray_starts: Sequence[float]): + self._ray_dirs = _sanitize_rays_to_tensor(ray_dirs) + self._ray_starts = _sanitize_rays_to_tensor(ray_starts) + self._return_shape: tuple[int, ...] = ray_dirs.shape[:-1] + + class GridPattern(RaycastPattern): """ Configuration for grid-based ray casting. diff --git a/genesis/options/viewer_interactions.py b/genesis/options/viewer_interactions.py new file mode 100644 index 0000000000..7e72a0e6b2 --- /dev/null +++ b/genesis/options/viewer_interactions.py @@ -0,0 +1,56 @@ +from .options import Options + + +class ViewerInteraction(Options): + """ + Base class for viewer interaction options. + + All viewer interaction option classes should inherit from this base class. + """ + + pass + + +class ViewerDefaultControls(ViewerInteraction): + """ + Default viewer interaction controls with keyboard shortcuts for recording, changing render modes, etc. + + Parameters + ---------- + keybindings : dict[str, int] + Override the default mapping of action names to keyboard key codes (pyglet.window.key.*). + """ + + keybindings: dict[str, int] = None + + +class MouseSpringViewerPlugin(ViewerDefaultControls): + """ + Options for the interactive viewer plugin that allows mouse-based object manipulation. + """ + + pass + + +class MeshPointSelectorPlugin(ViewerDefaultControls): + """ + Options for the mesh point selector plugin that allows selecting points on a mesh. + + Parameters + ---------- + sphere_radius : float + The radius of the sphere used to visualize selected points. + sphere_color : tuple + The color of the sphere used to visualize selected points. + hover_color : tuple + The color of the sphere used to visualize the point and normal when hovering over a mesh. + grid_snap : tuple[float, float, float] + Grid snap spacing for each axis (x, y, z). Any negative value disables snapping for that axis. + Default is (-1.0, -1.0, -1.0) which means no snapping. + """ + + sphere_radius: float = 0.005 + sphere_color: tuple = (0.1, 0.3, 1.0, 1.0) + hover_color: tuple = (0.3, 0.5, 1.0, 1.0) + grid_snap: tuple[float, float, float] = (-1.0, -1.0, -1.0) + output_file: str = "selected_points.csv" diff --git a/genesis/options/vis.py b/genesis/options/vis.py index 939df400c0..46934e8672 100644 --- a/genesis/options/vis.py +++ b/genesis/options/vis.py @@ -3,6 +3,7 @@ import genesis as gs from .options import Options +from .viewer_interactions import ViewerDefaultControls, ViewerInteraction class ViewerOptions(Options): @@ -33,20 +34,19 @@ class ViewerOptions(Options): The up vector of the camera's extrinsic pose. camera_fov : float The field of view (in degrees) of the camera. - disable_keyboard_shortcuts : bool - Whether to disable all keyboard shortcuts in the viewer. Defaults to False. + viewer_plugin : ViewerPluginOptions + Viewer plugin that adds interactive functionality to the viewer. """ - res: Optional[tuple] = None - run_in_thread: Optional[bool] = None + res: tuple | None = None + run_in_thread: bool | None = None refresh_rate: int = 60 - max_FPS: Optional[int] = 60 + max_FPS: int | None = 60 camera_pos: tuple = (3.5, 0.5, 2.5) camera_lookat: tuple = (0.0, 0.0, 0.5) camera_up: tuple = (0.0, 0.0, 1.0) camera_fov: float = 40 - enable_interaction: bool = False - disable_keyboard_shortcuts: bool = False + viewer_plugin: ViewerInteraction = ViewerDefaultControls() class VisOptions(Options): diff --git a/genesis/recorders/plotters.py b/genesis/recorders/plotters.py index 731ed8e372..e9d00cb7b5 100644 --- a/genesis/recorders/plotters.py +++ b/genesis/recorders/plotters.py @@ -16,10 +16,19 @@ from genesis.options.recorders import ( BasePlotterOptions, LinePlotterMixinOptions, - PyQtLinePlot as PyQtLinePlotterOptions, - MPLLinePlot as MPLLinePlotterOptions, +) +from genesis.options.recorders import ( + MPLDepthScatterPlot as MPLDepthScatterPlotterOptions, +) +from genesis.options.recorders import ( MPLImagePlot as MPLImagePlotterOptions, ) +from genesis.options.recorders import ( + MPLLinePlot as MPLLinePlotterOptions, +) +from genesis.options.recorders import ( + PyQtLinePlot as PyQtLinePlotterOptions, +) from genesis.utils import has_display, tensor_to_array from .base_recorder import Recorder @@ -631,3 +640,138 @@ def cleanup(self): self.ax = None self.image_plot = None self.background = None + + +@register_recording(MPLDepthScatterPlotterOptions) +class MPLDepthScatterPlotter(BaseMPLPlotter): + """ + Live depth sensor visualization using matplotlib 3D scatter plot. + + The data should be a tuple of (positions, distances) where: + - positions: array-like with shape (N, 3) for (x, y, z) coordinates + - distances: array-like with shape (N,) for depth/distance values + """ + + def build(self): + super().build() + + import matplotlib.pyplot as plt + + self.scatter = None + self.colorbar = None + self.positions = None + self.distances = None + + # Create 3D figure + self.fig = plt.figure(figsize=self.figsize) + self.ax = self.fig.add_subplot(111, projection="3d") + self.fig.suptitle(self._options.title) + self.ax.set_xlabel(self._options.x_label) + self.ax.set_ylabel(self._options.y_label) + self.ax.set_zlabel(self._options.z_label) + self.ax.grid(True, alpha=0.3) + + # Initialize with empty scatter (will be set on first data) + self.scatter = self.ax.scatter( + [], [], [], c=[], s=self._options.point_size, cmap=self._options.cmap, edgecolors="none" + ) + self.colorbar = self.fig.colorbar( + self.scatter, ax=self.ax, label=self._options.colorbar_label, shrink=0.5, aspect=5 + ) + + self._show_fig() + + def process(self, data, cur_time): + """Process new position and distance data.""" + # Expect data to be a tuple of (positions, distances) + if not isinstance(data, (tuple, list)) or len(data) != 2: + gs.logger.warning(f"[{type(self).__name__}] Data must be a tuple (positions, distances). Got: {type(data)}") + return + + positions, distances = data + + # Convert to numpy arrays + if isinstance(positions, torch.Tensor): + positions = tensor_to_array(positions) + else: + positions = np.asarray(positions) + + if isinstance(distances, torch.Tensor): + distances = tensor_to_array(distances) + else: + distances = np.asarray(distances) + + # Flatten if needed + if positions.ndim > 2: + positions = positions.reshape(-1, positions.shape[-1]) + distances = distances.flatten() + + # Validate shapes + if positions.ndim != 2 or positions.shape[1] < 3: + gs.logger.warning( + f"[{type(self).__name__}] Positions must have shape (N, 3) or (N, D) where D >= 3. " + f"Got shape {positions.shape}" + ) + return + + if len(positions) != len(distances): + gs.logger.warning( + f"[{type(self).__name__}] Number of positions ({len(positions)}) doesn't match " + f"number of distances ({len(distances)})" + ) + return + + self.positions = positions[:, :3] # use (x, y, z) + self.distances = distances + + super().process(data, cur_time) + + def _update_plot(self): + """Update the 3D scatter plot with new positions and distances.""" + if self.positions is None or self.distances is None: + return + + # Update scatter plot data + self.scatter._offsets3d = (self.positions[:, 0], self.positions[:, 1], self.positions[:, 2]) + self.scatter.set_array(self.distances) + + # Update color limits + vmin, vmax = np.min(self.distances), np.max(self.distances) + current_vmin, current_vmax = self.scatter.get_clim() + if abs(vmin - current_vmin) > 1e-6 or abs(vmax - current_vmax) > 1e-6: + self.scatter.set_clim(vmin, vmax) + + # Update axis limits if this is the first update or limits have changed significantly + if not hasattr(self, "_limits_set") or not self._limits_set: + x_min, x_max = self.positions[:, 0].min(), self.positions[:, 0].max() + y_min, y_max = self.positions[:, 1].min(), self.positions[:, 1].max() + z_min, z_max = self.positions[:, 2].min(), self.positions[:, 2].max() + + # Calculate uniform range for all axes + max_range = max(x_max - x_min, y_max - y_min, z_max - z_min) + padding = max_range * 0.1 or 0.1 + + # Center each axis range + x_center = (x_max + x_min) / 2 + y_center = (y_max + y_min) / 2 + z_center = (z_max + z_min) / 2 + + half_range = (max_range + 2 * padding) / 2 + + self.ax.set_xlim(x_center - half_range, x_center + half_range) + self.ax.set_ylim(y_center - half_range, y_center + half_range) + self.ax.set_zlim(z_center - half_range, z_center + half_range) + self._limits_set = True + + self._lock.acquire() + self.fig.canvas.draw() + self.fig.canvas.flush_events() + self._lock.release() + + def cleanup(self): + super().cleanup() + + self.scatter = None + self.colorbar = None + self.positions = None + self.distances = None diff --git a/genesis/utils/raycast.py b/genesis/utils/raycast.py new file mode 100644 index 0000000000..4779622309 --- /dev/null +++ b/genesis/utils/raycast.py @@ -0,0 +1,121 @@ +import gstaichi as ti + +import genesis as gs + + +@ti.func +def ray_triangle_intersection(ray_start, ray_dir, v0, v1, v2): + """ + Moller-Trumbore ray-triangle intersection. + + Returns: vec4(t, u, v, hit) where hit=1.0 if intersection found, 0.0 otherwise + """ + result = ti.Vector.zero(gs.ti_float, 4) + + edge1 = v1 - v0 + edge2 = v2 - v0 + + # Begin calculating determinant - also used to calculate u parameter + h = ray_dir.cross(edge2) + a = edge1.dot(h) + + # Check all conditions in sequence without early returns + valid = True + + t = gs.ti_float(0.0) + u = gs.ti_float(0.0) + v = gs.ti_float(0.0) + f = gs.ti_float(0.0) + s = ti.Vector.zero(gs.ti_float, 3) + q = ti.Vector.zero(gs.ti_float, 3) + + # If determinant is near zero, ray lies in plane of triangle + if ti.abs(a) < gs.EPS: + valid = False + + if valid: + f = 1.0 / a + s = ray_start - v0 + u = f * s.dot(h) + + if u < 0.0 or u > 1.0: + valid = False + + if valid: + q = s.cross(edge1) + v = f * ray_dir.dot(q) + + if v < 0.0 or u + v > 1.0: + valid = False + + if valid: + # At this stage we can compute t to find out where the intersection point is on the line + t = f * edge2.dot(q) + + # Ray intersection + if t <= gs.EPS: + valid = False + + if valid: + result = ti.math.vec4(t, u, v, 1.0) + + return result + + +@ti.func +def ray_aabb_intersection(ray_start, ray_dir, aabb_min, aabb_max): + """ + Fast ray-AABB intersection test. + Returns the t value of intersection, or -1.0 if no intersection. + """ + result = -1.0 + + # Use the slab method for ray-AABB intersection + sign = ti.select(ray_dir >= 0.0, 1.0, -1.0) + ray_dir = sign * ti.max(ti.abs(ray_dir), gs.EPS) + inv_dir = 1.0 / ray_dir + + t1 = (aabb_min - ray_start) * inv_dir + t2 = (aabb_max - ray_start) * inv_dir + + tmin = ti.min(t1, t2) + tmax = ti.max(t1, t2) + + t_near = ti.max(tmin.x, tmin.y, tmin.z, 0.0) + t_far = ti.min(tmax.x, tmax.y, tmax.z) + + # Check if ray intersects AABB + if t_near <= t_far: + result = t_near + + return result + + +@ti.kernel +def kernel_update_aabbs( + free_verts_state: ti.template(), + fixed_verts_state: ti.template(), + verts_info: ti.template(), + faces_info: ti.template(), + # FIXME: can't import array_class since it is before gs.init + # free_verts_state: array_class.VertsState, + # fixed_verts_state: array_class.VertsState, + # verts_info: array_class.VertsInfo, + # faces_info: array_class.FacesInfo, + aabb_state: ti.template(), +): + for i_b, i_f in ti.ndrange(free_verts_state.pos.shape[1], faces_info.verts_idx.shape[0]): + aabb_state.aabbs[i_b, i_f].min.fill(ti.math.inf) + aabb_state.aabbs[i_b, i_f].max.fill(-ti.math.inf) + + for i in ti.static(range(3)): + i_v = faces_info.verts_idx[i_f][i] + i_fv = verts_info.verts_state_idx[i_v] + if verts_info.is_fixed[i_v]: + pos_v = fixed_verts_state.pos[i_fv] + aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) + aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) + else: + pos_v = free_verts_state.pos[i_fv, i_b] + aabb_state.aabbs[i_b, i_f].min = ti.min(aabb_state.aabbs[i_b, i_f].min, pos_v) + aabb_state.aabbs[i_b, i_f].max = ti.max(aabb_state.aabbs[i_b, i_f].max, pos_v) diff --git a/genesis/vis/viewer.py b/genesis/vis/viewer.py index 3a0c24d4c8..02613bdf19 100644 --- a/genesis/vis/viewer.py +++ b/genesis/vis/viewer.py @@ -1,6 +1,6 @@ +import importlib import os import threading -import importlib from typing import TYPE_CHECKING import numpy as np @@ -9,11 +9,10 @@ import genesis as gs import genesis.utils.geom as gu - from genesis.ext import pyrender from genesis.repr_base import RBC -from genesis.utils.tools import Rate from genesis.utils.misc import redirect_libc_stderr, tensor_to_array +from genesis.utils.tools import Rate if TYPE_CHECKING: from genesis.options.vis import ViewerOptions @@ -40,16 +39,12 @@ def __init__(self, options: "ViewerOptions", context): self._camera_init_lookat = np.asarray(options.camera_lookat, dtype=gs.np_float) self._camera_up = np.asarray(options.camera_up, dtype=gs.np_float) self._camera_fov = options.camera_fov - self._enable_interaction = options.enable_interaction - self._disable_keyboard_shortcuts = options.disable_keyboard_shortcuts + self._viewer_plugin = options.viewer_plugin # Validate viewer options if any(e.shape != (3,) for e in (self._camera_init_pos, self._camera_init_lookat, self._camera_up)): gs.raise_exception("ViewerOptions.camera_(pos|lookat|up) must be sequences of length 3.") - if options.enable_interaction and gs.backend != gs.cpu: - gs.logger.warning("Interaction code is slow on GPU. Switch to CPU backend or disable interaction.") - self._pyrender_viewer = None self.context = context @@ -100,8 +95,7 @@ def build(self, scene): shadow=self.context.shadow, plane_reflection=self.context.plane_reflection, env_separate_rigid=self.context.env_separate_rigid, - enable_interaction=self._enable_interaction, - disable_keyboard_shortcuts=self._disable_keyboard_shortcuts, + plugin_options=self._viewer_plugin, viewer_flags={ "window_title": f"Genesis {gs.__version__}", "refresh_rate": self._refresh_rate, diff --git a/tests/test_render.py b/tests/test_render.py index c5815720c8..399b0a999e 100644 --- a/tests/test_render.py +++ b/tests/test_render.py @@ -9,7 +9,6 @@ import pyglet import pytest import torch -import OpenGL.error import genesis as gs import genesis.utils.geom as gu @@ -1226,28 +1225,6 @@ def on_key_press(self, symbol: int, modifiers: int): assert f.read() == png_snapshot -@pytest.mark.required -@pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) -@pytest.mark.skipif(not IS_INTERACTIVE_VIEWER_AVAILABLE, reason="Interactive viewer not supported on this platform.") -@pytest.mark.xfail(sys.platform == "win32", raises=OpenGL.error.Error, reason="Invalid OpenGL context.") -def test_interactive_viewer_disable_keyboard_shortcuts(): - """Test that keyboard shortcuts can be disabled in the interactive viewer.""" - - # Test with keyboard shortcuts DISABLED - scene = gs.Scene( - viewer_options=gs.options.ViewerOptions( - disable_keyboard_shortcuts=True, - ), - show_viewer=True, - ) - scene.build() - pyrender_viewer = scene.visualizer.viewer._pyrender_viewer - assert pyrender_viewer.is_active - - # Verify the flag is set correctly - assert pyrender_viewer._disable_keyboard_shortcuts is True - - @pytest.mark.required @pytest.mark.parametrize("renderer_type", [RENDERER_TYPE.RASTERIZER]) def test_camera_gimbal_lock_singularity(renderer, show_viewer): From bfa9ecc9f9e48b115c2ffccd0ce5b7c0d5d96c5b Mon Sep 17 00:00:00 2001 From: Trinity Chung Date: Sat, 3 Jan 2026 23:00:59 -0500 Subject: [PATCH 2/5] implement keybindings --- examples/keyboard_teleop.py | 287 +++++++---------- .../ext/pyrender/interaction/keybindings.py | 107 +++++-- .../interaction/plugins/default_keyboard.py | 213 ------------- .../interaction/plugins/viewer_controls.py | 294 +++++++++--------- genesis/ext/pyrender/viewer.py | 97 +++--- genesis/vis/viewer.py | 11 + 6 files changed, 388 insertions(+), 621 deletions(-) delete mode 100644 genesis/ext/pyrender/interaction/plugins/default_keyboard.py diff --git a/examples/keyboard_teleop.py b/examples/keyboard_teleop.py index 4af79140ea..eaf59d3db4 100644 --- a/examples/keyboard_teleop.py +++ b/examples/keyboard_teleop.py @@ -18,167 +18,10 @@ import random import numpy as np -import pyglet from scipy.spatial.transform import Rotation as R -from typing_extensions import override import genesis as gs -from genesis.ext.pyrender.interaction import register_viewer_plugin -from genesis.ext.pyrender.interaction.plugins.viewer_controls import ViewerDefaultControls -from genesis.options.viewer_interactions import ViewerDefaultControls as ViewerDefaultControlsOptions - - -class FrankaTeleopOptions(ViewerDefaultControlsOptions): - """Options for Franka teleoperation plugin.""" - - pass - - -@register_viewer_plugin(FrankaTeleopOptions) -class FrankaTeleopPlugin(ViewerDefaultControls): - """ - Viewer plugin for teleoperating Franka robot with keyboard. - Extends ViewerDefaultControls to add robot-specific controls. - """ - - def __init__(self, viewer, options=None, camera=None, scene=None, viewport_size=None): - super().__init__(viewer, options, camera, scene, viewport_size) - - # Robot control state - self.robot = None - self.target_entity = None - self.cube_entity = None - self.target_pos = None - self.target_R = None - self.robot_init_pos = np.array([0.5, 0, 0.55]) - self.robot_init_R = R.from_euler("y", np.pi) - - # Control parameters - self.dpos = 0.002 - self.drot = 0.01 - self.is_close_gripper = False - - # keybindings - self.keybindings.extend( - dict( - move_forward=pyglet.window.key.UP, - move_backward=pyglet.window.key.DOWN, - move_left=pyglet.window.key.LEFT, - move_right=pyglet.window.key.RIGHT, - move_up=pyglet.window.key.N, - move_down=pyglet.window.key.M, - rotate_ccw=pyglet.window.key.J, - rotate_cw=pyglet.window.key.K, - reset_scene=pyglet.window.key.U, - close_gripper=pyglet.window.key.SPACE, - ) - ) - self._instr_texts = ( - ["> [i]: show keyboard instructions"], - ["< [i]: hide keyboard instructions"] - + self.keybindings.as_instruction_texts(padding=3, exclude=("toggle_keyboard_instructions")), - ) - - def set_entities(self, robot, target_entity, cube_entity): - """Set references to scene entities.""" - self.robot = robot - self.target_entity = target_entity - self.cube_entity = cube_entity - - # Initialize target pose - self.target_pos = self.robot_init_pos.copy() - self.target_R = self.robot_init_R - - # Get DOF indices - n_dofs = robot.n_dofs - self.motors_dof = np.arange(n_dofs - 2) - self.fingers_dof = np.arange(n_dofs - 2, n_dofs) - self.ee_link = robot.get_link("hand") - - # Reset to initial pose - self._reset_robot() - - def _reset_robot(self): - """Reset robot and cube to initial positions.""" - if self.robot is None: - return - - self.target_pos = self.robot_init_pos.copy() - self.target_R = self.robot_init_R - target_quat = self.target_R.as_quat(scalar_first=True) - self.target_entity.set_qpos(np.concatenate([self.target_pos, target_quat])) - q = self.robot.inverse_kinematics(link=self.ee_link, pos=self.target_pos, quat=target_quat) - self.robot.set_qpos(q[:-2], self.motors_dof) - - # Randomize cube position - self.cube_entity.set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) - self.cube_entity.set_quat(R.from_euler("z", random.uniform(0, np.pi * 2)).as_quat(scalar_first=True)) - - @override - def on_key_press(self, symbol: int, modifiers: int): - # First handle default viewer controls - result = super().on_key_press(symbol, modifiers) - - if self.robot is None: - return result - - # Handle teleoperation controls - if symbol == pyglet.window.key.UP: - self.target_pos[0] -= self.dpos - elif symbol == pyglet.window.key.DOWN: - self.target_pos[0] += self.dpos - elif symbol == pyglet.window.key.RIGHT: - self.target_pos[1] += self.dpos - elif symbol == pyglet.window.key.LEFT: - self.target_pos[1] -= self.dpos - elif symbol == pyglet.window.key.N: - self.target_pos[2] += self.dpos - elif symbol == pyglet.window.key.M: - self.target_pos[2] -= self.dpos - elif symbol == pyglet.window.key.J: - self.target_R = R.from_euler("z", self.drot) * self.target_R - elif symbol == pyglet.window.key.K: - self.target_R = R.from_euler("z", -self.drot) * self.target_R - elif symbol == pyglet.window.key.U: - self._reset_robot() - elif symbol == pyglet.window.key.SPACE: - self.is_close_gripper = True - - return result - - @override - def on_key_release(self, symbol: int, modifiers: int): - result = super().on_key_release(symbol, modifiers) - - if symbol == pyglet.window.key.SPACE: - self.is_close_gripper = False - - return result - - @override - def update_on_sim_step(self): - """Update robot control every simulation step.""" - super().update_on_sim_step() - - if self.robot is None: - return - - # Update target entity visualization - target_quat = self.target_R.as_quat(scalar_first=True) - self.target_entity.set_qpos(np.concatenate([self.target_pos, target_quat])) - - # Control arm with inverse kinematics - q, err = self.robot.inverse_kinematics( - link=self.ee_link, pos=self.target_pos, quat=target_quat, return_error=True - ) - self.robot.control_dofs_position(q[:-2], self.motors_dof) - - # Control gripper - if self.is_close_gripper: - self.robot.control_dofs_force(np.array([-1.0, -1.0]), self.fingers_dof) - else: - self.robot.control_dofs_force(np.array([1.0, 1.0]), self.fingers_dof) - +from genesis.ext.pyrender.interaction.keybindings import KeyAction, Keybind if __name__ == "__main__": ########################## init ########################## @@ -202,7 +45,6 @@ def update_on_sim_step(self): camera_lookat=(0.2, 0.0, 0.1), camera_fov=50, max_FPS=60, - viewer_plugin=FrankaTeleopOptions(), ), show_viewer=True, show_FPS=False, @@ -242,23 +84,114 @@ def update_on_sim_step(self): ########################## build ########################## scene.build() - # Set up the teleoperation plugin with entity references - teleop_plugin = scene.viewer.viewer_interaction - teleop_plugin.set_entities(robot, target, cube) - - print("\nKeyboard Controls:") - print("↑\t- Move Forward (North)") - print("↓\t- Move Backward (South)") - print("←\t- Move Left (West)") - print("→\t- Move Right (East)") - print("n\t- Move Up") - print("m\t- Move Down") - print("j\t- Rotate Counterclockwise") - print("k\t- Rotate Clockwise") - print("u\t- Reset Scene") - print("space\t- Press to close gripper, release to open gripper") - print("\nPress 'i' in the viewer to see all keyboard controls") + # Initialize robot control state + robot_init_pos = np.array([0.5, 0, 0.55]) + robot_init_R = R.from_euler("y", np.pi) + + # Get DOF indices + n_dofs = robot.n_dofs + motors_dof = np.arange(n_dofs - 2) + fingers_dof = np.arange(n_dofs - 2, n_dofs) + ee_link = robot.get_link("hand") + + # Initialize target pose + target_pos = robot_init_pos.copy() + target_R = [robot_init_R] # Use list to make it mutable in closures + + # Control parameters + dpos = 0.002 + drot = 0.01 + + # Helper function to reset robot + def reset_robot(): + """Reset robot and cube to initial positions.""" + target_pos[:] = robot_init_pos.copy() + target_R[0] = robot_init_R + target_quat = target_R[0].as_quat(scalar_first=True) + target.set_qpos(np.concatenate([target_pos, target_quat])) + q = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat) + robot.set_qpos(q[:-2], motors_dof) + + # Randomize cube position + cube.set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) + cube.set_quat(R.from_euler("z", random.uniform(0, np.pi * 2)).as_quat(scalar_first=True)) + + # Initialize robot pose + reset_robot() + + # Robot teleoperation callback functions + def move_forward(): + target_pos[0] -= dpos + + def move_backward(): + target_pos[0] += dpos + + def move_left(): + target_pos[1] -= dpos + + def move_right(): + target_pos[1] += dpos + + def move_up(): + target_pos[2] += dpos + + def move_down(): + target_pos[2] -= dpos + + def rotate_ccw(): + target_R[0] = R.from_euler("z", drot) * target_R[0] + + def rotate_cw(): + target_R[0] = R.from_euler("z", -drot) * target_R[0] + + def close_gripper(): + robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) + + def open_gripper(): + robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) + + # Register robot teleoperation keybindings + from pyglet.window import key + + scene.viewer.register_keybinds( + ( + Keybind(key_code=key.UP, key_action=KeyAction.HOLD, name="move_forward", callback_func=move_forward), + Keybind(key_code=key.DOWN, key_action=KeyAction.HOLD, name="move_backward", callback_func=move_backward), + Keybind(key_code=key.LEFT, key_action=KeyAction.HOLD, name="move_left", callback_func=move_left), + Keybind(key_code=key.RIGHT, key_action=KeyAction.HOLD, name="move_right", callback_func=move_right), + Keybind(key_code=key.N, key_action=KeyAction.HOLD, name="move_up", callback_func=move_up), + Keybind(key_code=key.M, key_action=KeyAction.HOLD, name="move_down", callback_func=move_down), + Keybind(key_code=key.J, key_action=KeyAction.HOLD, name="rotate_ccw", callback_func=rotate_ccw), + Keybind(key_code=key.K, key_action=KeyAction.HOLD, name="rotate_cw", callback_func=rotate_cw), + Keybind(key_code=key.U, key_action=KeyAction.HOLD, name="reset_scene", callback_func=reset_robot), + Keybind( + key_code=key.SPACE, + name="close_gripper", + callback_func=close_gripper, + key_action=KeyAction.PRESS, + ), + Keybind( + key_code=key.SPACE, + name="open_gripper", + callback_func=open_gripper, + key_action=KeyAction.RELEASE, + ), + ) + ) ########################## run simulation ########################## - while True: - scene.step() + try: + while True: + # Update target entity visualization + target_quat = target_R[0].as_quat(scalar_first=True) + target.set_qpos(np.concatenate([target_pos, target_quat])) + + # Control arm with inverse kinematics + q, err = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat, return_error=True) + robot.control_dofs_position(q[:-2], motors_dof) + + scene.step() + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") + finally: + gs.logger.info("Simulation finished.") diff --git a/genesis/ext/pyrender/interaction/keybindings.py b/genesis/ext/pyrender/interaction/keybindings.py index 3cd7ba7990..308bb11031 100644 --- a/genesis/ext/pyrender/interaction/keybindings.py +++ b/genesis/ext/pyrender/interaction/keybindings.py @@ -1,28 +1,85 @@ +from enum import IntEnum +from typing import Callable, NamedTuple + + +class KeyAction(IntEnum): + PRESS = 0 + HOLD = 1 + RELEASE = 2 + +def get_key_hash(key_code: int, modifiers: int | None, action: KeyAction) -> int: + """Generate a unique hash for a key combination. + + Parameters + ---------- + key_code : int + The key code from pyglet. + modifiers : int | None + The modifier keys pressed. + action : KeyAction + The type of key action (press, hold, release). + + Returns + ------- + int + A unique hash for this key combination. + """ + return hash((key_code, modifiers, action)) + +def get_keycode_string(key_code: int) -> str: + from pyglet.window.key import symbol_string + return symbol_string(key_code).lower() + +class Keybind(NamedTuple): + key_code: int + name: str = "" + key_action: KeyAction = KeyAction.PRESS + callback_func: Callable[[], None] | None = None + modifiers: int | None = None + args: tuple = () + kwargs: dict = {} + + def key_hash(self) -> int: + """Generate a unique hash for the keybind based on key code and modifiers.""" + return get_key_hash(self.key_code, self.modifiers, self.key_action) + class Keybindings: - def __init__(self, map: dict[str, int] = {}, **kwargs: dict[str, int]): - self._map: dict[str, int] = {**map, **kwargs} + + def __init__(self, keybinds: tuple[Keybind] = ()): + self._keybinds_map: dict[int, Keybind] = {kb.key_hash(): kb for kb in keybinds} + + def register(self, keybind: Keybind) -> None: + if keybind.key_hash() in self._keybinds_map: + existing_kb = self._keybinds_map[keybind.key_hash()] + raise ValueError( + f"Key '{get_keycode_string(keybind.key_code)}' is already assigned to '{existing_kb.name}'." + ) + self._keybinds_map[keybind.key_hash()] = keybind + + def get(self, key: int, modifiers: int, key_action: KeyAction) -> Keybind | None: + key_hash = get_key_hash(key, modifiers, key_action) + if key_hash in self._keybinds_map: + return self._keybinds_map[key_hash] + + # Try ignoring modifiers (for keybinds where modifiers=None) + key_hash_no_mods = get_key_hash(key, None, key_action) + if key_hash_no_mods in self._keybinds_map: + return self._keybinds_map[key_hash_no_mods] + + return None + + def get_by_name(self, name: str) -> Keybind | None: + for kb in self._keybinds_map.values(): + if kb.name == name: + return kb + return None - def __getattr__(self, name: str) -> int: - if name in self._map: - return self._map[name] - raise AttributeError(f"Action '{name}' not found in keybindings.") - - def as_instruction_texts(self, padding, exclude: tuple[str]) -> list[str]: - from pyglet.window.key import symbol_string - - width = 4 + padding - return [ - f"{'[' + symbol_string(self._map[action]).lower():>{width}}]: " + - action.replace('_', ' ') for action in self._map.keys() if action not in exclude - ] + @property + def keys(self) -> tuple[str]: + """Return a list of all registered keys as ASCII characters.""" + return tuple(get_keycode_string(kb.key_code) for kb in self._keybinds_map.values()) - def extend(self, mapping: dict[str, int], replace_only: bool = False) -> None: - from pyglet.window.key import symbol_string - - current_keys = self._map.keys() - for action, key in mapping.items(): - if replace_only and action not in self._map: - raise KeyError(f"Action '{action}' not found. Available actions: {list(self._map.keys())}") - if key in current_keys: - raise ValueError(f"Key '{symbol_string(key)}' is already assigned to another action.") - self._map[action] = key \ No newline at end of file + @property + def keybinds(self) -> tuple[Keybind]: + """Return a tuple of all registered Keybinds.""" + return tuple(self._keybinds_map.values()) \ No newline at end of file diff --git a/genesis/ext/pyrender/interaction/plugins/default_keyboard.py b/genesis/ext/pyrender/interaction/plugins/default_keyboard.py deleted file mode 100644 index 9c211e84e3..0000000000 --- a/genesis/ext/pyrender/interaction/plugins/default_keyboard.py +++ /dev/null @@ -1,213 +0,0 @@ -import os -from typing import TYPE_CHECKING - -import numpy as np -import pyglet -from typing_extensions import override - -import genesis as gs - -from ...constants import TEXT_PADDING -from ..base_interaction import EVENT_HANDLE_STATE, BaseViewerInteraction - -if TYPE_CHECKING: - from genesis.engine.scene import Scene - from genesis.ext.pyrender.node import Node - - -class ViewerControls(BaseViewerInteraction): - """ - Default keyboard controls for the Genesis viewer. - - This plugin handles the standard viewer keyboard shortcuts for recording, changing render modes, etc. - """ - - def __init__( - self, - viewer, - options=None, - camera: "Node" = None, - scene: "Scene" = None, - viewport_size: tuple[int, int] = None, - ): - super().__init__(viewer, options, camera, scene, viewport_size) - - # Instruction display state - self._display_instr = False - self._instr_texts = [ - ["> [i]: show keyboard instructions"], - [ - "< [i]: hide keyboard instructions", - " [r]: record video", - " [s]: save image", - " [z]: reset camera", - " [a]: camera rotation", - " [h]: shadow", - " [f]: face normal", - " [v]: vertex normal", - " [w]: world frame", - " [l]: link frame", - " [d]: wireframe", - " [c]: camera & frustrum", - " [F11]: full-screen mode", - ], - ] - - @override - def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.viewer is None: - return None - - # A causes the frame to rotate - self.viewer._message_text = None - if symbol == pyglet.window.key.A: - self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] - if self.viewer.viewer_flags["rotate"]: - self.viewer._message_text = "Rotation On" - else: - self.viewer._message_text = "Rotation Off" - - # F11 toggles fullscreen - elif symbol == pyglet.window.key.F11: - self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] - self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) - self.viewer.activate() - if self.viewer.viewer_flags["fullscreen"]: - self.viewer._message_text = "Fullscreen On" - else: - self.viewer._message_text = "Fullscreen Off" - - # H toggles shadows - elif symbol == pyglet.window.key.H: - self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] - if self.viewer.render_flags["shadows"]: - self.viewer._message_text = "Shadows On" - else: - self.viewer._message_text = "Shadows Off" - - # W toggles world frame - elif symbol == pyglet.window.key.W: - if not self.viewer.gs_context.world_frame_shown: - self.viewer.gs_context.on_world_frame() - self.viewer._message_text = "World Frame On" - else: - self.viewer.gs_context.off_world_frame() - self.viewer._message_text = "World Frame Off" - - # L toggles link frame - elif symbol == pyglet.window.key.L: - if not self.viewer.gs_context.link_frame_shown: - self.viewer.gs_context.on_link_frame() - self.viewer._message_text = "Link Frame On" - else: - self.viewer.gs_context.off_link_frame() - self.viewer._message_text = "Link Frame Off" - - # C toggles camera frustum - elif symbol == pyglet.window.key.C: - if not self.viewer.gs_context.camera_frustum_shown: - self.viewer.gs_context.on_camera_frustum() - self.viewer._message_text = "Camera Frustrum On" - else: - self.viewer.gs_context.off_camera_frustum() - self.viewer._message_text = "Camera Frustrum Off" - - # F toggles face normals - elif symbol == pyglet.window.key.F: - self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] - if self.viewer.render_flags["face_normals"]: - self.viewer._message_text = "Face Normals On" - else: - self.viewer._message_text = "Face Normals Off" - - # V toggles vertex normals - elif symbol == pyglet.window.key.V: - self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] - if self.viewer.render_flags["vertex_normals"]: - self.viewer._message_text = "Vert Normals On" - else: - self.viewer._message_text = "Vert Normals Off" - - # R starts recording frames - elif symbol == pyglet.window.key.R: - if self.viewer.viewer_flags["record"]: - self.viewer.save_video() - self.viewer.set_caption(self.viewer.viewer_flags["window_title"]) - else: - # Importing moviepy is very slow and not used very often. Let's delay import. - from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter - - self.viewer._video_recorder = FFMPEG_VideoWriter( - filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), - fps=self.viewer.viewer_flags["refresh_rate"], - size=self.viewer.viewport_size, - ) - self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) - self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] - - # S saves the current frame as an image - elif symbol == pyglet.window.key.S: - self.viewer._save_image() - - # D toggles through wireframe modes - elif symbol == pyglet.window.key.D: - if self.viewer.render_flags["flip_wireframe"]: - self.viewer.render_flags["flip_wireframe"] = False - self.viewer.render_flags["all_wireframe"] = True - self.viewer.render_flags["all_solid"] = False - self.viewer._message_text = "All Wireframe" - elif self.viewer.render_flags["all_wireframe"]: - self.viewer.render_flags["flip_wireframe"] = False - self.viewer.render_flags["all_wireframe"] = False - self.viewer.render_flags["all_solid"] = True - self.viewer._message_text = "All Solid" - elif self.viewer.render_flags["all_solid"]: - self.viewer.render_flags["flip_wireframe"] = False - self.viewer.render_flags["all_wireframe"] = False - self.viewer.render_flags["all_solid"] = False - self.viewer._message_text = "Default Wireframe" - else: - self.viewer.render_flags["flip_wireframe"] = True - self.viewer.render_flags["all_wireframe"] = False - self.viewer.render_flags["all_solid"] = False - self.viewer._message_text = "Flip Wireframe" - - # Z resets the camera viewpoint - elif symbol == pyglet.window.key.Z: - self.viewer._reset_view() - - # I toggles instruction display - elif symbol == pyglet.window.key.I: - self._display_instr = not self._display_instr - - # P reloads shader program - elif symbol == pyglet.window.key.P: - self.viewer._renderer.reload_program() - - if self.viewer._message_text is not None: - self.viewer._message_opac = 1.0 + self.viewer._ticks_till_fade - - return None - - @override - def on_draw(self): - """Render keyboard instructions.""" - if self.viewer is None: - return - - if self._display_instr: - self.viewer._renderer.render_texts( - self._instr_texts[1], - TEXT_PADDING, - self.viewer.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - else: - self.viewer._renderer.render_texts( - self._instr_texts[0], - TEXT_PADDING, - self.viewer.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) diff --git a/genesis/ext/pyrender/interaction/plugins/viewer_controls.py b/genesis/ext/pyrender/interaction/plugins/viewer_controls.py index 7857d122ce..5fe67a07e0 100644 --- a/genesis/ext/pyrender/interaction/plugins/viewer_controls.py +++ b/genesis/ext/pyrender/interaction/plugins/viewer_controls.py @@ -3,19 +3,21 @@ import numpy as np import pyglet -from genesis.options.viewer_interactions import ViewerDefaultControls as ViewerDefaultControlsOptions from typing_extensions import override import genesis as gs +from genesis.options.viewer_interactions import ViewerDefaultControls as ViewerDefaultControlsOptions from ...constants import TEXT_PADDING from ..base_interaction import EVENT_HANDLE_STATE, BaseViewerInteraction, register_viewer_plugin -from ..keybindings import Keybindings +from ..keybindings import Keybind, get_keycode_string if TYPE_CHECKING: from genesis.engine.scene import Scene from genesis.ext.pyrender.node import Node +INSTR_KEYBIND_NAME = "toggle_instructions" + @register_viewer_plugin(ViewerDefaultControlsOptions) class ViewerDefaultControls(BaseViewerInteraction): """ @@ -33,163 +35,159 @@ def __init__( viewport_size: tuple[int, int] = None, ): super().__init__(viewer, options, camera, scene, viewport_size) - - self.keybindings: Keybindings = Keybindings( - toggle_keyboard_instructions=pyglet.window.key.I, - record_video=pyglet.window.key.R, - save_image=pyglet.window.key.S, - reset_camera=pyglet.window.key.Z, - camera_rotation=pyglet.window.key.A, - shadow=pyglet.window.key.H, - face_normals=pyglet.window.key.F, - vertex_normals=pyglet.window.key.V, - world_frame=pyglet.window.key.W, - link_frame=pyglet.window.key.L, - wireframe=pyglet.window.key.D, - camera_frustum=pyglet.window.key.C, - fullscreen_mode=pyglet.window.key.F11, - ) - if options and options.keybindings: - self.keybindings.apply_override_mapping(options.keybindings) + self.viewer.register_keybinds(( + Keybind(key_code=pyglet.window.key.I, name=INSTR_KEYBIND_NAME, callback_func=self.toggle_instructions), + Keybind(key_code=pyglet.window.key.R, name="record_video", callback_func=self.toggle_record_video), + Keybind(key_code=pyglet.window.key.S, name="save_image", callback_func=self.save_image), + Keybind(key_code=pyglet.window.key.Z, name="reset_camera", callback_func=self.reset_camera), + Keybind(key_code=pyglet.window.key.A, name="camera_rotation", callback_func=self.toggle_camera_rotation), + Keybind(key_code=pyglet.window.key.H, name="shadow", callback_func=self.toggle_shadow), + Keybind(key_code=pyglet.window.key.F, name="face_normals", callback_func=self.toggle_face_normals), + Keybind(key_code=pyglet.window.key.V, name="vertex_normals", callback_func=self.toggle_vertex_normals), + Keybind(key_code=pyglet.window.key.W, name="world_frame", callback_func=self.toggle_world_frame), + Keybind(key_code=pyglet.window.key.L, name="link_frame", callback_func=self.toggle_link_frame), + Keybind(key_code=pyglet.window.key.D, name="wireframe", callback_func=self.toggle_wireframe), + Keybind(key_code=pyglet.window.key.C, name="camera_frustum", callback_func=self.toggle_camera_frustum), + Keybind(key_code=pyglet.window.key.P, name="reload_shader", callback_func=self.reload_shader), + Keybind(key_code=pyglet.window.key.F11, name="fullscreen_mode", callback_func=self.toggle_fullscreen), + )) self._display_instr = False + self._instr_texts: tuple[list[str], list[str]] = ([], []) + self._update_instr_texts() + + def _update_instr_texts(self): + self.instr_key_str = get_keycode_string(self.viewer._keybindings.get_by_name(INSTR_KEYBIND_NAME).key_code) + kb_texts = [ + f"{'[' + get_keycode_string(kb.key_code):>{7}}]: " + + kb.name.replace('_', ' ') for kb in self.viewer._keybindings.keybinds if kb.name != INSTR_KEYBIND_NAME + ] self._instr_texts = ( - ["> [i]: show keyboard instructions"], - ["< [i]: hide keyboard instructions"] + self.keybindings.as_instruction_texts( - padding=3, exclude=("toggle_keyboard_instructions")), + [f"> [{self.instr_key_str}]: show keyboard instructions"], + [f"< [{self.instr_key_str}]: hide keyboard instructions"] + kb_texts ) + + def toggle_instructions(self): + self._display_instr = not self._display_instr + self._update_instr_texts() + + def toggle_camera_rotation(self): + self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] + if self.viewer.viewer_flags["rotate"]: + self.viewer._message_text = "Rotation On" + else: + self.viewer._message_text = "Rotation Off" + + def toggle_fullscreen(self): + self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] + self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) + self.viewer.activate() + if self.viewer.viewer_flags["fullscreen"]: + self.viewer._message_text = "Fullscreen On" + else: + self.viewer._message_text = "Fullscreen Off" + + def toggle_shadow(self): + self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] + if self.viewer.render_flags["shadows"]: + self.viewer._message_text = "Shadows On" + else: + self.viewer._message_text = "Shadows Off" + + def toggle_world_frame(self): + if not self.viewer.gs_context.world_frame_shown: + self.viewer.gs_context.on_world_frame() + self.viewer._message_text = "World Frame On" + else: + self.viewer.gs_context.off_world_frame() + self.viewer._message_text = "World Frame Off" + + def toggle_link_frame(self): + if not self.viewer.gs_context.link_frame_shown: + self.viewer.gs_context.on_link_frame() + self.viewer._message_text = "Link Frame On" + else: + self.viewer.gs_context.off_link_frame() + self.viewer._message_text = "Link Frame Off" + + def toggle_camera_frustum(self): + if not self.viewer.gs_context.camera_frustum_shown: + self.viewer.gs_context.on_camera_frustum() + self.viewer._message_text = "Camera Frustrum On" + else: + self.viewer.gs_context.off_camera_frustum() + self.viewer._message_text = "Camera Frustrum Off" + + def toggle_face_normals(self): + self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] + if self.viewer.render_flags["face_normals"]: + self.viewer._message_text = "Face Normals On" + else: + self.viewer._message_text = "Face Normals Off" + + def toggle_vertex_normals(self): + self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] + if self.viewer.render_flags["vertex_normals"]: + self.viewer._message_text = "Vert Normals On" + else: + self.viewer._message_text = "Vert Normals Off" + + def toggle_record_video(self): + if self.viewer.viewer_flags["record"]: + self.viewer.save_video() + self.viewer.set_caption(self.viewer.viewer_flags["window_title"]) + else: + # Importing moviepy is very slow and not used very often. Let's delay import. + from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter + + self.viewer._video_recorder = FFMPEG_VideoWriter( + filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), + fps=self.viewer.viewer_flags["refresh_rate"], + size=self.viewer.viewport_size, + ) + self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) + self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] + + def save_image(self): + self.viewer._save_image() + + def toggle_wireframe(self): + if self.viewer.render_flags["flip_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = True + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "All Wireframe" + elif self.viewer.render_flags["all_wireframe"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = True + self.viewer._message_text = "All Solid" + elif self.viewer.render_flags["all_solid"]: + self.viewer.render_flags["flip_wireframe"] = False + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "Default Wireframe" + else: + self.viewer.render_flags["flip_wireframe"] = True + self.viewer.render_flags["all_wireframe"] = False + self.viewer.render_flags["all_solid"] = False + self.viewer._message_text = "Flip Wireframe" + + def reset_camera(self): + self.viewer._reset_view() + + def reload_shader(self): + self.viewer._renderer.reload_program() @override def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: if self.viewer is None: return None - # A causes the frame to rotate + # Reset message text and check for keybinding self.viewer._message_text = None - if symbol == self.keybindings.camera_rotation: - self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] - if self.viewer.viewer_flags["rotate"]: - self.viewer._message_text = "Rotation On" - else: - self.viewer._message_text = "Rotation Off" - - # F11 toggles fullscreen - elif symbol == self.keybindings.fullscreen_mode: - self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] - self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) - self.viewer.activate() - if self.viewer.viewer_flags["fullscreen"]: - self.viewer._message_text = "Fullscreen On" - else: - self.viewer._message_text = "Fullscreen Off" - - # H toggles shadows - elif symbol == self.keybindings.shadow: - self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] - if self.viewer.render_flags["shadows"]: - self.viewer._message_text = "Shadows On" - else: - self.viewer._message_text = "Shadows Off" - - # W toggles world frame - elif symbol == self.keybindings.world_frame: - if not self.viewer.gs_context.world_frame_shown: - self.viewer.gs_context.on_world_frame() - self.viewer._message_text = "World Frame On" - else: - self.viewer.gs_context.off_world_frame() - self.viewer._message_text = "World Frame Off" - - # L toggles link frame - elif symbol == self.keybindings.link_frame: - if not self.viewer.gs_context.link_frame_shown: - self.viewer.gs_context.on_link_frame() - self.viewer._message_text = "Link Frame On" - else: - self.viewer.gs_context.off_link_frame() - self.viewer._message_text = "Link Frame Off" - - # C toggles camera frustum - elif symbol == self.keybindings.camera_frustum: - if not self.viewer.gs_context.camera_frustum_shown: - self.viewer.gs_context.on_camera_frustum() - self.viewer._message_text = "Camera Frustrum On" - else: - self.viewer.gs_context.off_camera_frustum() - self.viewer._message_text = "Camera Frustrum Off" - - # F toggles face normals - elif symbol == self.keybindings.face_normals: - self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] - if self.viewer.render_flags["face_normals"]: - self.viewer._message_text = "Face Normals On" - else: - self.viewer._message_text = "Face Normals Off" - - # V toggles vertex normals - elif symbol == self.keybindings.vertex_normals: - self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] - if self.viewer.render_flags["vertex_normals"]: - self.viewer._message_text = "Vert Normals On" - else: - self.viewer._message_text = "Vert Normals Off" - - # R starts recording frames - elif symbol == self.keybindings.record_video: - if self.viewer.viewer_flags["record"]: - self.viewer.save_video() - self.viewer.set_caption(self.viewer.viewer_flags["window_title"]) - else: - # Importing moviepy is very slow and not used very often. Let's delay import. - from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter - - self.viewer._video_recorder = FFMPEG_VideoWriter( - filename=os.path.join(gs.utils.misc.get_cache_dir(), "tmp_video.mp4"), - fps=self.viewer.viewer_flags["refresh_rate"], - size=self.viewer.viewport_size, - ) - self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) - self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] - - # S saves the current frame as an image - elif symbol == self.keybindings.save_image: - self.viewer._save_image() - - # D toggles through wireframe modes - elif symbol == self.keybindings.wireframe: - if self.viewer.render_flags["flip_wireframe"]: - self.viewer.render_flags["flip_wireframe"] = False - self.viewer.render_flags["all_wireframe"] = True - self.viewer.render_flags["all_solid"] = False - self.viewer._message_text = "All Wireframe" - elif self.viewer.render_flags["all_wireframe"]: - self.viewer.render_flags["flip_wireframe"] = False - self.viewer.render_flags["all_wireframe"] = False - self.viewer.render_flags["all_solid"] = True - self.viewer._message_text = "All Solid" - elif self.viewer.render_flags["all_solid"]: - self.viewer.render_flags["flip_wireframe"] = False - self.viewer.render_flags["all_wireframe"] = False - self.viewer.render_flags["all_solid"] = False - self.viewer._message_text = "Default Wireframe" - else: - self.viewer.render_flags["flip_wireframe"] = True - self.viewer.render_flags["all_wireframe"] = False - self.viewer.render_flags["all_solid"] = False - self.viewer._message_text = "Flip Wireframe" - - # Z resets the camera viewpoint - elif symbol == self.keybindings.reset_camera: - self.viewer._reset_view() - - # I toggles instruction display - elif symbol == self.keybindings.toggle_keyboard_instructions: - self._display_instr = not self._display_instr - - # P reloads shader program - elif symbol == self.keybindings.reload_shader: - self.viewer._renderer.reload_program() - + super().on_key_press(symbol, modifiers) + if self.viewer._message_text is not None: self.viewer._message_opac = 1.0 + self.viewer._ticks_till_fade diff --git a/genesis/ext/pyrender/viewer.py b/genesis/ext/pyrender/viewer.py index 65d3452734..ef277ad7a5 100644 --- a/genesis/ext/pyrender/viewer.py +++ b/genesis/ext/pyrender/viewer.py @@ -16,6 +16,8 @@ import genesis as gs +from .interaction.keybindings import KeyAction, Keybind, Keybindings + # Importing tkinter and creating a first context before importing pyglet is necessary to avoid later segfault on MacOS. # Note that destroying the window will cause segfault at exit. root = None @@ -79,17 +81,7 @@ class Viewer(pyglet.window.Window): viewer_flags : dict A set of flags for controlling the viewer's behavior. Described in the note below. - registered_keys : dict - A map from ASCII key characters to tuples containing: - - - A function to be called whenever the key is pressed, - whose first argument will be the viewer itself. - - (Optionally) A list of additional positional arguments - to be passed to the function. - - (Optionally) A dict of keyword arguments to be passed - to the function. - - kwargs : dict + **kwargs : dict Any keyword arguments left over will be interpreted as belonging to either the :attr:`.Viewer.render_flags` or :attr:`.Viewer.viewer_flags` dictionaries. Those flag sets will be updated appropriately. @@ -198,13 +190,12 @@ def __init__( viewport_size=None, render_flags=None, viewer_flags=None, - registered_keys=None, run_in_thread=False, auto_start=True, shadow=False, plane_reflection=False, env_separate_rigid=False, - plugin_options=gs.options.viewer_interactions.ViewerDefaultControls(), + plugin_options=None, **kwargs, ): ####################################################################### @@ -280,10 +271,8 @@ def __init__( elif key in self.viewer_flags: self._viewer_flags[key] = kwargs[key] - self._registered_keys = {} - if registered_keys is not None: - self._registered_keys = {ord(k.lower()): registered_keys[k] for k in registered_keys} - + self._keybindings: Keybindings = Keybindings() + self._held_keys: dict[tuple[int, int], bool] = {} # Track held keys: (symbol, modifiers) -> True ####################################################################### # Save internal settings ####################################################################### @@ -357,7 +346,10 @@ def __init__( # Note: context.scene is genesis.engine.scene.Scene # Note: context._scene is genesis.ext.pyrender.scene.Scene - # Setup viewer plugin + # Setup viewer interaction + if plugin_options is None: + plugin_options = gs.options.viewer_interactions.ViewerDefaultControls() + plugin_cls = VIEWER_PLUGIN_MAP.get(type(plugin_options)) if plugin_cls is None: gs.raise_exception( @@ -499,26 +491,18 @@ def viewer_flags(self): @viewer_flags.setter def viewer_flags(self, value): self._viewer_flags = value - - @property - def registered_keys(self): - """dict : Map from ASCII key character to a handler function. - - This is a map from ASCII key characters to tuples containing: - - - A function to be called whenever the key is pressed, - whose first argument will be the viewer itself. - - (Optionally) A list of additional positional arguments - to be passed to the function. - - (Optionally) A dict of keyword arguments to be passed - to the function. - + + def register_keybinds(self, keybinds: tuple[Keybind]) -> None: """ - return self._registered_keys + Add a key handler to call a function when the given ASCII key is pressed. - @registered_keys.setter - def registered_keys(self, value): - self._registered_keys = value + Parameters + ---------- + keybinds : tuple[Keybind] + A tuple of Keybind objects to register. + """ + for keybind in keybinds: + self._keybindings.register(keybind) def close(self): """Close the viewer. @@ -714,9 +698,7 @@ def on_draw(self): self.clear() self._render() - self.viewer_interaction.on_draw() - - self.interaction_plugin.on_draw() + self.interaction_plugin.on_draw() if self.viewer_flags["caption"] is not None: for caption in self.viewer_flags["caption"]: @@ -803,33 +785,29 @@ def on_mouse_scroll(self, x, y, dx, dy): ymag = max(c.ymag * sf, 1e-8 * c.ymag / c.xmag) c.xmag = xmag c.ymag = ymag + + def _call_keybind_callback(self, symbol: int, modifiers: int, action: KeyAction) -> None: + """Call registered keybind callbacks for the given key event.""" + keybind: Keybind = self._keybindings.get(symbol, modifiers, action) + if keybind is not None and keybind.callback_func is not None: + keybind.callback_func(*keybind.args, **keybind.kwargs) def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a key press.""" - # First, check for registered key callbacks - if symbol in self.registered_keys: - tup = self.registered_keys[symbol] - callback = None - args = [] - kwargs = {} - if not isinstance(tup, (list, tuple, np.ndarray)): - callback = tup - else: - callback = tup[0] - if len(tup) == 2: - args = tup[1] - if len(tup) == 3: - kwargs = tup[2] - callback(self, *args, **kwargs) - # Continue to plugins after registered callback - - # Delegate to viewer plugin + self._held_keys[(symbol, modifiers)] = True + self._call_keybind_callback(symbol, modifiers, KeyAction.PRESS) return self.interaction_plugin.on_key_press(symbol, modifiers) def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a key release.""" + self._held_keys.pop((symbol, modifiers), None) + self._call_keybind_callback(symbol, modifiers, KeyAction.RELEASE) return self.interaction_plugin.on_key_release(symbol, modifiers) + def on_deactivate(self) -> EVENT_HANDLE_STATE: + """Clear held keys when window loses focus.""" + self._held_keys.clear() + @staticmethod def _time_event(dt, self): """The timer callback.""" @@ -1161,7 +1139,7 @@ def start(self, auto_refresh=True): # The viewer can be considered as fully initialized at this point if not self._initialized_event.is_set(): self._initialized_event.set() - + if auto_refresh: while self._is_active: try: @@ -1213,6 +1191,9 @@ def refresh(self): self.flip() def update_on_sim_step(self): + # Call HOLD callbacks for all currently held keys + for (symbol, modifiers) in list(self._held_keys.keys()): + self._call_keybind_callback(symbol, modifiers, KeyAction.HOLD) self.interaction_plugin.update_on_sim_step() def _compute_initial_camera_pose(self): diff --git a/genesis/vis/viewer.py b/genesis/vis/viewer.py index 02613bdf19..ce2da98c63 100644 --- a/genesis/vis/viewer.py +++ b/genesis/vis/viewer.py @@ -259,6 +259,17 @@ def update_following(self): else: self.set_camera_pose(pos=camera_pos, lookat=self._follow_lookat) + def register_keybinds(self, keybinds: tuple[pyrender.interaction.keybindings.Keybind]) -> None: + """ + Register a callback function to be called when a key is pressed. + + Parameters + ---------- + keybinds : tuple[gs.ext.pyrender.interaction.keybindings.Keybind] + The Keybind objects to register. See Keybind documentation for usage. + """ + self._pyrender_viewer.register_keybinds(keybinds) + # ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------ From 69e569c8b9e72be86273061272658f147765d04c Mon Sep 17 00:00:00 2001 From: Trinity Chung Date: Sun, 4 Jan 2026 23:21:08 -0500 Subject: [PATCH 3/5] update examples and refactor out help text --- .github/workflows/examples.yml | 1 - examples/IPC_Solver/ipc_arm_cloth.py | 267 +++++++++--------- examples/drone/interactive_drone.py | 183 ++++++------ examples/interactive_viewer/mouse_spring.py | 2 +- examples/sensors/lidar_teleop.py | 170 +++++------ genesis/ext/pyrender/interaction/__init__.py | 11 +- .../ext/pyrender/interaction/keybindings.py | 35 ++- .../pyrender/interaction/plugins/__init__.py | 10 +- ...viewer_controls.py => default_controls.py} | 164 ++++------- .../pyrender/interaction/plugins/help_text.py | 104 +++++++ .../interaction/plugins/mesh_selector.py | 9 +- .../{mouse_interaction.py => mouse_spring.py} | 11 +- .../{base_interaction.py => viewer_plugin.py} | 8 +- genesis/ext/pyrender/viewer.py | 61 ++-- ...ewer_interactions.py => viewer_plugins.py} | 30 +- genesis/options/vis.py | 6 +- genesis/vis/viewer.py | 32 ++- 17 files changed, 583 insertions(+), 521 deletions(-) rename genesis/ext/pyrender/interaction/plugins/{viewer_controls.py => default_controls.py} (50%) create mode 100644 genesis/ext/pyrender/interaction/plugins/help_text.py rename genesis/ext/pyrender/interaction/plugins/{mouse_interaction.py => mouse_spring.py} (96%) rename genesis/ext/pyrender/interaction/{base_interaction.py => viewer_plugin.py} (93%) rename genesis/options/{viewer_interactions.py => viewer_plugins.py} (73%) diff --git a/.github/workflows/examples.yml b/.github/workflows/examples.yml index 14df548019..1d58431e64 100644 --- a/.github/workflows/examples.yml +++ b/.github/workflows/examples.yml @@ -51,7 +51,6 @@ jobs: run: | pip install --upgrade pip setuptools wheel pip install torch --index-url https://download.pytorch.org/whl/cpu - pip install -e '.[dev]' pynput - name: Get gstaichi version id: gstaichi_version diff --git a/examples/IPC_Solver/ipc_arm_cloth.py b/examples/IPC_Solver/ipc_arm_cloth.py index 20df431765..a8fe239940 100644 --- a/examples/IPC_Solver/ipc_arm_cloth.py +++ b/examples/IPC_Solver/ipc_arm_cloth.py @@ -14,45 +14,21 @@ ; - Roll Right (Rotate around X) u - Reset Scene space - Press to close gripper, release to open gripper -esc - Quit + +Plus all default viewer controls (press 'i' to see them) """ -import threading import argparse -import numpy as np import csv import os from datetime import datetime -from pynput import keyboard -from scipy.spatial.transform import Rotation as R + +import numpy as np from huggingface_hub import snapshot_download +from scipy.spatial.transform import Rotation as R import genesis as gs - - -class KeyboardDevice: - def __init__(self): - self.pressed_keys = set() - self.lock = threading.Lock() - self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) - - def start(self): - self.listener.start() - - def stop(self): - self.listener.stop() - self.listener.join() - - def on_press(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.add(key) - - def on_release(self, key: keyboard.Key): - with self.lock: - self.pressed_keys.discard(key) - - def get_cmd(self): - return self.pressed_keys +from genesis.ext.pyrender.interaction.keybindings import KeyAction, Keybind def build_scene(use_ipc=False, show_viewer=False, enable_ipc_gui=False): @@ -171,14 +147,14 @@ def build_scene(use_ipc=False, show_viewer=False, enable_ipc_gui=False): return scene, entities -def run_sim(scene, entities, clients, mode="interactive", trajectory_file=None): +def run_sim(scene, entities, mode="interactive", trajectory_file=None): robot = entities["robot"] target_entity = entities["target"] robot_init_pos = np.array([0.5, 0, 0.55]) robot_init_R = R.from_euler("y", np.pi) - target_pos = robot_init_pos.copy() - target_R = robot_init_R + target_pos = [robot_init_pos.copy()] # Use list for mutability in closures + target_R = [robot_init_R] # Use list for mutability in closures n_dofs = robot.n_dofs motors_dof = np.arange(n_dofs - 2) @@ -189,18 +165,97 @@ def run_sim(scene, entities, clients, mode="interactive", trajectory_file=None): trajectory = [] recording = mode == "record" + # Gripper state (use list for mutability in closures) + gripper_closed = [False] + + # Control parameters + dpos = 0.002 + drot = 0.01 + def reset_scene(): - nonlocal target_pos, target_R - target_pos = robot_init_pos.copy() - target_R = robot_init_R - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) - q = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat) + target_pos[0] = robot_init_pos.copy() + target_R[0] = robot_init_R + target_quat = target_R[0].as_quat(scalar_first=True) + target_entity.set_qpos(np.concatenate([target_pos[0], target_quat])) + q = robot.inverse_kinematics(link=ee_link, pos=target_pos[0], quat=target_quat) robot.set_qpos(q[:-2], motors_dof) # entities["cube"].set_pos((random.uniform(0.2, 0.4), random.uniform(-0.2, 0.2), 0.05)) # entities["cube"].set_quat(R.from_euler("z", random.uniform(0, np.pi * 2)).as_quat(scalar_first=True)) + # Define movement callbacks + def move_forward(): + target_pos[0][0] -= dpos + + def move_backward(): + target_pos[0][0] += dpos + + def move_left(): + target_pos[0][1] -= dpos + + def move_right(): + target_pos[0][1] += dpos + + def move_up(): + target_pos[0][2] += dpos + + def move_down(): + target_pos[0][2] -= dpos + + def yaw_left(): + target_R[0] = R.from_euler("z", drot) * target_R[0] + + def yaw_right(): + target_R[0] = R.from_euler("z", -drot) * target_R[0] + + def pitch_up(): + target_R[0] = R.from_euler("y", drot) * target_R[0] + + def pitch_down(): + target_R[0] = R.from_euler("y", -drot) * target_R[0] + + def roll_left(): + target_R[0] = R.from_euler("x", drot) * target_R[0] + + def roll_right(): + target_R[0] = R.from_euler("x", -drot) * target_R[0] + + def close_gripper(): + gripper_closed[0] = True + + def open_gripper(): + gripper_closed[0] = False + + # Register keybindings (only for interactive and record modes) + if mode in ["interactive", "record"]: + from pyglet.window import key + + scene.viewer.register_keybinds( + ( + Keybind(key_code=key.UP, key_action=KeyAction.HOLD, name="move_forward", callback_func=move_forward), + Keybind( + key_code=key.DOWN, key_action=KeyAction.HOLD, name="move_backward", callback_func=move_backward + ), + Keybind(key_code=key.LEFT, key_action=KeyAction.HOLD, name="move_left", callback_func=move_left), + Keybind(key_code=key.RIGHT, key_action=KeyAction.HOLD, name="move_right", callback_func=move_right), + Keybind(key_code=key.N, key_action=KeyAction.HOLD, name="move_up", callback_func=move_up), + Keybind(key_code=key.M, key_action=KeyAction.HOLD, name="move_down", callback_func=move_down), + Keybind(key_code=key.J, key_action=KeyAction.HOLD, name="yaw_left", callback_func=yaw_left), + Keybind(key_code=key.K, key_action=KeyAction.HOLD, name="yaw_right", callback_func=yaw_right), + Keybind(key_code=key.I, key_action=KeyAction.HOLD, name="pitch_up", callback_func=pitch_up), + Keybind(key_code=key.O, key_action=KeyAction.HOLD, name="pitch_down", callback_func=pitch_down), + Keybind(key_code=key.L, key_action=KeyAction.HOLD, name="roll_left", callback_func=roll_left), + Keybind(key_code=key.SEMICOLON, key_action=KeyAction.HOLD, name="roll_right", callback_func=roll_right), + Keybind(key_code=key.U, key_action=KeyAction.HOLD, name="reset_scene", callback_func=reset_scene), + Keybind( + key_code=key.SPACE, key_action=KeyAction.PRESS, name="close_gripper", callback_func=close_gripper + ), + Keybind( + key_code=key.SPACE, key_action=KeyAction.RELEASE, name="open_gripper", callback_func=open_gripper + ), + ) + ) + # Load trajectory if in playback mode if mode == "playback": if not trajectory_file or not os.path.exists(trajectory_file): @@ -232,7 +287,7 @@ def reset_scene(): print(f"\nMode: {mode.upper()}") if mode == "record": - print("Recording trajectory... Press ESC to stop and save.") + print("Recording trajectory...") elif mode == "playback": print("Playing back trajectory...") @@ -248,99 +303,57 @@ def reset_scene(): print("l/;\t- Roll Left/Right (Rotate around X axis)") print("u\t- Reset Scene") print("space\t- Press to close gripper, release to open gripper") - print("esc\t- Quit") + if mode in ["interactive", "record"]: + print("\nPlus all default viewer controls (press 'i' to see them)") # reset scene before starting teleoperation reset_scene() # start teleoperation or playback - stop = False step_count = 0 - while not stop: - if mode == "playback": - # Playback mode: replay recorded trajectory - if step_count < len(trajectory): - step_data = trajectory[step_count] - target_pos = step_data["target_pos"] - target_R = R.from_quat(step_data["target_quat"]) - is_close_gripper = step_data["gripper_closed"] - step_count += 1 - print(f"\rPlayback step: {step_count}/{len(trajectory)}", end="") - # Check if user wants to stop playback - pressed_keys = clients["keyboard"].pressed_keys.copy() - stop = keyboard.Key.esc in pressed_keys + try: + while True: + if mode == "playback": + # Playback mode: replay recorded trajectory + if step_count < len(trajectory): + step_data = trajectory[step_count] + target_pos[0] = step_data["target_pos"] + target_R[0] = R.from_quat(step_data["target_quat"]) + gripper_closed[0] = step_data["gripper_closed"] + step_count += 1 + print(f"\rPlayback step: {step_count}/{len(trajectory)}", end="") + else: + print("\nPlayback finished!") + break else: - print("\nPlayback finished!") - break - else: - # Interactive or recording mode - pressed_keys = clients["keyboard"].pressed_keys.copy() - - # reset scene: - reset_flag = False - reset_flag |= keyboard.KeyCode.from_char("u") in pressed_keys - if reset_flag: - reset_scene() - - # stop teleoperation - stop = keyboard.Key.esc in pressed_keys - - # get ee target pose - is_close_gripper = False - dpos = 0.002 - drot = 0.01 - for key in pressed_keys: - if key == keyboard.Key.up: - target_pos[0] -= dpos - elif key == keyboard.Key.down: - target_pos[0] += dpos - elif key == keyboard.Key.right: - target_pos[1] += dpos - elif key == keyboard.Key.left: - target_pos[1] -= dpos - elif key == keyboard.KeyCode.from_char("n"): - target_pos[2] += dpos - elif key == keyboard.KeyCode.from_char("m"): - target_pos[2] -= dpos - elif key == keyboard.KeyCode.from_char("j"): - target_R = R.from_euler("z", drot) * target_R - elif key == keyboard.KeyCode.from_char("k"): - target_R = R.from_euler("z", -drot) * target_R - elif key == keyboard.KeyCode.from_char("i"): - target_R = R.from_euler("y", drot) * target_R - elif key == keyboard.KeyCode.from_char("o"): - target_R = R.from_euler("y", -drot) * target_R - elif key == keyboard.KeyCode.from_char("l"): - target_R = R.from_euler("x", drot) * target_R - elif key == keyboard.KeyCode.from_char(";"): - target_R = R.from_euler("x", -drot) * target_R - elif key == keyboard.Key.space: - is_close_gripper = True - - # Record current state if recording - if recording: - step_data = { - "target_pos": target_pos.copy(), - "target_quat": target_R.as_quat(), # x,y,z,w format - "gripper_closed": is_close_gripper, - "step": step_count, - } - trajectory.append(step_data) - - # control arm - target_quat = target_R.as_quat(scalar_first=True) - target_entity.set_qpos(np.concatenate([target_pos, target_quat])) - q, err = robot.inverse_kinematics(link=ee_link, pos=target_pos, quat=target_quat, return_error=True) - robot.control_dofs_position(q[:-2], motors_dof) - # control gripper - if is_close_gripper: - robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) - else: - robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) + # Interactive or recording mode + # Movement is handled by keybinding callbacks + # Record current state if recording + if recording: + step_data = { + "target_pos": target_pos[0].copy(), + "target_quat": target_R[0].as_quat(), # x,y,z,w format + "gripper_closed": gripper_closed[0], + "step": step_count, + } + trajectory.append(step_data) + + # control arm + target_quat = target_R[0].as_quat(scalar_first=True) + target_entity.set_qpos(np.concatenate([target_pos[0], target_quat])) + q, err = robot.inverse_kinematics(link=ee_link, pos=target_pos[0], quat=target_quat, return_error=True) + robot.control_dofs_position(q[:-2], motors_dof) + # control gripper + if gripper_closed[0]: + robot.control_dofs_force(np.array([-1.0, -1.0]), fingers_dof) + else: + robot.control_dofs_force(np.array([1.0, 1.0]), fingers_dof) - scene.step() - step_count += 1 + scene.step() + step_count += 1 + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") # Save trajectory if recording if recording and len(trajectory) > 0: @@ -436,12 +449,8 @@ def main(): elif not os.path.isabs(trajectory_file): trajectory_file = os.path.join(traj_dir, os.path.basename(trajectory_file)) - clients = dict() - clients["keyboard"] = KeyboardDevice() - clients["keyboard"].start() - scene, entities = build_scene(use_ipc=args.ipc, show_viewer=args.vis, enable_ipc_gui=False) - run_sim(scene, entities, clients, mode=args.mode, trajectory_file=trajectory_file) + run_sim(scene, entities, mode=args.mode, trajectory_file=trajectory_file) if __name__ == "__main__": diff --git a/examples/drone/interactive_drone.py b/examples/drone/interactive_drone.py index cdc6713d6e..f18f353f10 100644 --- a/examples/drone/interactive_drone.py +++ b/examples/drone/interactive_drone.py @@ -1,11 +1,9 @@ import os -import time -import threading -from pynput import keyboard import numpy as np import genesis as gs +from genesis.ext.pyrender.interaction.keybindings import KeyAction, Keybind class DroneController: @@ -13,100 +11,57 @@ def __init__(self): self.thrust = 14475.8 # Base hover RPM - constant hover self.rotation_delta = 200.0 # Differential RPM for rotation self.thrust_delta = 10.0 # Amount to change thrust by when accelerating/decelerating - self.running = True self.rpms = [self.thrust] * 4 - self.pressed_keys = set() - - def on_press(self, key): - try: - if key == keyboard.Key.esc: - self.running = False - return False - self.pressed_keys.add(key) - print(f"Key pressed: {key}") - except AttributeError: - pass - - def on_release(self, key): - try: - self.pressed_keys.discard(key) - except KeyError: - pass def update_thrust(self): - # Store previous RPMs for debugging - prev_rpms = self.rpms.copy() - # Reset RPMs to hover thrust self.rpms = [self.thrust] * 4 + return self.rpms - # Acceleration (Spacebar) - All rotors spin faster - if keyboard.Key.space in self.pressed_keys: - self.thrust += self.thrust_delta - self.rpms = [self.thrust] * 4 - print("Accelerating") - - # Deceleration (Left Shift) - All rotors spin slower - if keyboard.Key.shift in self.pressed_keys: - self.thrust -= self.thrust_delta - self.rpms = [self.thrust] * 4 - print("Decelerating") - - # Forward (North) - Front rotors spin faster - if keyboard.Key.up in self.pressed_keys: - self.rpms[0] += self.rotation_delta # Front left - self.rpms[1] += self.rotation_delta # Front right - self.rpms[2] -= self.rotation_delta # Back left - self.rpms[3] -= self.rotation_delta # Back right - print("Moving Forward") - - # Backward (South) - Back rotors spin faster - if keyboard.Key.down in self.pressed_keys: - self.rpms[0] -= self.rotation_delta # Front left - self.rpms[1] -= self.rotation_delta # Front right - self.rpms[2] += self.rotation_delta # Back left - self.rpms[3] += self.rotation_delta # Back right - print("Moving Backward") - - # Left (West) - Left rotors spin faster - if keyboard.Key.left in self.pressed_keys: - self.rpms[0] -= self.rotation_delta # Front left - self.rpms[2] -= self.rotation_delta # Back left - self.rpms[1] += self.rotation_delta # Front right - self.rpms[3] += self.rotation_delta # Back right - print("Moving Left") - - # Right (East) - Right rotors spin faster - if keyboard.Key.right in self.pressed_keys: - self.rpms[0] += self.rotation_delta # Front left - self.rpms[2] += self.rotation_delta # Back left - self.rpms[1] -= self.rotation_delta # Front right - self.rpms[3] -= self.rotation_delta # Back right - print("Moving Right") - + def move_forward(self): + """Front rotors spin faster""" + self.rpms[0] += self.rotation_delta # Front left + self.rpms[1] += self.rotation_delta # Front right + self.rpms[2] -= self.rotation_delta # Back left + self.rpms[3] -= self.rotation_delta # Back right self.rpms = np.clip(self.rpms, 0, 25000) - # Debug print if any RPMs changed - if not np.array_equal(prev_rpms, self.rpms): - print(f"RPMs changed from {prev_rpms} to {self.rpms}") - - return self.rpms - + def move_backward(self): + """Back rotors spin faster""" + self.rpms[0] -= self.rotation_delta # Front left + self.rpms[1] -= self.rotation_delta # Front right + self.rpms[2] += self.rotation_delta # Back left + self.rpms[3] += self.rotation_delta # Back right + self.rpms = np.clip(self.rpms, 0, 25000) -def run_sim(scene, drone, controller): - while controller.running: - # Update drone with current RPMs - rpms = controller.update_thrust() - drone.set_propellels_rpm(rpms) + def move_left(self): + """Left rotors spin faster""" + self.rpms[0] -= self.rotation_delta # Front left + self.rpms[2] -= self.rotation_delta # Back left + self.rpms[1] += self.rotation_delta # Front right + self.rpms[3] += self.rotation_delta # Back right + self.rpms = np.clip(self.rpms, 0, 25000) - # Update physics - scene.step(refresh_visualizer=False) + def move_right(self): + """Right rotors spin faster""" + print("move right") + self.rpms[0] += self.rotation_delta # Front left + self.rpms[2] += self.rotation_delta # Back left + self.rpms[1] -= self.rotation_delta # Front right + self.rpms[3] -= self.rotation_delta # Back right + self.rpms = np.clip(self.rpms, 0, 25000) - # Limit simulation rate - time.sleep(1.0 / scene.viewer.max_FPS) + def accelerate(self): + """All rotors spin faster""" + self.thrust += self.thrust_delta + self.rpms = [self.thrust] * 4 + self.rpms = np.clip(self.rpms, 0, 25000) - if "PYTEST_VERSION" in os.environ: - break + def decelerate(self): + """All rotors spin slower""" + self.thrust -= self.thrust_delta + self.rpms = [self.thrust] * 4 + self.rpms = np.clip(self.rpms, 0, 25000) def main(): @@ -145,14 +100,36 @@ def main(): # Initialize controller controller = DroneController() - # Start keyboard listener. - # Note that instantiating the listener after building the scene causes segfault on MacOS. - listener = keyboard.Listener(on_press=controller.on_press, on_release=controller.on_release) - listener.start() - # Build scene scene.build() + # Register keybindings + from pyglet.window import key + + scene.viewer.register_keybinds( + ( + Keybind( + key_code=key.UP, key_action=KeyAction.HOLD, name="move_forward", callback_func=controller.move_forward + ), + Keybind( + key_code=key.DOWN, + key_action=KeyAction.HOLD, + name="move_backward", + callback_func=controller.move_backward, + ), + Keybind(key_code=key.LEFT, key_action=KeyAction.HOLD, name="move_left", callback_func=controller.move_left), + Keybind( + key_code=key.RIGHT, key_action=KeyAction.HOLD, name="move_right", callback_func=controller.move_right + ), + Keybind( + key_code=key.SPACE, key_action=KeyAction.HOLD, name="accelerate", callback_func=controller.accelerate + ), + Keybind( + key_code=key.LSHIFT, key_action=KeyAction.HOLD, name="decelerate", callback_func=controller.decelerate + ), + ) + ) + # Print control instructions print("\nDrone Controls:") print("↑ - Move Forward (North)") @@ -161,19 +138,25 @@ def main(): print("→ - Move Right (East)") print("space - Increase RPM") print("shift - Decrease RPM") - print("ESC - Quit\n") + print("\nPlus all default viewer controls (press 'i' to see them)\n") print("Initial hover RPM:", controller.thrust) - # Run simulation in another thread - threading.Thread(target=run_sim, args=(scene, drone, controller)).start() - if "PYTEST_VERSION" not in os.environ: - scene.viewer.run() - + # Run simulation try: - listener.stop() - except NotImplementedError: - # Dummy backend does not implement stop - pass + while True: + # Update drone with current RPMs + rpms = controller.update_thrust() + drone.set_propellels_rpm(rpms) + + # Update physics + scene.step() + + if "PYTEST_VERSION" in os.environ: + break + except KeyboardInterrupt: + gs.logger.info("Simulation interrupted, exiting.") + finally: + gs.logger.info("Simulation finished.") if __name__ == "__main__": diff --git a/examples/interactive_viewer/mouse_spring.py b/examples/interactive_viewer/mouse_spring.py index 1421bbcaab..d88647e131 100644 --- a/examples/interactive_viewer/mouse_spring.py +++ b/examples/interactive_viewer/mouse_spring.py @@ -11,7 +11,7 @@ camera_pos=(3.5, 0.0, 2.5), camera_lookat=(0.0, 0.0, 0.5), camera_fov=40, - viewer_plugin=gs.options.viewer_interactions.MouseSpringViewerPlugin(), + viewer_plugin=gs.options.viewer_plugins.MouseSpringPlugin(), ), profiling_options=gs.options.ProfilingOptions( show_FPS=False, diff --git a/examples/sensors/lidar_teleop.py b/examples/sensors/lidar_teleop.py index 5336114018..6f5f0c8623 100644 --- a/examples/sensors/lidar_teleop.py +++ b/examples/sensors/lidar_teleop.py @@ -1,28 +1,16 @@ import argparse import os -import threading import numpy as np import genesis as gs +from genesis.ext.pyrender.interaction.keybindings import KeyAction, Keybind from genesis.utils.geom import euler_to_quat -IS_PYNPUT_AVAILABLE = False -try: - from pynput import keyboard - - IS_PYNPUT_AVAILABLE = True -except ImportError: - pass - # Position and angle increments for keyboard teleop control KEY_DPOS = 0.1 KEY_DANGLE = 0.1 -# Movement when no keyboard control is available -MOVE_RADIUS = 1.0 -MOVE_RATE = 1.0 / 100.0 - # Number of obstacles to create in a ring around the robot NUM_CYLINDERS = 8 NUM_BOXES = 6 @@ -30,35 +18,6 @@ BOX_RING_RADIUS = 5.0 -class KeyboardDevice: - def __init__(self): - self.pressed_keys = set() - self.lock = threading.Lock() - self.listener = keyboard.Listener(on_press=self.on_press, on_release=self.on_release) - - def start(self): - self.listener.start() - - def stop(self): - try: - self.listener.stop() - except NotImplementedError: - # Dummy backend does not implement stop - pass - self.listener.join() - - def on_press(self, key: "keyboard.Key"): - with self.lock: - self.pressed_keys.add(key) - - def on_release(self, key: "keyboard.Key"): - with self.lock: - self.pressed_keys.discard(key) - - def get_cmd(self): - return self.pressed_keys - - def main(): parser = argparse.ArgumentParser(description="Genesis LiDAR/Depth Camera Visualization with Keyboard Teleop") parser.add_argument("-B", "--n_envs", type=int, default=0, help="Number of environments to replicate") @@ -69,12 +28,6 @@ def main(): ) args = parser.parse_args() - if IS_PYNPUT_AVAILABLE: - kb = KeyboardDevice() - kb.start() - else: - print("Keyboard teleop is disabled since pynput is not installed. To install, run `pip install pynput`.") - gs.init(backend=gs.cpu if args.cpu else gs.gpu, precision="32", logging_level="info") scene = gs.Scene( @@ -169,17 +122,7 @@ def main(): scene.build(n_envs=args.n_envs) - if IS_PYNPUT_AVAILABLE: - # Avoid using same keys as interactive viewer keyboard controls - print("Keyboard Controls:") - print("[↑/↓/←/→]: Move XY") - print("[j/k]: Down/Up") - print("[n/m]: Roll CCW/CW") - print("[,/.]: Pitch Up/Down") - print("[o/p]: Yaw CCW/CW") - print("[\\]: Reset") - print("[esc]: Quit") - + # Initialize pose state init_pos = np.array([0.0, 0.0, 0.35], dtype=np.float32) init_euler = np.array([0.0, 0.0, 0.0], dtype=np.float32) @@ -193,48 +136,81 @@ def apply_pose_to_all_envs(pos_np: np.ndarray, quat_np: np.ndarray): robot.set_pos(pos_np) robot.set_quat(quat_np) + # Define control callbacks + def reset_pose(): + target_pos[:] = init_pos + target_euler[:] = init_euler + + def move_forward(): + target_pos[0] += KEY_DPOS + + def move_backward(): + target_pos[0] -= KEY_DPOS + + def move_right(): + target_pos[1] -= KEY_DPOS + + def move_left(): + target_pos[1] += KEY_DPOS + + def move_down(): + target_pos[2] -= KEY_DPOS + + def move_up(): + target_pos[2] += KEY_DPOS + + def roll_ccw(): + target_euler[0] += KEY_DANGLE + + def roll_cw(): + target_euler[0] -= KEY_DANGLE + + def pitch_up(): + target_euler[1] += KEY_DANGLE + + def pitch_down(): + target_euler[1] -= KEY_DANGLE + + def yaw_ccw(): + target_euler[2] += KEY_DANGLE + + def yaw_cw(): + target_euler[2] -= KEY_DANGLE + + # Register keybindings + from pyglet.window import key + + scene.viewer.register_keybinds( + ( + Keybind(key_code=key.UP, key_action=KeyAction.HOLD, name="move_forward", callback_func=move_forward), + Keybind(key_code=key.DOWN, key_action=KeyAction.HOLD, name="move_backward", callback_func=move_backward), + Keybind(key_code=key.RIGHT, key_action=KeyAction.HOLD, name="move_right", callback_func=move_right), + Keybind(key_code=key.LEFT, key_action=KeyAction.HOLD, name="move_left", callback_func=move_left), + Keybind(key_code=key.J, key_action=KeyAction.HOLD, name="move_down", callback_func=move_down), + Keybind(key_code=key.K, key_action=KeyAction.HOLD, name="move_up", callback_func=move_up), + Keybind(key_code=key.N, key_action=KeyAction.HOLD, name="roll_ccw", callback_func=roll_ccw), + Keybind(key_code=key.M, key_action=KeyAction.HOLD, name="roll_cw", callback_func=roll_cw), + Keybind(key_code=key.COMMA, key_action=KeyAction.HOLD, name="pitch_up", callback_func=pitch_up), + Keybind(key_code=key.PERIOD, key_action=KeyAction.HOLD, name="pitch_down", callback_func=pitch_down), + Keybind(key_code=key.O, key_action=KeyAction.HOLD, name="yaw_ccw", callback_func=yaw_ccw), + Keybind(key_code=key.P, key_action=KeyAction.HOLD, name="yaw_cw", callback_func=yaw_cw), + Keybind(key_code=key.BACKSLASH, key_action=KeyAction.HOLD, name="reset", callback_func=reset_pose), + ) + ) + + # Print controls + print("Keyboard Controls:") + print("[↑/↓/←/→]: Move XY") + print("[j/k]: Down/Up") + print("[n/m]: Roll CCW/CW") + print("[,/.]: Pitch Up/Down") + print("[o/p]: Yaw CCW/CW") + print("[\\]: Reset") + apply_pose_to_all_envs(target_pos, euler_to_quat(target_euler)) try: while True: - if IS_PYNPUT_AVAILABLE: - pressed = kb.pressed_keys.copy() - if keyboard.Key.esc in pressed: - break - if keyboard.KeyCode.from_char("\\") in pressed: - target_pos[:] = init_pos - target_euler[:] = init_euler - - if keyboard.Key.up in pressed: - target_pos[0] += KEY_DPOS - if keyboard.Key.down in pressed: - target_pos[0] -= KEY_DPOS - if keyboard.Key.right in pressed: - target_pos[1] -= KEY_DPOS - if keyboard.Key.left in pressed: - target_pos[1] += KEY_DPOS - if keyboard.KeyCode.from_char("j") in pressed: - target_pos[2] -= KEY_DPOS - if keyboard.KeyCode.from_char("k") in pressed: - target_pos[2] += KEY_DPOS - - if keyboard.KeyCode.from_char("n") in pressed: - target_euler[0] += KEY_DANGLE # roll CCW around +X - if keyboard.KeyCode.from_char("m") in pressed: - target_euler[0] -= KEY_DANGLE # roll CW around +X - if keyboard.KeyCode.from_char(",") in pressed: - target_euler[1] += KEY_DANGLE # pitch up around +Y - if keyboard.KeyCode.from_char(".") in pressed: - target_euler[1] -= KEY_DANGLE # pitch down around +Y - if keyboard.KeyCode.from_char("o") in pressed: - target_euler[2] += KEY_DANGLE # yaw CCW around +Z - if keyboard.KeyCode.from_char("p") in pressed: - target_euler[2] -= KEY_DANGLE # yaw CW around +Z - else: - # move in a circle if no keyboard control - target_pos[0] = MOVE_RADIUS * np.cos(scene.t * MOVE_RATE) - target_pos[1] = MOVE_RADIUS * np.sin(scene.t * MOVE_RATE) - apply_pose_to_all_envs(target_pos, euler_to_quat(target_euler)) scene.step() diff --git a/genesis/ext/pyrender/interaction/__init__.py b/genesis/ext/pyrender/interaction/__init__.py index 2bb9b27154..24ffef5aca 100644 --- a/genesis/ext/pyrender/interaction/__init__.py +++ b/genesis/ext/pyrender/interaction/__init__.py @@ -1,10 +1,11 @@ -from .base_interaction import ( +from .keybindings import KeyAction, Keybind, Keybindings, get_keycode_string +from .plugins.default_controls import DefaultControls +from .plugins.mesh_selector import MeshPointSelectorPlugin +from .plugins.mouse_spring import MouseSpringPlugin +from .viewer_plugin import ( EVENT_HANDLE_STATE, EVENT_HANDLED, VIEWER_PLUGIN_MAP, - BaseViewerInteraction, + ViewerPlugin, register_viewer_plugin, ) -from .plugins.mesh_selector import MeshPointSelectorPlugin -from .plugins.mouse_interaction import MouseSpringViewerPlugin -from .plugins.viewer_controls import ViewerDefaultControls diff --git a/genesis/ext/pyrender/interaction/keybindings.py b/genesis/ext/pyrender/interaction/keybindings.py index 308bb11031..8bb8ce4e7e 100644 --- a/genesis/ext/pyrender/interaction/keybindings.py +++ b/genesis/ext/pyrender/interaction/keybindings.py @@ -1,6 +1,17 @@ from enum import IntEnum from typing import Callable, NamedTuple +KEY_STRING_TO_CHAR = { + "backslash": "\\", + "slash": "/", + "comma": ",", + "period": ".", + "bracketleft": "[", + "bracketright": "]", + "semicolon": ";", + "minus": "-", + "equal": "=", +} class KeyAction(IntEnum): PRESS = 0 @@ -18,7 +29,7 @@ def get_key_hash(key_code: int, modifiers: int | None, action: KeyAction) -> int The modifier keys pressed. action : KeyAction The type of key action (press, hold, release). - + Returns ------- int @@ -28,7 +39,10 @@ def get_key_hash(key_code: int, modifiers: int | None, action: KeyAction) -> int def get_keycode_string(key_code: int) -> str: from pyglet.window.key import symbol_string - return symbol_string(key_code).lower() + symbol = symbol_string(key_code).lower() + if symbol in KEY_STRING_TO_CHAR: + return KEY_STRING_TO_CHAR[symbol] + return symbol class Keybind(NamedTuple): key_code: int @@ -56,6 +70,23 @@ def register(self, keybind: Keybind) -> None: ) self._keybinds_map[keybind.key_hash()] = keybind + def rebind(self, name: str, new_key_code: int | None, new_modifiers: int | None = None, new_key_action: KeyAction | None = None) -> None: + for kb in self._keybinds_map.values(): + if kb.name == name: + new_kb = Keybind( + name=kb.name, + key_code=new_key_code or kb.key_code, + key_action=new_key_action or kb.key_action, + modifiers=new_modifiers or kb.modifiers, + callback_func=kb.callback_func, + args=kb.args, + kwargs=kb.kwargs, + ) + del self._keybinds_map[kb.key_hash()] + self._keybinds_map[new_kb.key_hash()] = new_kb + return + raise ValueError(f"No keybind found with name '{name}'.") + def get(self, key: int, modifiers: int, key_action: KeyAction) -> Keybind | None: key_hash = get_key_hash(key, modifiers, key_action) if key_hash in self._keybinds_map: diff --git a/genesis/ext/pyrender/interaction/plugins/__init__.py b/genesis/ext/pyrender/interaction/plugins/__init__.py index 4d9ca6c59c..b5cb306ef0 100644 --- a/genesis/ext/pyrender/interaction/plugins/__init__.py +++ b/genesis/ext/pyrender/interaction/plugins/__init__.py @@ -1,9 +1,11 @@ +from .default_controls import DefaultControls +from .help_text import HelpTextPlugin from .mesh_selector import MeshPointSelectorPlugin -from .mouse_interaction import MouseSpringViewerPlugin -from .viewer_controls import ViewerDefaultControls +from .mouse_spring import MouseSpringPlugin __all__ = [ - "ViewerDefaultControls", + "DefaultControls", + "HelpTextPlugin", "MeshPointSelectorPlugin", - "MouseSpringViewerPlugin", + "MouseSpringPlugin", ] diff --git a/genesis/ext/pyrender/interaction/plugins/viewer_controls.py b/genesis/ext/pyrender/interaction/plugins/default_controls.py similarity index 50% rename from genesis/ext/pyrender/interaction/plugins/viewer_controls.py rename to genesis/ext/pyrender/interaction/plugins/default_controls.py index 5fe67a07e0..48fbb91479 100644 --- a/genesis/ext/pyrender/interaction/plugins/viewer_controls.py +++ b/genesis/ext/pyrender/interaction/plugins/default_controls.py @@ -1,16 +1,14 @@ import os from typing import TYPE_CHECKING -import numpy as np import pyglet -from typing_extensions import override import genesis as gs -from genesis.options.viewer_interactions import ViewerDefaultControls as ViewerDefaultControlsOptions +from genesis.options.viewer_plugins import DefaultControlsPlugin as DefaultControlsOptions -from ...constants import TEXT_PADDING -from ..base_interaction import EVENT_HANDLE_STATE, BaseViewerInteraction, register_viewer_plugin -from ..keybindings import Keybind, get_keycode_string +from ..keybindings import Keybind +from ..viewer_plugin import register_viewer_plugin +from .help_text import HelpTextPlugin if TYPE_CHECKING: from genesis.engine.scene import Scene @@ -18,8 +16,8 @@ INSTR_KEYBIND_NAME = "toggle_instructions" -@register_viewer_plugin(ViewerDefaultControlsOptions) -class ViewerDefaultControls(BaseViewerInteraction): +@register_viewer_plugin(DefaultControlsOptions) +class DefaultControls(HelpTextPlugin): """ Default keyboard controls for the Genesis viewer. @@ -37,102 +35,83 @@ def __init__( super().__init__(viewer, options, camera, scene, viewport_size) self.viewer.register_keybinds(( - Keybind(key_code=pyglet.window.key.I, name=INSTR_KEYBIND_NAME, callback_func=self.toggle_instructions), - Keybind(key_code=pyglet.window.key.R, name="record_video", callback_func=self.toggle_record_video), - Keybind(key_code=pyglet.window.key.S, name="save_image", callback_func=self.save_image), - Keybind(key_code=pyglet.window.key.Z, name="reset_camera", callback_func=self.reset_camera), - Keybind(key_code=pyglet.window.key.A, name="camera_rotation", callback_func=self.toggle_camera_rotation), - Keybind(key_code=pyglet.window.key.H, name="shadow", callback_func=self.toggle_shadow), - Keybind(key_code=pyglet.window.key.F, name="face_normals", callback_func=self.toggle_face_normals), - Keybind(key_code=pyglet.window.key.V, name="vertex_normals", callback_func=self.toggle_vertex_normals), - Keybind(key_code=pyglet.window.key.W, name="world_frame", callback_func=self.toggle_world_frame), - Keybind(key_code=pyglet.window.key.L, name="link_frame", callback_func=self.toggle_link_frame), - Keybind(key_code=pyglet.window.key.D, name="wireframe", callback_func=self.toggle_wireframe), - Keybind(key_code=pyglet.window.key.C, name="camera_frustum", callback_func=self.toggle_camera_frustum), - Keybind(key_code=pyglet.window.key.P, name="reload_shader", callback_func=self.reload_shader), - Keybind(key_code=pyglet.window.key.F11, name="fullscreen_mode", callback_func=self.toggle_fullscreen), + Keybind(key_code=pyglet.window.key.R, name="record_video", callback_func=self._toggle_record_video), + Keybind(key_code=pyglet.window.key.S, name="save_image", callback_func=self._save_image), + Keybind(key_code=pyglet.window.key.Z, name="reset_camera", callback_func=self._reset_camera), + Keybind(key_code=pyglet.window.key.A, name="camera_rotation", callback_func=self._toggle_camera_rotation), + Keybind(key_code=pyglet.window.key.H, name="shadow", callback_func=self._toggle_shadow), + Keybind(key_code=pyglet.window.key.F, name="face_normals", callback_func=self._toggle_face_normals), + Keybind(key_code=pyglet.window.key.V, name="vertex_normals", callback_func=self._toggle_vertex_normals), + Keybind(key_code=pyglet.window.key.W, name="world_frame", callback_func=self._toggle_world_frame), + Keybind(key_code=pyglet.window.key.L, name="link_frame", callback_func=self._toggle_link_frame), + Keybind(key_code=pyglet.window.key.D, name="wireframe", callback_func=self._toggle_wireframe), + Keybind(key_code=pyglet.window.key.C, name="camera_frustum", callback_func=self._toggle_camera_frustum), + Keybind(key_code=pyglet.window.key.P, name="reload_shader", callback_func=self._reload_shader), + Keybind(key_code=pyglet.window.key.F11, name="fullscreen_mode", callback_func=self._toggle_fullscreen), )) - self._display_instr = False - self._instr_texts: tuple[list[str], list[str]] = ([], []) - self._update_instr_texts() - def _update_instr_texts(self): - self.instr_key_str = get_keycode_string(self.viewer._keybindings.get_by_name(INSTR_KEYBIND_NAME).key_code) - kb_texts = [ - f"{'[' + get_keycode_string(kb.key_code):>{7}}]: " + - kb.name.replace('_', ' ') for kb in self.viewer._keybindings.keybinds if kb.name != INSTR_KEYBIND_NAME - ] - self._instr_texts = ( - [f"> [{self.instr_key_str}]: show keyboard instructions"], - [f"< [{self.instr_key_str}]: hide keyboard instructions"] + kb_texts - ) - - def toggle_instructions(self): - self._display_instr = not self._display_instr - self._update_instr_texts() - - def toggle_camera_rotation(self): + def _toggle_camera_rotation(self): self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] if self.viewer.viewer_flags["rotate"]: - self.viewer._message_text = "Rotation On" + self.set_message_text("Rotation On") else: - self.viewer._message_text = "Rotation Off" + self.set_message_text("Rotation Off") - def toggle_fullscreen(self): + def _toggle_fullscreen(self): self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) self.viewer.activate() if self.viewer.viewer_flags["fullscreen"]: - self.viewer._message_text = "Fullscreen On" + self.set_message_text("Fullscreen On") else: - self.viewer._message_text = "Fullscreen Off" + self.set_message_text("Fullscreen Off") - def toggle_shadow(self): + def _toggle_shadow(self): self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] if self.viewer.render_flags["shadows"]: - self.viewer._message_text = "Shadows On" + self.set_message_text("Shadows On") else: - self.viewer._message_text = "Shadows Off" + self.set_message_text("Shadows Off") - def toggle_world_frame(self): + def _toggle_world_frame(self): if not self.viewer.gs_context.world_frame_shown: self.viewer.gs_context.on_world_frame() - self.viewer._message_text = "World Frame On" + self.set_message_text("World Frame On") else: self.viewer.gs_context.off_world_frame() - self.viewer._message_text = "World Frame Off" + self.set_message_text("World Frame Off") - def toggle_link_frame(self): + def _toggle_link_frame(self): if not self.viewer.gs_context.link_frame_shown: self.viewer.gs_context.on_link_frame() - self.viewer._message_text = "Link Frame On" + self.set_message_text("Link Frame On") else: self.viewer.gs_context.off_link_frame() - self.viewer._message_text = "Link Frame Off" + self.set_message_text("Link Frame Off") - def toggle_camera_frustum(self): + def _toggle_camera_frustum(self): if not self.viewer.gs_context.camera_frustum_shown: self.viewer.gs_context.on_camera_frustum() - self.viewer._message_text = "Camera Frustrum On" + self.set_message_text("Camera Frustum On") else: self.viewer.gs_context.off_camera_frustum() - self.viewer._message_text = "Camera Frustrum Off" + self.set_message_text("Camera Frustum Off") - def toggle_face_normals(self): + def _toggle_face_normals(self): self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] if self.viewer.render_flags["face_normals"]: - self.viewer._message_text = "Face Normals On" + self.set_message_text("Face Normals On") else: - self.viewer._message_text = "Face Normals Off" + self.set_message_text("Face Normals Off") - def toggle_vertex_normals(self): + def _toggle_vertex_normals(self): self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] if self.viewer.render_flags["vertex_normals"]: - self.viewer._message_text = "Vert Normals On" + self.set_message_text("Vert Normals On") else: - self.viewer._message_text = "Vert Normals Off" + self.set_message_text("Vert Normals Off") - def toggle_record_video(self): + def _toggle_record_video(self): if self.viewer.viewer_flags["record"]: self.viewer.save_video() self.viewer.set_caption(self.viewer.viewer_flags["window_title"]) @@ -148,70 +127,33 @@ def toggle_record_video(self): self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] - def save_image(self): + def _save_image(self): self.viewer._save_image() - def toggle_wireframe(self): + def _toggle_wireframe(self): if self.viewer.render_flags["flip_wireframe"]: self.viewer.render_flags["flip_wireframe"] = False self.viewer.render_flags["all_wireframe"] = True self.viewer.render_flags["all_solid"] = False - self.viewer._message_text = "All Wireframe" + self.set_message_text("All Wireframe") elif self.viewer.render_flags["all_wireframe"]: self.viewer.render_flags["flip_wireframe"] = False self.viewer.render_flags["all_wireframe"] = False self.viewer.render_flags["all_solid"] = True - self.viewer._message_text = "All Solid" + self.set_message_text("All Solid") elif self.viewer.render_flags["all_solid"]: self.viewer.render_flags["flip_wireframe"] = False self.viewer.render_flags["all_wireframe"] = False self.viewer.render_flags["all_solid"] = False - self.viewer._message_text = "Default Wireframe" + self.set_message_text("Default Wireframe") else: self.viewer.render_flags["flip_wireframe"] = True self.viewer.render_flags["all_wireframe"] = False self.viewer.render_flags["all_solid"] = False - self.viewer._message_text = "Flip Wireframe" + self.set_message_text("Flip Wireframe") - def reset_camera(self): + def _reset_camera(self): self.viewer._reset_view() - def reload_shader(self): - self.viewer._renderer.reload_program() - - @override - def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: - if self.viewer is None: - return None - - # Reset message text and check for keybinding - self.viewer._message_text = None - super().on_key_press(symbol, modifiers) - - if self.viewer._message_text is not None: - self.viewer._message_opac = 1.0 + self.viewer._ticks_till_fade - - return None - - @override - def on_draw(self): - """Render keyboard instructions.""" - if self.viewer is None: - return - - if self._display_instr: - self.viewer._renderer.render_texts( - self._instr_texts[1], - TEXT_PADDING, - self.viewer.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - else: - self.viewer._renderer.render_texts( - self._instr_texts[0], - TEXT_PADDING, - self.viewer.viewport_size[1] - TEXT_PADDING, - font_pt=26, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) + def _reload_shader(self): + self.viewer._renderer.reload_program() \ No newline at end of file diff --git a/genesis/ext/pyrender/interaction/plugins/help_text.py b/genesis/ext/pyrender/interaction/plugins/help_text.py new file mode 100644 index 0000000000..ff891f8fbd --- /dev/null +++ b/genesis/ext/pyrender/interaction/plugins/help_text.py @@ -0,0 +1,104 @@ +from typing import TYPE_CHECKING + +import numpy as np +import pyglet +from typing_extensions import override + +from genesis.options.viewer_plugins import HelpTextPlugin as HelpTextPluginOptions + +from ...constants import TEXT_PADDING, TextAlign +from ..keybindings import Keybind, get_keycode_string +from ..viewer_plugin import ViewerPlugin, register_viewer_plugin + +if TYPE_CHECKING: + from genesis.engine.scene import Scene + from genesis.ext.pyrender.node import Node + from genesis.options.viewer_plugins import ViewerPlugin as ViewerPluginOptions + +INSTR_KEYBIND_NAME = "toggle_instructions" + +@register_viewer_plugin(HelpTextPluginOptions) +class HelpTextPlugin(ViewerPlugin): + """ + Default keyboard controls for the Genesis viewer. + """ + + def __init__( + self, + viewer, + options: "ViewerPluginOptions", + camera: "Node" = None, + scene: "Scene" = None, + viewport_size: tuple[int, int] = None, + ): + super().__init__(viewer, options, camera, scene, viewport_size) + + self.viewer.register_keybinds(( + Keybind(key_code=pyglet.window.key.I, name=INSTR_KEYBIND_NAME, callback_func=self._toggle_instructions), + )) + self._collapse_instructions = True + self._instr_texts: tuple[list[str], list[str]] = ([], []) + self._update_instr_texts() + + self._message_text = None + self._ticks_till_fade = 2.0 / 3.0 * self.viewer.viewer_flags["refresh_rate"] + self._message_opac = 1.0 + self._ticks_till_fade + + def _update_instr_texts(self): + self.instr_key_str = get_keycode_string(self.viewer._keybindings.get_by_name(INSTR_KEYBIND_NAME).key_code) + kb_texts = [ + f"{'[' + get_keycode_string(kb.key_code):>{7}}]: " + + kb.name.replace('_', ' ') for kb in self.viewer._keybindings.keybinds if kb.name != INSTR_KEYBIND_NAME + ] + self._instr_texts = ( + [f"> [{self.instr_key_str}]: show keyboard instructions"], + [f"< [{self.instr_key_str}]: hide keyboard instructions"] + kb_texts + ) + + def _toggle_instructions(self): + self._collapse_instructions = not self._collapse_instructions + self._update_instr_texts() + + def set_message_text(self, text: str): + self._message_text = text + self._message_opac = 1.0 + self._ticks_till_fade + + @override + def on_draw(self): + if self._message_text is not None: + self.viewer._renderer.render_text( + self._message_text, + self.viewport_size[0] - TEXT_PADDING, + TEXT_PADDING, + font_pt=self.options.font_size, + color=np.array([0.1, 0.7, 0.2, np.clip(self._message_opac, 0.0, 1.0)]), + align=TextAlign.BOTTOM_RIGHT, + ) + + if self._message_opac > 1.0: + self._message_opac -= 1.0 + else: + self._message_opac *= 0.90 + + if self._message_opac < 0.05: + self._message_opac = 1.0 + self._ticks_till_fade + self._message_text = None + + + if self._collapse_instructions: + self.viewer._renderer.render_texts( + self._instr_texts[0], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=self.options.font_size, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + else: + self.viewer._renderer.render_texts( + self._instr_texts[1], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=self.options.font_size, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + \ No newline at end of file diff --git a/genesis/ext/pyrender/interaction/plugins/mesh_selector.py b/genesis/ext/pyrender/interaction/plugins/mesh_selector.py index c2399a5e21..55bc4b1bb8 100644 --- a/genesis/ext/pyrender/interaction/plugins/mesh_selector.py +++ b/genesis/ext/pyrender/interaction/plugins/mesh_selector.py @@ -4,11 +4,11 @@ from typing_extensions import override import genesis as gs -from genesis.options.viewer_interactions import MeshPointSelectorPlugin as MeshPointSelectorPluginOptions +from genesis.options.viewer_plugins import MeshPointSelectorPlugin as MeshPointSelectorPluginOptions -from ..base_interaction import EVENT_HANDLE_STATE, EVENT_HANDLED, register_viewer_plugin from ..utils import Pose, Ray, Vec3, ViewerRaycaster -from .viewer_controls import ViewerDefaultControls +from ..viewer_plugin import EVENT_HANDLE_STATE, EVENT_HANDLED, register_viewer_plugin +from .help_text import HelpTextPlugin if TYPE_CHECKING: from genesis.engine.entities.rigid_entity import RigidLink @@ -34,9 +34,8 @@ class SelectedPoint(NamedTuple): local_normal: Vec3 - @register_viewer_plugin(MeshPointSelectorPluginOptions) -class MeshPointSelectorPlugin(ViewerDefaultControls): +class MeshPointSelectorPlugin(HelpTextPlugin): """ Interactive viewer plugin that enables using mouse clicks to select points on rigid meshes. Selected points are stored in local coordinates relative to their link's frame. diff --git a/genesis/ext/pyrender/interaction/plugins/mouse_interaction.py b/genesis/ext/pyrender/interaction/plugins/mouse_spring.py similarity index 96% rename from genesis/ext/pyrender/interaction/plugins/mouse_interaction.py rename to genesis/ext/pyrender/interaction/plugins/mouse_spring.py index f88b0a9db0..883e793f4c 100644 --- a/genesis/ext/pyrender/interaction/plugins/mouse_interaction.py +++ b/genesis/ext/pyrender/interaction/plugins/mouse_spring.py @@ -5,11 +5,10 @@ from typing_extensions import override # Made it into standard lib from Python 3.12 import genesis as gs -from genesis.options.viewer_interactions import MouseSpringViewerPlugin as MouseSpringViewerPluginOptions +from genesis.options.viewer_plugins import MouseSpringPlugin as MouseSpringPluginOptions -from ..base_interaction import EVENT_HANDLE_STATE, EVENT_HANDLED, register_viewer_plugin from ..utils import AABB, OBB, Color, Plane, Pose, Quat, Ray, RayHit, Vec3, ViewerRaycaster -from .viewer_controls import ViewerDefaultControls +from ..viewer_plugin import EVENT_HANDLE_STATE, EVENT_HANDLED, ViewerPlugin, register_viewer_plugin if TYPE_CHECKING: from genesis.engine.entities.rigid_entity import RigidEntity, RigidGeom, RigidLink @@ -102,8 +101,8 @@ def is_attached(self) -> bool: return self.held_link is not None -@register_viewer_plugin(MouseSpringViewerPluginOptions) -class MouseSpringViewerPlugin(ViewerDefaultControls): +@register_viewer_plugin(MouseSpringPluginOptions) +class MouseSpringPlugin(ViewerPlugin): """ Basic interactive viewer plugin that enables using mouse to apply spring force on rigid entities. """ @@ -111,7 +110,7 @@ class MouseSpringViewerPlugin(ViewerDefaultControls): def __init__( self, viewer, - options: MouseSpringViewerPluginOptions, + options: MouseSpringPluginOptions, camera: "Node", scene: "Scene", viewport_size: tuple[int, int], diff --git a/genesis/ext/pyrender/interaction/base_interaction.py b/genesis/ext/pyrender/interaction/viewer_plugin.py similarity index 93% rename from genesis/ext/pyrender/interaction/base_interaction.py rename to genesis/ext/pyrender/interaction/viewer_plugin.py index 15082e2004..e6708e2dfd 100644 --- a/genesis/ext/pyrender/interaction/base_interaction.py +++ b/genesis/ext/pyrender/interaction/viewer_plugin.py @@ -7,14 +7,14 @@ if TYPE_CHECKING: from genesis.engine.scene import Scene from genesis.ext.pyrender.node import Node - from genesis.options.viewer_interactions import ViewerInteraction as ViewerPluginOptions + from genesis.options.viewer_plugins import ViewerPlugin as ViewerPluginOptions EVENT_HANDLE_STATE = Literal[True] | None EVENT_HANDLED: Literal[True] = True # Global map from options class to viewer plugin class -VIEWER_PLUGIN_MAP: dict[Type["ViewerPluginOptions"], Type["BaseViewerInteraction"]] = {} +VIEWER_PLUGIN_MAP: dict[Type["ViewerPluginOptions"], Type["ViewerPlugin"]] = {} def register_viewer_plugin(options_cls: Type["ViewerPluginOptions"]): @@ -37,14 +37,14 @@ def register_viewer_plugin(options_cls: Type["ViewerPluginOptions"]): class ViewerInteraction(ViewerInteractionBase): ... """ - def _impl(plugin_cls: Type["BaseViewerInteraction"]): + def _impl(plugin_cls: Type["ViewerPlugin"]): VIEWER_PLUGIN_MAP[options_cls] = plugin_cls return plugin_cls return _impl # Note: Viewer window is based on pyglet.window.Window, mouse events are defined in pyglet.window.BaseWindow -class BaseViewerInteraction(): +class ViewerPlugin(): """ Base class for handling pyglet.window.Window events. """ diff --git a/genesis/ext/pyrender/viewer.py b/genesis/ext/pyrender/viewer.py index ef277ad7a5..a2cd20ba90 100644 --- a/genesis/ext/pyrender/viewer.py +++ b/genesis/ext/pyrender/viewer.py @@ -220,7 +220,6 @@ def __init__( self._offscreen_semaphore = Semaphore(0) self._offscreen_result = None - self._video_saver = None self._video_recorder = None self._default_render_flags = { @@ -272,16 +271,11 @@ def __init__( self._viewer_flags[key] = kwargs[key] self._keybindings: Keybindings = Keybindings() - self._held_keys: dict[tuple[int, int], bool] = {} # Track held keys: (symbol, modifiers) -> True + self._held_keys: dict[tuple[int, int], bool] = {} ####################################################################### # Save internal settings ####################################################################### - # Set up caption stuff - self._message_text = None - self._ticks_till_fade = 2.0 / 3.0 * self.viewer_flags["refresh_rate"] - self._message_opac = 1.0 + self._ticks_till_fade - # Set up raymond lights and direct lights self._raymond_lights = self._create_raymond_lights() self._direct_light = self._create_direct_light() @@ -348,7 +342,7 @@ def __init__( # Setup viewer interaction if plugin_options is None: - plugin_options = gs.options.viewer_interactions.ViewerDefaultControls() + plugin_options = gs.options.viewer_plugins.ViewerDefaultControls() plugin_cls = VIEWER_PLUGIN_MAP.get(type(plugin_options)) if plugin_cls is None: @@ -494,7 +488,7 @@ def viewer_flags(self, value): def register_keybinds(self, keybinds: tuple[Keybind]) -> None: """ - Add a key handler to call a function when the given ASCII key is pressed. + Add a key handler to call a function when the given key is pressed. Parameters ---------- @@ -503,6 +497,7 @@ def register_keybinds(self, keybinds: tuple[Keybind]) -> None: """ for keybind in keybinds: self._keybindings.register(keybind) + def close(self): """Close the viewer. @@ -735,6 +730,13 @@ def on_mouse_motion(self, x: int, y: int, dx: int, dy: int) -> EVENT_HANDLE_STAT def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record an initial mouse press.""" + # Stop animating while using the mouse + self.viewer_flags["mouse_pressed"] = True + + result = self.interaction_plugin.on_mouse_press(x, y, button, modifiers) + if result is EVENT_HANDLED: + return result + self._trackball.set_state(Trackball.STATE_ROTATE) if button == pyglet.window.mouse.LEFT: ctrl = modifiers & pyglet.window.key.MOD_CTRL @@ -751,24 +753,27 @@ def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_H self._trackball.down(np.array([x, y])) - # Stop animating while using the mouse - self.viewer_flags["mouse_pressed"] = True - return self.interaction_plugin.on_mouse_press(x, y, button, modifiers) + return EVENT_HANDLED def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifiers: int) -> EVENT_HANDLE_STATE: """The mouse was moved with one or more buttons held down.""" result = self.interaction_plugin.on_mouse_drag(x, y, dx, dy, buttons, modifiers) - if result is not EVENT_HANDLED: - result = self._trackball.drag(np.array([x, y])) - return result + if result is EVENT_HANDLED: + return result + + result = self._trackball.drag(np.array([x, y])) + return EVENT_HANDLED def on_mouse_release(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a mouse release.""" self.viewer_flags["mouse_pressed"] = False return self.interaction_plugin.on_mouse_release(x, y, button, modifiers) - def on_mouse_scroll(self, x, y, dx, dy): + def on_mouse_scroll(self, x, y, dx, dy) -> EVENT_HANDLE_STATE: """Record a mouse scroll.""" + if self.interaction_plugin.on_mouse_scroll(x, y, dx, dy) == EVENT_HANDLED: + return EVENT_HANDLED + if self.viewer_flags["use_perspective_cam"]: self._trackball.scroll(dy) else: @@ -785,6 +790,8 @@ def on_mouse_scroll(self, x, y, dx, dy): ymag = max(c.ymag * sf, 1e-8 * c.ymag / c.xmag) c.xmag = xmag c.ymag = ymag + + return EVENT_HANDLED def _call_keybind_callback(self, symbol: int, modifiers: int, action: KeyAction) -> None: """Call registered keybind callbacks for the given key event.""" @@ -795,12 +802,14 @@ def _call_keybind_callback(self, symbol: int, modifiers: int, action: KeyAction) def on_key_press(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a key press.""" self._held_keys[(symbol, modifiers)] = True + self._call_keybind_callback(symbol, modifiers, KeyAction.PRESS) return self.interaction_plugin.on_key_press(symbol, modifiers) def on_key_release(self, symbol: int, modifiers: int) -> EVENT_HANDLE_STATE: """Record a key release.""" self._held_keys.pop((symbol, modifiers), None) + self._call_keybind_callback(symbol, modifiers, KeyAction.RELEASE) return self.interaction_plugin.on_key_release(symbol, modifiers) @@ -820,26 +829,6 @@ def _time_event(dt, self): if self.viewer_flags["rotate"] and not self.viewer_flags["mouse_pressed"]: self._rotate() - # Manage message opacity - if self._message_text is not None: - if self._message_opac > 1.0: - self._message_opac -= 1.0 - else: - self._message_opac *= 0.90 - if self._message_opac < 0.05: - self._message_opac = 1.0 + self._ticks_till_fade - self._message_text = None - - # video saving warning - if self._video_saver is not None: - if self._video_saver.is_alive(): - self._message_text = "Saving video... Please don't exit." - self._message_opac = 1.0 - else: - self._message_text = f"Video saved to {self._video_file_name}" - self._message_opac = self.viewer_flags["refresh_rate"] * 2 - self._video_saver = None - self.on_draw() def _reset_view(self): diff --git a/genesis/options/viewer_interactions.py b/genesis/options/viewer_plugins.py similarity index 73% rename from genesis/options/viewer_interactions.py rename to genesis/options/viewer_plugins.py index 7e72a0e6b2..ab8a10c0c1 100644 --- a/genesis/options/viewer_interactions.py +++ b/genesis/options/viewer_plugins.py @@ -1,38 +1,36 @@ from .options import Options -class ViewerInteraction(Options): +class ViewerPlugin(Options): """ Base class for viewer interaction options. All viewer interaction option classes should inherit from this base class. """ - pass +class HelpTextPlugin(ViewerPlugin): + """ + Displays keyboard instructions in the viewer. + """ + + display_instructions: bool = True + font_size: int = 26 -class ViewerDefaultControls(ViewerInteraction): + +class DefaultControlsPlugin(HelpTextPlugin): """ Default viewer interaction controls with keyboard shortcuts for recording, changing render modes, etc. - - Parameters - ---------- - keybindings : dict[str, int] - Override the default mapping of action names to keyboard key codes (pyglet.window.key.*). """ - keybindings: dict[str, int] = None - -class MouseSpringViewerPlugin(ViewerDefaultControls): +class MouseSpringPlugin(HelpTextPlugin): """ Options for the interactive viewer plugin that allows mouse-based object manipulation. """ - pass - -class MeshPointSelectorPlugin(ViewerDefaultControls): +class MeshPointSelectorPlugin(HelpTextPlugin): """ Options for the mesh point selector plugin that allows selecting points on a mesh. @@ -40,9 +38,9 @@ class MeshPointSelectorPlugin(ViewerDefaultControls): ---------- sphere_radius : float The radius of the sphere used to visualize selected points. - sphere_color : tuple + sphere_color : tuple[float, float, float, float] The color of the sphere used to visualize selected points. - hover_color : tuple + hover_color : tuple[float, float, float, float] The color of the sphere used to visualize the point and normal when hovering over a mesh. grid_snap : tuple[float, float, float] Grid snap spacing for each axis (x, y, z). Any negative value disables snapping for that axis. diff --git a/genesis/options/vis.py b/genesis/options/vis.py index 46934e8672..f3a0cacf6d 100644 --- a/genesis/options/vis.py +++ b/genesis/options/vis.py @@ -3,7 +3,7 @@ import genesis as gs from .options import Options -from .viewer_interactions import ViewerDefaultControls, ViewerInteraction +from .viewer_plugins import DefaultControlsPlugin, ViewerPlugin class ViewerOptions(Options): @@ -34,7 +34,7 @@ class ViewerOptions(Options): The up vector of the camera's extrinsic pose. camera_fov : float The field of view (in degrees) of the camera. - viewer_plugin : ViewerPluginOptions + viewer_plugin : ViewerPlugin Viewer plugin that adds interactive functionality to the viewer. """ @@ -46,7 +46,7 @@ class ViewerOptions(Options): camera_lookat: tuple = (0.0, 0.0, 0.5) camera_up: tuple = (0.0, 0.0, 1.0) camera_fov: float = 40 - viewer_plugin: ViewerInteraction = ViewerDefaultControls() + viewer_plugin: ViewerPlugin = DefaultControlsPlugin() class VisOptions(Options): diff --git a/genesis/vis/viewer.py b/genesis/vis/viewer.py index ce2da98c63..bf3444d0d7 100644 --- a/genesis/vis/viewer.py +++ b/genesis/vis/viewer.py @@ -10,6 +10,7 @@ import genesis as gs import genesis.utils.geom as gu from genesis.ext import pyrender +from genesis.ext.pyrender.interaction import KeyAction, Keybind from genesis.repr_base import RBC from genesis.utils.misc import redirect_libc_stderr, tensor_to_array from genesis.utils.tools import Rate @@ -259,7 +260,7 @@ def update_following(self): else: self.set_camera_pose(pos=camera_pos, lookat=self._follow_lookat) - def register_keybinds(self, keybinds: tuple[pyrender.interaction.keybindings.Keybind]) -> None: + def register_keybinds(self, keybinds: tuple[Keybind]) -> None: """ Register a callback function to be called when a key is pressed. @@ -270,6 +271,35 @@ def register_keybinds(self, keybinds: tuple[pyrender.interaction.keybindings.Key """ self._pyrender_viewer.register_keybinds(keybinds) + def remap_keybind( + self, + keybind_name: str, + new_key_code: int, + new_modifiers: int | None = None, + new_key_action: KeyAction = KeyAction.PRESS, + ) -> None: + """ + Remap an existing keybind to a new key combination. + Use `None` for any parameter you do not wish to change. + + Parameters + ---------- + keybind_name : str + The name of the keybind to remap. + new_key_code : int + The new key code from pyglet. + new_modifiers : int | None, optional + The new modifier keys pressed. + new_key_action : KeyAction, optional + The new type of key action (press, hold, release). + """ + self._pyrender_viewer._keybindings.rebind( + keybind_name, + new_key_code, + new_modifiers, + new_key_action, + ) + # ------------------------------------------------------------------------------------ # ----------------------------------- properties ------------------------------------- # ------------------------------------------------------------------------------------ From 8e5449148c13d090c4d0e0ef28097bf48741e78e Mon Sep 17 00:00:00 2001 From: Trinity Chung Date: Sat, 10 Jan 2026 23:55:52 -0500 Subject: [PATCH 4/5] add point mesh selector example --- examples/viewer_plugin/mesh_point_selector.py | 47 +++++++++++++++++++ .../mouse_spring.py | 0 genesis/ext/pyrender/interaction/__init__.py | 2 +- .../pyrender/interaction/plugins/__init__.py | 2 +- ...esh_selector.py => mesh_point_selector.py} | 0 5 files changed, 49 insertions(+), 2 deletions(-) create mode 100644 examples/viewer_plugin/mesh_point_selector.py rename examples/{interactive_viewer => viewer_plugin}/mouse_spring.py (100%) rename genesis/ext/pyrender/interaction/plugins/{mesh_selector.py => mesh_point_selector.py} (100%) diff --git a/examples/viewer_plugin/mesh_point_selector.py b/examples/viewer_plugin/mesh_point_selector.py new file mode 100644 index 0000000000..140b58bfd6 --- /dev/null +++ b/examples/viewer_plugin/mesh_point_selector.py @@ -0,0 +1,47 @@ +from huggingface_hub import snapshot_download + +import genesis as gs + +if __name__ == "__main__": + + gs.init(backend=gs.gpu) + + scene = gs.Scene( + sim_options=gs.options.SimOptions( + gravity=(0.0, 0.0, 0.0), + ), + viewer_options=gs.options.ViewerOptions( + camera_pos=(0.5, 0.2, 1.0), + camera_lookat=(0.0, 0.0, 1.0), + camera_fov=40, + viewer_plugin=gs.options.viewer_plugins.MeshPointSelectorPlugin( + sphere_radius=0.004, + grid_snap=(-1.0, 0.01, 0.01), + ), + ), + profiling_options=gs.options.ProfilingOptions( + show_FPS=False, + ), + show_viewer=True, + ) + + asset_path = snapshot_download( + repo_id="Genesis-Intelligence/assets", + allow_patterns="allegro_hand/*", + repo_type="dataset", + ) + hand = scene.add_entity( + morph=gs.morphs.URDF( + file=f"{asset_path}/allegro_hand/allegro_hand_right_glb.urdf", + collision=True, + pos=(0.0, 0.0, 1.0), + euler=(0.0, 0.0, 0.0), + fixed=True, + merge_fixed_links=False, + ), + ) + + scene.build() + + while True: + scene.step() diff --git a/examples/interactive_viewer/mouse_spring.py b/examples/viewer_plugin/mouse_spring.py similarity index 100% rename from examples/interactive_viewer/mouse_spring.py rename to examples/viewer_plugin/mouse_spring.py diff --git a/genesis/ext/pyrender/interaction/__init__.py b/genesis/ext/pyrender/interaction/__init__.py index 24ffef5aca..ea9fb562fd 100644 --- a/genesis/ext/pyrender/interaction/__init__.py +++ b/genesis/ext/pyrender/interaction/__init__.py @@ -1,6 +1,6 @@ from .keybindings import KeyAction, Keybind, Keybindings, get_keycode_string from .plugins.default_controls import DefaultControls -from .plugins.mesh_selector import MeshPointSelectorPlugin +from .plugins.mesh_point_selector import MeshPointSelectorPlugin from .plugins.mouse_spring import MouseSpringPlugin from .viewer_plugin import ( EVENT_HANDLE_STATE, diff --git a/genesis/ext/pyrender/interaction/plugins/__init__.py b/genesis/ext/pyrender/interaction/plugins/__init__.py index b5cb306ef0..a30d1981fa 100644 --- a/genesis/ext/pyrender/interaction/plugins/__init__.py +++ b/genesis/ext/pyrender/interaction/plugins/__init__.py @@ -1,6 +1,6 @@ from .default_controls import DefaultControls from .help_text import HelpTextPlugin -from .mesh_selector import MeshPointSelectorPlugin +from .mesh_point_selector import MeshPointSelectorPlugin from .mouse_spring import MouseSpringPlugin __all__ = [ diff --git a/genesis/ext/pyrender/interaction/plugins/mesh_selector.py b/genesis/ext/pyrender/interaction/plugins/mesh_point_selector.py similarity index 100% rename from genesis/ext/pyrender/interaction/plugins/mesh_selector.py rename to genesis/ext/pyrender/interaction/plugins/mesh_point_selector.py From ba92fd6991dc3f1dfd82414391bf3d9ae6396d70 Mon Sep 17 00:00:00 2001 From: Trinity Chung Date: Mon, 12 Jan 2026 12:20:35 -0500 Subject: [PATCH 5/5] rufff format --- examples/viewer_plugin/mesh_point_selector.py | 1 - examples/viewer_plugin/mouse_spring.py | 1 - .../ext/pyrender/interaction/keybindings.py | 33 +++++--- .../interaction/plugins/default_controls.py | 63 ++++++++------- .../pyrender/interaction/plugins/help_text.py | 69 +++++++++-------- .../plugins/mesh_point_selector.py | 77 ++++++++++--------- .../interaction/plugins/mouse_spring.py | 19 +++-- .../pyrender/interaction/utils/raycaster.py | 76 +++++++++--------- .../ext/pyrender/interaction/viewer_plugin.py | 18 +++-- genesis/ext/pyrender/viewer.py | 17 ++-- genesis/options/sensors/raycaster.py | 1 - 11 files changed, 201 insertions(+), 174 deletions(-) diff --git a/examples/viewer_plugin/mesh_point_selector.py b/examples/viewer_plugin/mesh_point_selector.py index 140b58bfd6..7a08785cab 100644 --- a/examples/viewer_plugin/mesh_point_selector.py +++ b/examples/viewer_plugin/mesh_point_selector.py @@ -3,7 +3,6 @@ import genesis as gs if __name__ == "__main__": - gs.init(backend=gs.gpu) scene = gs.Scene( diff --git a/examples/viewer_plugin/mouse_spring.py b/examples/viewer_plugin/mouse_spring.py index d88647e131..e0f2545b4b 100644 --- a/examples/viewer_plugin/mouse_spring.py +++ b/examples/viewer_plugin/mouse_spring.py @@ -3,7 +3,6 @@ import genesis as gs if __name__ == "__main__": - gs.init(backend=gs.gpu) scene = gs.Scene( diff --git a/genesis/ext/pyrender/interaction/keybindings.py b/genesis/ext/pyrender/interaction/keybindings.py index 8bb8ce4e7e..1d5fac867c 100644 --- a/genesis/ext/pyrender/interaction/keybindings.py +++ b/genesis/ext/pyrender/interaction/keybindings.py @@ -13,14 +13,16 @@ "equal": "=", } + class KeyAction(IntEnum): PRESS = 0 HOLD = 1 RELEASE = 2 + def get_key_hash(key_code: int, modifiers: int | None, action: KeyAction) -> int: """Generate a unique hash for a key combination. - + Parameters ---------- key_code : int @@ -37,13 +39,16 @@ def get_key_hash(key_code: int, modifiers: int | None, action: KeyAction) -> int """ return hash((key_code, modifiers, action)) + def get_keycode_string(key_code: int) -> str: from pyglet.window.key import symbol_string + symbol = symbol_string(key_code).lower() if symbol in KEY_STRING_TO_CHAR: return KEY_STRING_TO_CHAR[symbol] return symbol + class Keybind(NamedTuple): key_code: int name: str = "" @@ -57,11 +62,11 @@ def key_hash(self) -> int: """Generate a unique hash for the keybind based on key code and modifiers.""" return get_key_hash(self.key_code, self.modifiers, self.key_action) -class Keybindings: +class Keybindings: def __init__(self, keybinds: tuple[Keybind] = ()): self._keybinds_map: dict[int, Keybind] = {kb.key_hash(): kb for kb in keybinds} - + def register(self, keybind: Keybind) -> None: if keybind.key_hash() in self._keybinds_map: existing_kb = self._keybinds_map[keybind.key_hash()] @@ -69,8 +74,14 @@ def register(self, keybind: Keybind) -> None: f"Key '{get_keycode_string(keybind.key_code)}' is already assigned to '{existing_kb.name}'." ) self._keybinds_map[keybind.key_hash()] = keybind - - def rebind(self, name: str, new_key_code: int | None, new_modifiers: int | None = None, new_key_action: KeyAction | None = None) -> None: + + def rebind( + self, + name: str, + new_key_code: int | None, + new_modifiers: int | None = None, + new_key_action: KeyAction | None = None, + ) -> None: for kb in self._keybinds_map.values(): if kb.name == name: new_kb = Keybind( @@ -86,17 +97,17 @@ def rebind(self, name: str, new_key_code: int | None, new_modifiers: int | None self._keybinds_map[new_kb.key_hash()] = new_kb return raise ValueError(f"No keybind found with name '{name}'.") - + def get(self, key: int, modifiers: int, key_action: KeyAction) -> Keybind | None: key_hash = get_key_hash(key, modifiers, key_action) if key_hash in self._keybinds_map: return self._keybinds_map[key_hash] - + # Try ignoring modifiers (for keybinds where modifiers=None) key_hash_no_mods = get_key_hash(key, None, key_action) if key_hash_no_mods in self._keybinds_map: return self._keybinds_map[key_hash_no_mods] - + return None def get_by_name(self, name: str) -> Keybind | None: @@ -104,13 +115,13 @@ def get_by_name(self, name: str) -> Keybind | None: if kb.name == name: return kb return None - + @property def keys(self) -> tuple[str]: """Return a list of all registered keys as ASCII characters.""" return tuple(get_keycode_string(kb.key_code) for kb in self._keybinds_map.values()) - + @property def keybinds(self) -> tuple[Keybind]: """Return a tuple of all registered Keybinds.""" - return tuple(self._keybinds_map.values()) \ No newline at end of file + return tuple(self._keybinds_map.values()) diff --git a/genesis/ext/pyrender/interaction/plugins/default_controls.py b/genesis/ext/pyrender/interaction/plugins/default_controls.py index 48fbb91479..62c43ff0d1 100644 --- a/genesis/ext/pyrender/interaction/plugins/default_controls.py +++ b/genesis/ext/pyrender/interaction/plugins/default_controls.py @@ -16,11 +16,12 @@ INSTR_KEYBIND_NAME = "toggle_instructions" + @register_viewer_plugin(DefaultControlsOptions) class DefaultControls(HelpTextPlugin): """ Default keyboard controls for the Genesis viewer. - + This plugin handles the standard viewer keyboard shortcuts for recording, changing render modes, etc. """ @@ -34,21 +35,25 @@ def __init__( ): super().__init__(viewer, options, camera, scene, viewport_size) - self.viewer.register_keybinds(( - Keybind(key_code=pyglet.window.key.R, name="record_video", callback_func=self._toggle_record_video), - Keybind(key_code=pyglet.window.key.S, name="save_image", callback_func=self._save_image), - Keybind(key_code=pyglet.window.key.Z, name="reset_camera", callback_func=self._reset_camera), - Keybind(key_code=pyglet.window.key.A, name="camera_rotation", callback_func=self._toggle_camera_rotation), - Keybind(key_code=pyglet.window.key.H, name="shadow", callback_func=self._toggle_shadow), - Keybind(key_code=pyglet.window.key.F, name="face_normals", callback_func=self._toggle_face_normals), - Keybind(key_code=pyglet.window.key.V, name="vertex_normals", callback_func=self._toggle_vertex_normals), - Keybind(key_code=pyglet.window.key.W, name="world_frame", callback_func=self._toggle_world_frame), - Keybind(key_code=pyglet.window.key.L, name="link_frame", callback_func=self._toggle_link_frame), - Keybind(key_code=pyglet.window.key.D, name="wireframe", callback_func=self._toggle_wireframe), - Keybind(key_code=pyglet.window.key.C, name="camera_frustum", callback_func=self._toggle_camera_frustum), - Keybind(key_code=pyglet.window.key.P, name="reload_shader", callback_func=self._reload_shader), - Keybind(key_code=pyglet.window.key.F11, name="fullscreen_mode", callback_func=self._toggle_fullscreen), - )) + self.viewer.register_keybinds( + ( + Keybind(key_code=pyglet.window.key.R, name="record_video", callback_func=self._toggle_record_video), + Keybind(key_code=pyglet.window.key.S, name="save_image", callback_func=self._save_image), + Keybind(key_code=pyglet.window.key.Z, name="reset_camera", callback_func=self._reset_camera), + Keybind( + key_code=pyglet.window.key.A, name="camera_rotation", callback_func=self._toggle_camera_rotation + ), + Keybind(key_code=pyglet.window.key.H, name="shadow", callback_func=self._toggle_shadow), + Keybind(key_code=pyglet.window.key.F, name="face_normals", callback_func=self._toggle_face_normals), + Keybind(key_code=pyglet.window.key.V, name="vertex_normals", callback_func=self._toggle_vertex_normals), + Keybind(key_code=pyglet.window.key.W, name="world_frame", callback_func=self._toggle_world_frame), + Keybind(key_code=pyglet.window.key.L, name="link_frame", callback_func=self._toggle_link_frame), + Keybind(key_code=pyglet.window.key.D, name="wireframe", callback_func=self._toggle_wireframe), + Keybind(key_code=pyglet.window.key.C, name="camera_frustum", callback_func=self._toggle_camera_frustum), + Keybind(key_code=pyglet.window.key.P, name="reload_shader", callback_func=self._reload_shader), + Keybind(key_code=pyglet.window.key.F11, name="fullscreen_mode", callback_func=self._toggle_fullscreen), + ) + ) def _toggle_camera_rotation(self): self.viewer.viewer_flags["rotate"] = not self.viewer.viewer_flags["rotate"] @@ -56,7 +61,7 @@ def _toggle_camera_rotation(self): self.set_message_text("Rotation On") else: self.set_message_text("Rotation Off") - + def _toggle_fullscreen(self): self.viewer.viewer_flags["fullscreen"] = not self.viewer.viewer_flags["fullscreen"] self.viewer.set_fullscreen(self.viewer.viewer_flags["fullscreen"]) @@ -65,14 +70,14 @@ def _toggle_fullscreen(self): self.set_message_text("Fullscreen On") else: self.set_message_text("Fullscreen Off") - + def _toggle_shadow(self): self.viewer.render_flags["shadows"] = not self.viewer.render_flags["shadows"] if self.viewer.render_flags["shadows"]: self.set_message_text("Shadows On") else: self.set_message_text("Shadows Off") - + def _toggle_world_frame(self): if not self.viewer.gs_context.world_frame_shown: self.viewer.gs_context.on_world_frame() @@ -80,7 +85,7 @@ def _toggle_world_frame(self): else: self.viewer.gs_context.off_world_frame() self.set_message_text("World Frame Off") - + def _toggle_link_frame(self): if not self.viewer.gs_context.link_frame_shown: self.viewer.gs_context.on_link_frame() @@ -88,7 +93,7 @@ def _toggle_link_frame(self): else: self.viewer.gs_context.off_link_frame() self.set_message_text("Link Frame Off") - + def _toggle_camera_frustum(self): if not self.viewer.gs_context.camera_frustum_shown: self.viewer.gs_context.on_camera_frustum() @@ -96,21 +101,21 @@ def _toggle_camera_frustum(self): else: self.viewer.gs_context.off_camera_frustum() self.set_message_text("Camera Frustum Off") - + def _toggle_face_normals(self): self.viewer.render_flags["face_normals"] = not self.viewer.render_flags["face_normals"] if self.viewer.render_flags["face_normals"]: self.set_message_text("Face Normals On") else: self.set_message_text("Face Normals Off") - + def _toggle_vertex_normals(self): self.viewer.render_flags["vertex_normals"] = not self.viewer.render_flags["vertex_normals"] if self.viewer.render_flags["vertex_normals"]: self.set_message_text("Vert Normals On") else: self.set_message_text("Vert Normals Off") - + def _toggle_record_video(self): if self.viewer.viewer_flags["record"]: self.viewer.save_video() @@ -126,10 +131,10 @@ def _toggle_record_video(self): ) self.viewer.set_caption("{} (RECORDING)".format(self.viewer.viewer_flags["window_title"])) self.viewer.viewer_flags["record"] = not self.viewer.viewer_flags["record"] - + def _save_image(self): self.viewer._save_image() - + def _toggle_wireframe(self): if self.viewer.render_flags["flip_wireframe"]: self.viewer.render_flags["flip_wireframe"] = False @@ -151,9 +156,9 @@ def _toggle_wireframe(self): self.viewer.render_flags["all_wireframe"] = False self.viewer.render_flags["all_solid"] = False self.set_message_text("Flip Wireframe") - + def _reset_camera(self): self.viewer._reset_view() - + def _reload_shader(self): - self.viewer._renderer.reload_program() \ No newline at end of file + self.viewer._renderer.reload_program() diff --git a/genesis/ext/pyrender/interaction/plugins/help_text.py b/genesis/ext/pyrender/interaction/plugins/help_text.py index ff891f8fbd..aca6f5974b 100644 --- a/genesis/ext/pyrender/interaction/plugins/help_text.py +++ b/genesis/ext/pyrender/interaction/plugins/help_text.py @@ -13,10 +13,10 @@ if TYPE_CHECKING: from genesis.engine.scene import Scene from genesis.ext.pyrender.node import Node - from genesis.options.viewer_plugins import ViewerPlugin as ViewerPluginOptions INSTR_KEYBIND_NAME = "toggle_instructions" + @register_viewer_plugin(HelpTextPluginOptions) class HelpTextPlugin(ViewerPlugin): """ @@ -26,19 +26,24 @@ class HelpTextPlugin(ViewerPlugin): def __init__( self, viewer, - options: "ViewerPluginOptions", + options: HelpTextPluginOptions, camera: "Node" = None, scene: "Scene" = None, viewport_size: tuple[int, int] = None, ): super().__init__(viewer, options, camera, scene, viewport_size) - self.viewer.register_keybinds(( - Keybind(key_code=pyglet.window.key.I, name=INSTR_KEYBIND_NAME, callback_func=self._toggle_instructions), - )) - self._collapse_instructions = True - self._instr_texts: tuple[list[str], list[str]] = ([], []) - self._update_instr_texts() + if self.options.display_instructions: + self.viewer.register_keybinds( + ( + Keybind( + key_code=pyglet.window.key.I, name=INSTR_KEYBIND_NAME, callback_func=self._toggle_instructions + ), + ) + ) + self._collapse_instructions = True + self._instr_texts: tuple[list[str], list[str]] = ([], []) + self._update_instr_texts() self._message_text = None self._ticks_till_fade = 2.0 / 3.0 * self.viewer.viewer_flags["refresh_rate"] @@ -47,18 +52,21 @@ def __init__( def _update_instr_texts(self): self.instr_key_str = get_keycode_string(self.viewer._keybindings.get_by_name(INSTR_KEYBIND_NAME).key_code) kb_texts = [ - f"{'[' + get_keycode_string(kb.key_code):>{7}}]: " + - kb.name.replace('_', ' ') for kb in self.viewer._keybindings.keybinds if kb.name != INSTR_KEYBIND_NAME + f"{'[' + get_keycode_string(kb.key_code):>{7}}]: " + kb.name.replace("_", " ") + for kb in self.viewer._keybindings.keybinds + if kb.name != INSTR_KEYBIND_NAME ] self._instr_texts = ( [f"> [{self.instr_key_str}]: show keyboard instructions"], - [f"< [{self.instr_key_str}]: hide keyboard instructions"] + kb_texts + [f"< [{self.instr_key_str}]: hide keyboard instructions"] + kb_texts, ) def _toggle_instructions(self): + if not self.options.display_instructions: + raise RuntimeError("Instructions display is disabled by options.") self._collapse_instructions = not self._collapse_instructions self._update_instr_texts() - + def set_message_text(self, text: str): self._message_text = text self._message_opac = 1.0 + self._ticks_till_fade @@ -79,26 +87,25 @@ def on_draw(self): self._message_opac -= 1.0 else: self._message_opac *= 0.90 - + if self._message_opac < 0.05: self._message_opac = 1.0 + self._ticks_till_fade self._message_text = None - - if self._collapse_instructions: - self.viewer._renderer.render_texts( - self._instr_texts[0], - TEXT_PADDING, - self.viewer.viewport_size[1] - TEXT_PADDING, - font_pt=self.options.font_size, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - else: - self.viewer._renderer.render_texts( - self._instr_texts[1], - TEXT_PADDING, - self.viewer.viewport_size[1] - TEXT_PADDING, - font_pt=self.options.font_size, - color=np.array([1.0, 1.0, 1.0, 0.85]), - ) - \ No newline at end of file + if self.options.display_instructions: + if self._collapse_instructions: + self.viewer._renderer.render_texts( + self._instr_texts[0], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=self.options.font_size, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) + else: + self.viewer._renderer.render_texts( + self._instr_texts[1], + TEXT_PADDING, + self.viewer.viewport_size[1] - TEXT_PADDING, + font_pt=self.options.font_size, + color=np.array([1.0, 1.0, 1.0, 0.85]), + ) diff --git a/genesis/ext/pyrender/interaction/plugins/mesh_point_selector.py b/genesis/ext/pyrender/interaction/plugins/mesh_point_selector.py index 55bc4b1bb8..3d3302287d 100644 --- a/genesis/ext/pyrender/interaction/plugins/mesh_point_selector.py +++ b/genesis/ext/pyrender/interaction/plugins/mesh_point_selector.py @@ -19,7 +19,7 @@ class SelectedPoint(NamedTuple): """ Represents a selected point on a rigid mesh surface. - + Attributes ---------- link : RigidLink @@ -29,6 +29,7 @@ class SelectedPoint(NamedTuple): local_normal : Vec3 The surface normal at the point in the link's local coordinate frame. """ + link: "RigidLink" local_position: Vec3 local_normal: Vec3 @@ -60,24 +61,24 @@ def __init__( def _snap_to_grid(self, position: Vec3) -> Vec3: """ Snap a position to the grid based on grid_snap settings. - + Parameters ---------- position : Vec3 The position to snap. - + Returns ------- Vec3 The snapped position. """ snap_x, snap_y, snap_z = self.options.grid_snap - + # Snap each axis if the snap value is non-negative x = round(position.x / snap_x) * snap_x if snap_x >= 0 else position.x y = round(position.y / snap_y) * snap_y if snap_y >= 0 else position.y z = round(position.z / snap_z) * snap_z if snap_z >= 0 else position.z - + return Vec3.from_xyz(x, y, z) @override @@ -97,7 +98,7 @@ def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_H link = ray_hit.geom.link world_pos = ray_hit.position world_normal = ray_hit.normal - + pose: Pose = Pose.from_link(link) local_pos = pose.inverse_transform_point(world_pos) local_normal = pose.inverse_transform_direction(world_normal) @@ -105,11 +106,7 @@ def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_H # Apply grid snapping to local position local_pos = self._snap_to_grid(local_pos) - selected_point = SelectedPoint( - link=link, - local_position=local_pos, - local_normal=local_normal - ) + selected_point = SelectedPoint(link=link, local_position=local_pos, local_normal=local_normal) self.selected_points.append(selected_point) return EVENT_HANDLED @@ -170,36 +167,42 @@ def on_close(self) -> None: if not self.selected_points: print("[MeshPointSelectorPlugin] No points selected.") return - + output_file = self.options.output_file try: - with open(output_file, 'w', newline='') as csvfile: + with open(output_file, "w", newline="") as csvfile: writer = csv.writer(csvfile) - - writer.writerow([ - 'point_idx', - 'link_idx', - 'local_pos_x', - 'local_pos_y', - 'local_pos_z', - 'local_normal_x', - 'local_normal_y', - 'local_normal_z' - ]) - + + writer.writerow( + [ + "point_idx", + "link_idx", + "local_pos_x", + "local_pos_y", + "local_pos_z", + "local_normal_x", + "local_normal_y", + "local_normal_z", + ] + ) + for i, point in enumerate(self.selected_points, 1): - writer.writerow([ - i, - point.link.idx, - point.local_position.x, - point.local_position.y, - point.local_position.z, - point.local_normal.x, - point.local_normal.y, - point.local_normal.z, - ]) - - gs.logger.info(f"[MeshPointSelectorPlugin] Wrote {len(self.selected_points)} selected points to '{output_file}'") + writer.writerow( + [ + i, + point.link.idx, + point.local_position.x, + point.local_position.y, + point.local_position.z, + point.local_normal.x, + point.local_normal.y, + point.local_normal.z, + ] + ) + + gs.logger.info( + f"[MeshPointSelectorPlugin] Wrote {len(self.selected_points)} selected points to '{output_file}'" + ) except Exception as e: gs.logger.error(f"[MeshPointSelectorPlugin] Error writing to '{output_file}': {e}") diff --git a/genesis/ext/pyrender/interaction/plugins/mouse_spring.py b/genesis/ext/pyrender/interaction/plugins/mouse_spring.py index 883e793f4c..17638dbdcf 100644 --- a/genesis/ext/pyrender/interaction/plugins/mouse_spring.py +++ b/genesis/ext/pyrender/interaction/plugins/mouse_spring.py @@ -19,6 +19,7 @@ MOUSE_SPRING_POSITION_CORRECTION_FACTOR = 1.0 MOUSE_SPRING_VELOCITY_CORRECTION_FACTOR = 1.0 + class MouseSpring: def __init__(self) -> None: self.held_link: "RigidLink | None" = None @@ -55,7 +56,9 @@ def apply_force(self, control_point: Vec3, delta_time: float) -> None: link_T_principal: Pose = Pose(Vec3.from_arraylike(link.inertial_pos), Quat.from_arraylike(link.inertial_quat)) world_T_principal: Pose = link_pose * link_T_principal - arm_in_principal: Vec3 = link_T_principal.inverse_transform_point(self.held_point_in_local) # for non-spherical inertia + arm_in_principal: Vec3 = link_T_principal.inverse_transform_point( + self.held_point_in_local + ) # for non-spherical inertia arm_in_world: Vec3 = world_T_principal.rot * arm_in_principal # for spherical inertia pos_err_v: Vec3 = control_point - held_point_in_world @@ -69,7 +72,7 @@ def apply_force(self, control_point: Vec3, delta_time: float) -> None: total_impulse: Vec3 = Vec3.zero() total_torque_impulse: Vec3 = Vec3.zero() - for i in range(3*4): + for i in range(3 * 4): body_point_vel: Vec3 = lin_vel + ang_vel.cross(arm_in_world) vel_err_v: Vec3 = Vec3.zero() - body_point_vel @@ -93,8 +96,8 @@ def apply_force(self, control_point: Vec3, delta_time: float) -> None: total_torque = total_torque_impulse * inv_dt force_tensor: torch.Tensor = total_force.as_tensor()[None] torque_tensor: torch.Tensor = total_torque.as_tensor()[None] - link.solver.apply_links_external_force(force_tensor, (link.idx,), ref='link_com', local=False) - link.solver.apply_links_external_torque(torque_tensor, (link.idx,), ref='link_com', local=False) + link.solver.apply_links_external_force(force_tensor, (link.idx,), ref="link_com", local=False) + link.solver.apply_links_external_torque(torque_tensor, (link.idx,), ref="link_com", local=False) @property def is_attached(self) -> bool: @@ -145,9 +148,10 @@ def on_mouse_drag(self, x: int, y: int, dx: int, dy: int, buttons: int, modifier @override def on_mouse_press(self, x: int, y: int, button: int, modifiers: int) -> EVENT_HANDLE_STATE: super().on_mouse_press(x, y, button, modifiers) - if button == 1: # left mouse button - - ray_hit = self.raycaster.cast_ray(self._screen_position_to_ray(x, y).origin.v, self._screen_position_to_ray(x, y).direction.v) + if button == 1: # left mouse button + ray_hit = self.raycaster.cast_ray( + self._screen_position_to_ray(x, y).origin.v, self._screen_position_to_ray(x, y).direction.v + ) with self.lock: if ray_hit.geom: self.picked_link = ray_hit.geom.link @@ -229,7 +233,6 @@ def on_draw(self) -> None: if closest_hit.geom: self._draw_entity_unrotated_obb(closest_hit.geom) - def _get_box_obb(self, box_entity: "RigidEntity") -> OBB: box: gs.morphs.Box = box_entity.morph pose = Pose.from_link(box_entity.links[0]) diff --git a/genesis/ext/pyrender/interaction/utils/raycaster.py b/genesis/ext/pyrender/interaction/utils/raycaster.py index 6c8d31bb1f..e6c95a7751 100644 --- a/genesis/ext/pyrender/interaction/utils/raycaster.py +++ b/genesis/ext/pyrender/interaction/utils/raycaster.py @@ -18,7 +18,6 @@ NO_HIT_DISTANCE = -1.0 - @ti.kernel def kernel_cast_single_ray_for_viewer( fixed_verts_state: ti.template(), @@ -31,13 +30,15 @@ def kernel_cast_single_ray_for_viewer( ray_direction: ti.types.ndarray(ndim=1), # [3] max_range: ti.f32, envs_idx: ti.types.ndarray(ndim=1), # [n_envs] - result: ti.types.ndarray(ndim=1), # [9]: [distance, geom_idx, hit_x, hit_y, hit_z, normal_x, normal_y, normal_z, env_idx] + result: ti.types.ndarray( + ndim=1 + ), # [9]: [distance, geom_idx, hit_x, hit_y, hit_z, normal_x, normal_y, normal_z, env_idx] ): """ Taichi kernel for casting a single ray for viewer interaction. - + This loops over all environments in envs_idx and returns the closest hit. - + Returns: result[0]: distance to hit point (NO_HIT_DISTANCE if no hit) result[1]: geom_idx of hit geometry @@ -50,11 +51,11 @@ def kernel_cast_single_ray_for_viewer( result[8]: env_idx of hit environment """ n_triangles = faces_info.verts_idx.shape[0] - + # Setup ray ray_start_world = ti.math.vec3(ray_start[0], ray_start[1], ray_start[2]) ray_direction_world = ti.math.vec3(ray_direction[0], ray_direction[1], ray_direction[2]) - + # Initialize result with no hit result[0] = -1.0 # NO_HIT_DISTANCE result[1] = -1.0 # no geom @@ -65,40 +66,40 @@ def kernel_cast_single_ray_for_viewer( result[6] = 0.0 # normal y result[7] = 0.0 # normal z result[8] = -1.0 # no env - + global_closest_distance = max_range global_hit_face = -1 global_hit_env_idx = -1 global_hit_normal = ti.math.vec3(0.0, 0.0, 0.0) - + # Loop over all environments in envs_idx for i_b in range(envs_idx.shape[0]): rendered_env_idx = ti.cast(envs_idx[i_b], ti.i32) - + hit_face = -1 closest_distance = global_closest_distance hit_normal = ti.math.vec3(0.0, 0.0, 0.0) - + # Stack for non-recursive BVH traversal node_stack = ti.Vector.zero(ti.i32, STACK_SIZE) node_stack[0] = 0 # Start at root node stack_idx = 1 - + while stack_idx > 0: stack_idx -= 1 node_idx = node_stack[stack_idx] - + node = bvh_nodes[i_b, node_idx] - + # Check if ray hits the node's bounding box aabb_t = ray_aabb_intersection(ray_start_world, ray_direction_world, node.bound.min, node.bound.max) - + if aabb_t >= 0.0 and aabb_t < closest_distance: if node.left == -1: # Leaf node # Get original triangle/face index sorted_leaf_idx = node_idx - (n_triangles - 1) i_f = ti.cast(bvh_morton_codes[0, sorted_leaf_idx][1], ti.i32) - + # Get triangle vertices tri_vertices = ti.Matrix.zero(gs.ti_float, 3, 3) for i in ti.static(range(3)): @@ -109,10 +110,10 @@ def kernel_cast_single_ray_for_viewer( else: tri_vertices[:, i] = free_verts_state.pos[i_fv, rendered_env_idx] v0, v1, v2 = tri_vertices[:, 0], tri_vertices[:, 1], tri_vertices[:, 2] - + # Perform ray-triangle intersection hit_result = ray_triangle_intersection(ray_start_world, ray_direction_world, v0, v1, v2) - + if hit_result.w > 0.0 and hit_result.x < closest_distance and hit_result.x >= 0.0: closest_distance = hit_result.x hit_face = i_f @@ -126,14 +127,14 @@ def kernel_cast_single_ray_for_viewer( node_stack[stack_idx] = node.left node_stack[stack_idx + 1] = node.right stack_idx += 2 - + # Update global closest if this environment had a closer hit if hit_face >= 0 and closest_distance < global_closest_distance: global_closest_distance = closest_distance global_hit_face = hit_face global_hit_env_idx = rendered_env_idx global_hit_normal = hit_normal - + # Store result if global_hit_face >= 0: result[0] = global_closest_distance # distance (positive value indicates hit) @@ -152,11 +153,10 @@ def kernel_cast_single_ray_for_viewer( result[8] = gs.ti_float(global_hit_env_idx) - class ViewerRaycaster: """ BVH-accelerated raycaster for viewer interaction plugins. - + This class manages a BVH structure built from the scene's rigid geometry and provides efficient single-ray casting for interactive applications. Only considers environments specified in rendered_envs_idx. @@ -165,7 +165,7 @@ class ViewerRaycaster: def __init__(self, scene: "Scene"): """ Initialize the ViewerRaycaster. - + Parameters ---------- scene : Scene @@ -181,30 +181,30 @@ def __init__(self, scene: "Scene"): # Build the BVH structure for rendered environments. n_faces = self.solver.faces_info.geom_idx.shape[0] - + if n_faces == 0: gs.logger.warning("No faces found in scene, viewer raycasting will not work.") self.aabb = None self.bvh = None return - + self.aabb = AABB(n_batches=len(self.rendered_envs_idx), n_aabbs=n_faces) self.bvh = LBVH( self.aabb, max_n_query_result_per_aabb=0, # Not used for ray queries n_radix_sort_groups=min(64, n_faces), ) - + self.update_bvh() - + def update_bvh(self): """Update the BVH structure with current geometry state.""" if self.bvh is None: return - + # Update vertex positions from genesis.engine.solvers.rigid.rigid_solver_decomp import kernel_update_all_verts - + kernel_update_all_verts( geoms_info=self.solver.geoms_info, geoms_state=self.solver.geoms_state, @@ -213,7 +213,7 @@ def update_bvh(self): fixed_verts_state=self.solver.fixed_verts_state, static_rigid_sim_config=self.solver._static_rigid_sim_config, ) - + # Update AABBs for each rendered environment kernel_update_aabbs( free_verts_state=self.solver.free_verts_state, @@ -222,10 +222,10 @@ def update_bvh(self): faces_info=self.solver.faces_info, aabb_state=self.aabb, ) - + # Rebuild BVH self.bvh.build() - + def cast_ray( self, ray_origin: np.ndarray, @@ -234,7 +234,7 @@ def cast_ray( ) -> RayHit: """ Cast a single ray against all rendered environments and return the closest hit. - + Parameters ---------- ray_origin : np.ndarray, shape (3,) @@ -243,7 +243,7 @@ def cast_ray( The direction vector of the ray (will be normalized). max_range : float, optional Maximum distance to check for intersections. Default is 1000.0. - + Returns ------- RayHit @@ -251,11 +251,11 @@ def cast_ray( If no hit, returns RayHit.no_hit(). """ ray_direction = ray_direction / (np.linalg.norm(ray_direction) + gs.EPS) - + ray_start_np = np.asarray(ray_origin, dtype=gs.np_float) ray_dir_np = np.asarray(ray_direction, dtype=gs.np_float) result_np = np.zeros(9, dtype=gs.np_float) - + kernel_cast_single_ray_for_viewer( fixed_verts_state=self.solver.fixed_verts_state, free_verts_state=self.solver.free_verts_state, @@ -269,14 +269,14 @@ def cast_ray( envs_idx=self.rendered_envs_idx, result=result_np, ) - + distance = float(result_np[0]) if distance < NO_HIT_DISTANCE + gs.EPS: # NO_HIT_DISTANCE return RayHit.no_hit() - + geom_idx = int(result_np[1]) position = Vec3(result_np[2:5]) normal = Vec3(result_np[5:8]) geom = self.solver.geoms[geom_idx] - + return RayHit(distance=distance, position=position, normal=normal, geom=geom) diff --git a/genesis/ext/pyrender/interaction/viewer_plugin.py b/genesis/ext/pyrender/interaction/viewer_plugin.py index e6708e2dfd..41cbacc30a 100644 --- a/genesis/ext/pyrender/interaction/viewer_plugin.py +++ b/genesis/ext/pyrender/interaction/viewer_plugin.py @@ -20,31 +20,35 @@ def register_viewer_plugin(options_cls: Type["ViewerPluginOptions"]): """ Decorator to register a viewer plugin class with its corresponding options class. - + Parameters ---------- options_cls : Type[ViewerPluginOptions] The options class that configures this viewer plugin. - + Returns ------- Callable The decorator function that registers the plugin class. - + Example ------- @register_viewer_plugin(ViewerInteractionOptions) class ViewerInteraction(ViewerInteractionBase): ... """ + def _impl(plugin_cls: Type["ViewerPlugin"]): VIEWER_PLUGIN_MAP[options_cls] = plugin_cls return plugin_cls + return _impl + # Note: Viewer window is based on pyglet.window.Window, mouse events are defined in pyglet.window.BaseWindow -class ViewerPlugin(): + +class ViewerPlugin: """ Base class for handling pyglet.window.Window events. """ @@ -59,8 +63,8 @@ def __init__( ): self.viewer = viewer self.options: "ViewerPluginOptions" = options - self.camera: 'Node' = camera - self.scene: 'Scene' = scene + self.camera: "Node" = camera + self.scene: "Scene" = scene self.viewport_size: tuple[int, int] = viewport_size self.camera_yfov: float = camera.camera.yfov @@ -123,4 +127,4 @@ def _get_camera_ray(self) -> Ray: mtx = self.camera.matrix position = Vec3.from_array(mtx[:3, 3]) forward = Vec3.from_array(-mtx[:3, 2]) - return Ray(position, forward) \ No newline at end of file + return Ray(position, forward) diff --git a/genesis/ext/pyrender/viewer.py b/genesis/ext/pyrender/viewer.py index a2cd20ba90..623a978037 100644 --- a/genesis/ext/pyrender/viewer.py +++ b/genesis/ext/pyrender/viewer.py @@ -339,7 +339,7 @@ def __init__( # Note: context.scene is genesis.engine.scene.Scene # Note: context._scene is genesis.ext.pyrender.scene.Scene - + # Setup viewer interaction if plugin_options is None: plugin_options = gs.options.viewer_plugins.ViewerDefaultControls() @@ -350,9 +350,7 @@ def __init__( f"Viewer plugin type {type(plugin_options).__name__} is not registered. " f"Available plugins: {list(VIEWER_PLUGIN_MAP.keys())}" ) - self.interaction_plugin = plugin_cls( - self, plugin_options, self._camera_node, context.scene, viewport_size - ) + self.interaction_plugin = plugin_cls(self, plugin_options, self._camera_node, context.scene, viewport_size) ####################################################################### # Initialize OpenGL context and renderer @@ -485,7 +483,7 @@ def viewer_flags(self): @viewer_flags.setter def viewer_flags(self, value): self._viewer_flags = value - + def register_keybinds(self, keybinds: tuple[Keybind]) -> None: """ Add a key handler to call a function when the given key is pressed. @@ -497,7 +495,6 @@ def register_keybinds(self, keybinds: tuple[Keybind]) -> None: """ for keybind in keybinds: self._keybindings.register(keybind) - def close(self): """Close the viewer. @@ -790,9 +787,9 @@ def on_mouse_scroll(self, x, y, dx, dy) -> EVENT_HANDLE_STATE: ymag = max(c.ymag * sf, 1e-8 * c.ymag / c.xmag) c.xmag = xmag c.ymag = ymag - + return EVENT_HANDLED - + def _call_keybind_callback(self, symbol: int, modifiers: int, action: KeyAction) -> None: """Call registered keybind callbacks for the given key event.""" keybind: Keybind = self._keybindings.get(symbol, modifiers, action) @@ -1128,7 +1125,7 @@ def start(self, auto_refresh=True): # The viewer can be considered as fully initialized at this point if not self._initialized_event.is_set(): self._initialized_event.set() - + if auto_refresh: while self._is_active: try: @@ -1181,7 +1178,7 @@ def refresh(self): def update_on_sim_step(self): # Call HOLD callbacks for all currently held keys - for (symbol, modifiers) in list(self._held_keys.keys()): + for symbol, modifiers in list(self._held_keys.keys()): self._call_keybind_callback(symbol, modifiers, KeyAction.HOLD) self.interaction_plugin.update_on_sim_step() diff --git a/genesis/options/sensors/raycaster.py b/genesis/options/sensors/raycaster.py index 0a42e07d99..1e9af7b7ae 100644 --- a/genesis/options/sensors/raycaster.py +++ b/genesis/options/sensors/raycaster.py @@ -63,7 +63,6 @@ def _sanitize_rays_to_tensor(rays: Sequence[float]) -> torch.Tensor: class RaycastCustomPattern(RaycastPattern): - def __init__(self, ray_dirs: Sequence[float], ray_starts: Sequence[float]): self._ray_dirs = _sanitize_rays_to_tensor(ray_dirs) self._ray_starts = _sanitize_rays_to_tensor(ray_starts)