diff --git a/.clang-format b/.clang-format
index 6b72b54..3f016df 100644
--- a/.clang-format
+++ b/.clang-format
@@ -165,4 +165,3 @@ TabWidth: 8
UseCRLF: false
UseTab: Never
...
-
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..8c66bd8
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,39 @@
+# See https://pre-commit.com for more information
+# See https://pre-commit.com/hooks.html for more hooks
+repos:
+ - repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v4.4.0
+ hooks:
+ - id: check-symlinks
+ - id: destroyed-symlinks
+ - id: trailing-whitespace
+ - id: end-of-file-fixer
+ - id: check-yaml
+ - id: check-toml
+ - id: check-ast
+ - id: check-added-large-files
+ args: ["--maxkb=2000"]
+ - id: check-merge-conflict
+ - id: check-executables-have-shebangs
+ - id: check-shebang-scripts-are-executable
+ - id: detect-private-key
+ - id: debug-statements
+ - repo: https://github.com/codespell-project/codespell
+ rev: v2.2.2
+ hooks:
+ - id: codespell
+ args:
+ - --skip=*.ipynb
+ - -L ths
+ - repo: https://github.com/python/black
+ rev: 23.1.0
+ hooks:
+ - id: black
+ - repo: https://github.com/charliermarsh/ruff-pre-commit
+ # Ruff version.
+ rev: 'v0.4.10'
+ hooks:
+ - id: ruff
+ args:
+ - --fix
+ - --unsafe-fixes
diff --git a/README.md b/README.md
index 5b288f4..f5bdf09 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
# PyTorch Robot Kinematics
- Parallel and differentiable forward kinematics (FK), Jacobian calculation, and damped least squares inverse kinematics (IK)
-- Load robot description from URDF, SDF, and MJCF formats
+- Load robot description from URDF, SDF, and MJCF formats
- SDF queries batched across configurations and points via [pytorch-volumetric](https://github.com/UM-ARM-Lab/pytorch_volumetric)
# Installation
@@ -183,7 +183,7 @@ ret = chain.forward_kinematics(th)
## Jacobian calculation
The Jacobian (in the kinematics context) is a matrix describing how the end effector changes with respect to joint value changes
(where  is the twist, or stacked velocity and angular velocity):
-
+
For `SerialChain` we provide a differentiable and parallelizable method for computing the Jacobian with respect to the base frame.
```python
@@ -223,7 +223,7 @@ The Jacobian can be used to do inverse kinematics. See [IK survey](https://www.m
for a survey of ways to do so. Note that IK may be better performed through other means (but doing it through the Jacobian can give an end-to-end differentiable method).
## Inverse Kinematics (IK)
-Inverse kinematics is available via damped least squares (iterative steps with Jacobian pseudo-inverse damped to avoid oscillation near singularlities).
+Inverse kinematics is available via damped least squares (iterative steps with Jacobian pseudo-inverse damped to avoid oscillation near singularlities).
Compared to other IK libraries, these are the typical advantages over them:
- not ROS dependent (many IK libraries need the robot description on the ROS parameter server)
- batched in both goal specification and retries from different starting configurations
diff --git a/pyproject.toml b/pyproject.toml
index f9b05eb..6a5f842 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -77,3 +77,77 @@ test = [
# They will be installed as dependencies during the build, which can take a while the first time.
requires = ["setuptools>=60.0.0", "wheel"]
build-backend= "setuptools.build_meta"
+
+
+[tool.ruff]
+src = ["src", "tests",]
+show-fixes = true
+# Same as Black.
+line-length = 127
+# Assume Python 3.8.
+target-version = "py38"
+
+[tool.ruff.lint]
+extend-select = ["C4", "SIM", "TCH"]
+
+ignore = [
+ "A001", "A002",
+ "ARG001", "ARG002",
+ "B007", "B008", "B019", "B023", "B028", "B904", "B026",
+ "E501", "ERA001", "E741", "E722",
+ "FBT001", "FBT002", "FBT003",
+ "N802", "N803", "N806", "N812",
+ "PGH003", "PGH004", "PLR2004",
+ "S101", "S202", "S301", "S310", "S311", "S320",
+ "UP006",
+ "T201", "T203"
+]
+
+select = [
+ "A",
+ "ARG",
+ "B",
+ "C4",
+ "C90",
+ "E",
+ "ERA",
+ "F",
+ "FBT",
+ "ICN",
+ "I",
+ "ISC",
+ "N",
+ "NPY",
+ "PD",
+ "PGH",
+ "PIE",
+ "PLE",
+ "PLR",
+ "Q",
+ "RUF",
+ "S",
+ "SIM",
+ "T",
+ "UP",
+ "W",
+]
+
+[tool.ruff.lint.mccabe]
+# Unlike Flake8, default to a complexity level of 10.
+max-complexity = 30
+
+[tool.ruff.lint.isort]
+lines-after-imports = 2
+extra-standard-library = ["typing_extensions"]
+
+[tool.ruff.lint.pylint]
+max-statements = 100
+max-branches=25
+max-args=20
+
+[tool.ruff.lint.per-file-ignores]
+"tests/*.py" = ["PLR2004", "N802", "N801", "SIM115", "E501", "ERA001"]
+
+
+[tool.black]
+line-length = 127
diff --git a/src/pytorch_kinematics/cfg.py b/src/pytorch_kinematics/cfg.py
index 3d54276..918192d 100644
--- a/src/pytorch_kinematics/cfg.py
+++ b/src/pytorch_kinematics/cfg.py
@@ -1,4 +1,5 @@
-import os
+from pathlib import Path
-ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
-TEST_DIR = os.path.join(ROOT_DIR, 'tests')
+
+ROOT_DIR = Path(__file__).resolve().parent.parent.parent
+TEST_DIR = ROOT_DIR / "tests"
diff --git a/src/pytorch_kinematics/chain.py b/src/pytorch_kinematics/chain.py
index 0b21114..22280d2 100644
--- a/src/pytorch_kinematics/chain.py
+++ b/src/pytorch_kinematics/chain.py
@@ -1,17 +1,31 @@
+import copy
from functools import lru_cache
-from typing import Optional, Sequence
+from typing import (
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Union,
+)
-import copy
import numpy as np
import torch
import pytorch_kinematics.transforms as tf
from pytorch_kinematics import jacobian
-from pytorch_kinematics.frame import Frame, Link, Joint
-from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_44, axis_and_d_to_pris_matrix
+from pytorch_kinematics.frame import Frame, Joint, Link
+from pytorch_kinematics.transforms.rotation_conversions import (
+ axis_and_angle_to_matrix_44,
+ axis_and_d_to_pris_matrix,
+)
+
+
+SUPPORTED_INPUTS = Union[torch.Tensor, np.ndarray, List, Dict]
-def get_n_joints(th):
+def get_n_joints(th: SUPPORTED_INPUTS) -> int:
"""
Args:
@@ -20,16 +34,16 @@ def get_n_joints(th):
Returns: The number of joints in the input
"""
- if isinstance(th, torch.Tensor) or isinstance(th, np.ndarray):
+ if isinstance(th, (np.ndarray, torch.Tensor)):
return th.shape[-1]
- elif isinstance(th, list) or isinstance(th, dict):
+ elif isinstance(th, (dict, list)):
return len(th)
else:
raise NotImplementedError(f"Unsupported type {type(th)}")
-def get_batch_size(th):
- if isinstance(th, torch.Tensor) or isinstance(th, np.ndarray):
+def get_batch_size(th: SUPPORTED_INPUTS) -> int:
+ if isinstance(th, (np.ndarray, torch.Tensor)):
return th.shape[0]
elif isinstance(th, dict):
elem_shape = get_dict_elem_shape(th)
@@ -41,7 +55,11 @@ def get_batch_size(th):
raise NotImplementedError(f"Unsupported type {type(th)}")
-def ensure_2d_tensor(th, dtype, device):
+def ensure_2d_tensor(
+ th: Union[torch.Tensor, np.ndarray],
+ dtype: torch.dtype,
+ device: torch.device,
+) -> Tuple[torch.Tensor, int]:
if not torch.is_tensor(th):
th = torch.tensor(th, dtype=dtype, device=device)
if len(th.shape) <= 1:
@@ -52,11 +70,9 @@ def ensure_2d_tensor(th, dtype, device):
return th, N
-def get_dict_elem_shape(th_dict):
- elem = th_dict[list(th_dict.keys())[0]]
- if isinstance(elem, np.ndarray):
- return elem.shape
- elif isinstance(elem, torch.Tensor):
+def get_dict_elem_shape(th_dict: Dict[str, Union[np.ndarray, torch.Tensor]]) -> Tuple[int, ...]:
+ elem = th_dict[next(iter(th_dict.keys()))]
+ if isinstance(elem, (np.ndarray, torch.Tensor)):
return elem.shape
else:
return ()
@@ -69,7 +85,12 @@ class Chain:
having a physical link and a number of child frames each connected via some joint.
"""
- def __init__(self, root_frame, dtype=torch.float32, device="cpu"):
+ def __init__(
+ self,
+ root_frame: Frame,
+ dtype: torch.dtype = torch.float32,
+ device: torch.device = torch.device("cpu"),
+ ) -> None:
self._root = root_frame
self.dtype = dtype
self.device = device
@@ -85,7 +106,7 @@ def __init__(self, root_frame, dtype=torch.float32, device="cpu"):
# parents_indices and joint_indices all use this indexing scheme.
# The root frame will be index 0 and the first frame of the root frame's children will be index 1,
# then the child of that frame will be index 2, etc. In other words, it's a depth-first ordering.
- self.parents_indices = [] # list of indices from 0 (root) to the given frame
+ self.parents_indices: List[torch.Tensor] = [] # list of indices from 0 (root) to the given frame
self.joint_indices = []
self.n_joints = len(self.get_joint_parameter_names())
self.axes = torch.zeros([self.n_joints, 3], dtype=self.dtype, device=self.device)
@@ -107,7 +128,7 @@ def __init__(self, root_frame, dtype=torch.float32, device="cpu"):
else:
self.parents_indices.append(self.parents_indices[parent_idx] + [idx])
- is_fixed = root.joint.joint_type == 'fixed'
+ is_fixed = root.joint.joint_type == "fixed"
if root.link.offset is None:
self.link_offsets.append(None)
@@ -139,7 +160,11 @@ def __init__(self, root_frame, dtype=torch.float32, device="cpu"):
# We need to use a dict because torch.compile doesn't list lists of tensors
self.parents_indices = [torch.tensor(p, dtype=torch.long, device=self.device) for p in self.parents_indices]
- def to(self, dtype=None, device=None):
+ def to(
+ self,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ ) -> "Chain":
if dtype is not None:
self.dtype = dtype
if device is not None:
@@ -152,142 +177,149 @@ def to(self, dtype=None, device=None):
self.joint_indices = self.joint_indices.to(dtype=torch.long, device=self.device)
self.axes = self.axes.to(dtype=self.dtype, device=self.device)
self.link_offsets = [l if l is None else l.to(dtype=self.dtype, device=self.device) for l in self.link_offsets]
- self.joint_offsets = [j if j is None else j.to(dtype=self.dtype, device=self.device) for j in
- self.joint_offsets]
+ self.joint_offsets = [j if j is None else j.to(dtype=self.dtype, device=self.device) for j in self.joint_offsets]
self.low = self.low.to(dtype=self.dtype, device=self.device)
self.high = self.high.to(dtype=self.dtype, device=self.device)
return self
- def __str__(self):
+ def __str__(self) -> str:
return str(self._root)
@staticmethod
- def _find_frame_recursive(name, frame: Frame) -> Optional[Frame]:
+ def _find_frame_recursive(name: str, frame: Frame) -> Optional[Frame]:
for child in frame.children:
if child.name == name:
return child
ret = Chain._find_frame_recursive(name, child)
- if not ret is None:
+ if ret is not None:
return ret
return None
- def find_frame(self, name) -> Optional[Frame]:
+ def find_frame(self, name: str) -> Optional[Frame]:
if self._root.name == name:
return self._root
return self._find_frame_recursive(name, self._root)
@staticmethod
- def _find_link_recursive(name, frame) -> Optional[Link]:
+ def _find_link_recursive(name: str, frame: Frame) -> Optional[Link]:
for child in frame.children:
if child.link.name == name:
return child.link
ret = Chain._find_link_recursive(name, child)
- if not ret is None:
+ if ret is not None:
return ret
return None
@staticmethod
- def _get_joints(frame, exclude_fixed=True):
- joints = []
+ def _get_joints(frame: Frame, exclude_fixed: bool = True) -> List[Joint]:
+ joints: List[Joint] = []
if exclude_fixed and frame.joint.joint_type != "fixed":
joints.append(frame.joint)
for child in frame.children:
joints.extend(Chain._get_joints(child))
return joints
- def get_joints(self, exclude_fixed=True):
+ def get_joints(self, exclude_fixed: bool = True) -> List[Joint]:
joints = self._get_joints(self._root, exclude_fixed=exclude_fixed)
return joints
- @lru_cache()
- def get_joint_parameter_names(self, exclude_fixed=True):
- names = []
+ @lru_cache
+ def get_joint_parameter_names(self, exclude_fixed: bool = True) -> List[str]:
+ names: List[str] = []
for j in self.get_joints(exclude_fixed=exclude_fixed):
- if exclude_fixed and j.joint_type == 'fixed':
+ if exclude_fixed and j.joint_type == "fixed":
continue
names.append(j.name)
return names
@staticmethod
- def _find_joint_recursive(name, frame):
+ def _find_joint_recursive(name: str, frame: Frame) -> Optional[Joint]:
for child in frame.children:
if child.joint.name == name:
return child.joint
ret = Chain._find_joint_recursive(name, child)
- if not ret is None:
+ if ret is not None:
return ret
return None
- def find_link(self, name) -> Optional[Link]:
+ def find_link(self, name: str) -> Optional[Link]:
if self._root.link.name == name:
return self._root.link
return self._find_link_recursive(name, self._root)
- def find_joint(self, name):
+ def find_joint(self, name: str) -> Optional[Joint]:
if self._root.joint.name == name:
return self._root.joint
return self._find_joint_recursive(name, self._root)
@staticmethod
- def _get_joint_parent_frame_names(frame, exclude_fixed=True):
- joint_names = []
+ def _get_joint_parent_frame_names(frame: Frame, exclude_fixed: bool = True) -> List[str]:
+ joint_names: List[str] = []
if not (exclude_fixed and frame.joint.joint_type == "fixed"):
joint_names.append(frame.name)
for child in frame.children:
joint_names.extend(Chain._get_joint_parent_frame_names(child, exclude_fixed))
return joint_names
- def get_joint_parent_frame_names(self, exclude_fixed=True):
+ def get_joint_parent_frame_names(self, exclude_fixed: bool = True) -> List[str]:
names = self._get_joint_parent_frame_names(self._root, exclude_fixed)
return sorted(set(names), key=names.index)
@staticmethod
- def _get_frame_names(frame: Frame, exclude_fixed=True) -> Sequence[str]:
- names = []
+ def _get_frame_names(frame: Frame, exclude_fixed: bool = True) -> Sequence[str]:
+ names: List[str] = []
if not (exclude_fixed and frame.joint.joint_type == "fixed"):
names.append(frame.name)
for child in frame.children:
names.extend(Chain._get_frame_names(child, exclude_fixed))
return names
- def get_frame_names(self, exclude_fixed=True):
+ def get_frame_names(self, exclude_fixed: bool = True) -> List[str]:
names = self._get_frame_names(self._root, exclude_fixed)
return sorted(set(names), key=names.index)
@staticmethod
- def _get_links(frame):
- links = [frame.link]
+ def _get_links(frame: Frame) -> List[Link]:
+ links: List[Link] = [frame.link]
for child in frame.children:
links.extend(Chain._get_links(child))
return links
- def get_links(self):
+ def get_links(self) -> List[Link]:
links = self._get_links(self._root)
return links
@staticmethod
- def _get_link_names(frame):
- link_names = [frame.link.name]
+ def _get_link_names(frame: Frame) -> List[str]:
+ link_names: List[str] = [frame.link.name]
for child in frame.children:
link_names.extend(Chain._get_link_names(child))
return link_names
- def get_link_names(self):
+ def get_link_names(self) -> List[str]:
names = self._get_link_names(self._root)
return sorted(set(names), key=names.index)
@lru_cache
- def get_frame_indices(self, *frame_names):
- return torch.tensor([self.frame_to_idx[n] for n in frame_names], dtype=torch.long, device=self.device)
-
- def print_tree(self, do_print=True):
+ def get_frame_indices(self, *frame_names: str) -> torch.Tensor:
+ return torch.tensor(
+ [self.frame_to_idx[n] for n in frame_names],
+ dtype=torch.long,
+ device=self.device,
+ )
+
+ def print_tree(self, do_print: bool = True) -> str:
tree = str(self._root)
if do_print:
print(tree)
return tree
- def forward_kinematics(self, th, frame_indices: Optional = None):
+ def forward_kinematics(
+ self,
+ th: SUPPORTED_INPUTS,
+ frame_indices: Optional[List[int]] = None,
+ ) -> Dict[str, tf.Transform3d]:
"""
Compute forward kinematics for the given joint values.
@@ -315,7 +347,7 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
rev_jnt_transform = axis_and_angle_to_matrix_44(axes_expanded, th)
pris_jnt_transform = axis_and_d_to_pris_matrix(axes_expanded, th)
- frame_transforms = {}
+ frame_transforms: Dict[int, torch.Tensor] = {}
b = th.shape[0]
for frame_idx in frame_indices:
frame_transform = torch.eye(4).to(th).unsqueeze(0).repeat(b, 1, 1)
@@ -346,19 +378,18 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
frame_transforms[frame_idx.item()] = frame_transform
- frame_names_and_transform3ds = {self.idx_to_frame[frame_idx]: tf.Transform3d(matrix=transform) for
- frame_idx, transform in frame_transforms.items()}
+ frame_names_and_transform3ds: Dict[str, tf.Transform3d] = {
+ self.idx_to_frame[frame_idx]: tf.Transform3d(matrix=transform) for frame_idx, transform in frame_transforms.items()
+ }
return frame_names_and_transform3ds
- def ensure_tensor(self, th):
+ def ensure_tensor(self, th: SUPPORTED_INPUTS) -> torch.Tensor:
"""
Converts a number of possible types into a tensor. The order of the tensor is determined by the order
of self.get_joint_parameter_names(). th must contain all joints in the entire chain.
"""
- if isinstance(th, np.ndarray):
- th = torch.tensor(th, device=self.device, dtype=self.dtype)
- elif isinstance(th, list):
+ if isinstance(th, (list, np.ndarray)):
th = torch.tensor(th, device=self.device, dtype=self.dtype)
elif isinstance(th, dict):
# convert dict to a flat, complete, tensor of all joints values. Missing joints are filled with zeros.
@@ -371,16 +402,16 @@ def ensure_tensor(self, th):
th[..., jnt_idx] = joint_position
if torch.any(torch.isnan(th)):
msg = "Missing values for the following joints:\n"
- for joint_name, th_i in zip(self.get_joint_parameter_names(), th):
+ for joint_name, _ in zip(self.get_joint_parameter_names(), th):
msg += joint_name + "\n"
raise ValueError(msg)
return th
- def get_all_frame_indices(self):
+ def get_all_frame_indices(self) -> torch.Tensor:
frame_indices = self.get_frame_indices(*self.get_frame_names(exclude_fixed=False))
return frame_indices
- def clamp(self, th):
+ def clamp(self, th: SUPPORTED_INPUTS) -> torch.Tensor:
"""
Args:
@@ -389,21 +420,21 @@ def clamp(self, th):
Returns: Always a tensor in the order of self.get_joint_parameter_names(), possibly batched.
"""
- th = self.ensure_tensor(th)
- return torch.clamp(th, self.low, self.high)
+ th_tensor = self.ensure_tensor(th)
+ return torch.clamp(th_tensor, self.low, self.high)
- def get_joint_limits(self):
+ def get_joint_limits(self) -> Tuple[List[float], List[float]]:
return self._get_joint_limits("limits")
- def get_joint_velocity_limits(self):
+ def get_joint_velocity_limits(self) -> Tuple[List[float], List[float]]:
return self._get_joint_limits("velocity_limits")
- def get_joint_effort_limits(self):
+ def get_joint_effort_limits(self) -> Tuple[List[float], List[float]]:
return self._get_joint_limits("effort_limits")
- def _get_joint_limits(self, param_name):
- low = []
- high = []
+ def _get_joint_limits(self, param_name: str) -> Tuple[List[float], List[float]]:
+ low: List[float] = []
+ high: List[float] = []
for joint in self.get_joints():
val = getattr(joint, param_name)
if val is None:
@@ -418,20 +449,24 @@ def _get_joint_limits(self, param_name):
return low, high
@staticmethod
- def _get_joints_and_child_links(frame):
+ def _get_joints_and_child_links(
+ frame: Frame,
+ ) -> Iterable[Tuple[Joint, List[Link]]]:
joint = frame.joint
- me_and_my_children = [frame.link]
+ me_and_my_children: List[Link] = [frame.link]
for child in frame.children:
recursive_child_links = yield from Chain._get_joints_and_child_links(child)
me_and_my_children.extend(recursive_child_links)
- if joint is not None and joint.joint_type != 'fixed':
+ if joint is not None and joint.joint_type != "fixed":
yield joint, me_and_my_children
return me_and_my_children
- def get_joints_and_child_links(self):
+ def get_joints_and_child_links(
+ self,
+ ) -> Iterable[Tuple[Joint, List[Link]]]:
yield from Chain._get_joints_and_child_links(self._root)
@@ -441,17 +476,23 @@ class SerialChain(Chain):
Serial chains can be generated from subsets of a Chain.
"""
- def __init__(self, chain, end_frame_name, root_frame_name="", **kwargs):
+ def __init__(
+ self,
+ chain: Chain,
+ end_frame_name: str,
+ root_frame_name: str = "",
+ **kwargs,
+ ) -> None:
root_frame = chain._root if root_frame_name == "" else chain.find_frame(root_frame_name)
if root_frame is None:
- raise ValueError("Invalid root frame name %s." % root_frame_name)
+ raise ValueError(f"Invalid root frame name {root_frame_name}.")
chain = Chain(root_frame, **kwargs)
# make a copy of those frames that includes only the chain up to the end effector
end_frame_idx = chain.get_frame_indices(end_frame_name)
ancestors = chain.parents_indices[end_frame_idx]
- frames = []
+ frames: List[Frame] = []
# first pass create copies of the ancestor nodes
for idx in ancestors:
this_frame_name = chain.idx_to_frame[idx.item()]
@@ -466,13 +507,22 @@ def __init__(self, chain, end_frame_name, root_frame_name="", **kwargs):
self._serial_frames = frames
super().__init__(frames[0], **kwargs)
- def jacobian(self, th, locations=None, **kwargs):
+ def jacobian(
+ self,
+ th: SUPPORTED_INPUTS,
+ locations: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> torch.Tensor:
if locations is not None:
locations = tf.Transform3d(pos=locations)
return jacobian.calc_jacobian(self, th, tool=locations, **kwargs)
- def forward_kinematics(self, th, end_only: bool = True):
- """ Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """
+ def forward_kinematics(
+ self,
+ th: SUPPORTED_INPUTS,
+ end_only: bool = True,
+ ) -> Union[tf.Transform3d, Dict[str, tf.Transform3d]]:
+ """Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints."""
frame_indices, th = self.convert_serial_inputs_to_chain_inputs(th, end_only)
mat = super().forward_kinematics(th, frame_indices)
@@ -482,34 +532,33 @@ def forward_kinematics(self, th, end_only: bool = True):
else:
return mat
- def convert_serial_inputs_to_chain_inputs(self, th, end_only: bool):
+ def convert_serial_inputs_to_chain_inputs(
+ self,
+ th: SUPPORTED_INPUTS,
+ end_only: bool,
+ ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
# th = self.ensure_tensor(th)
th_b = get_batch_size(th)
th_n_joints = get_n_joints(th)
if isinstance(th, list):
th = torch.tensor(th, device=self.device, dtype=self.dtype)
- if end_only:
- frame_indices = self.get_frame_indices(self._serial_frames[-1].name)
- else:
- # pass through default behavior for frame indices being None, which is currently
- # to return all frames.
- frame_indices = None
+ # pass through default behavior for frame indices being None, which is currently
+ # to return all frames.
+ frame_indices = self.get_frame_indices(self._serial_frames[-1].name) if end_only else None
+
if th_n_joints < self.n_joints:
# if th is only a partial list of joints, assume it's a list of joints for only the serial chain.
partial_th = th
- nonfixed_serial_frames = list(filter(lambda f: f.joint.joint_type != 'fixed', self._serial_frames))
+ nonfixed_serial_frames = list(filter(lambda f: f.joint.joint_type != "fixed", self._serial_frames))
if th_n_joints != len(nonfixed_serial_frames):
- raise ValueError(f'Expected {len(nonfixed_serial_frames)} joint values, got {th_n_joints}.')
+ raise ValueError(f"Expected {len(nonfixed_serial_frames)} joint values, got {th_n_joints}.")
th = torch.zeros([th_b, self.n_joints], device=self.device, dtype=self.dtype)
for i, frame in enumerate(nonfixed_serial_frames):
joint_name = frame.joint.name
- if isinstance(partial_th, dict):
- partial_th_i = partial_th[joint_name]
- else:
- partial_th_i = partial_th[..., i]
+ partial_th_i = partial_th[joint_name] if isinstance(partial_th, dict) else partial_th[..., i]
k = self.frame_to_idx[frame.name]
jnt_idx = self.joint_indices[k]
- if frame.joint.joint_type != 'fixed':
+ if frame.joint.joint_type != "fixed":
th[..., jnt_idx] = partial_th_i
return frame_indices, th
diff --git a/src/pytorch_kinematics/frame.py b/src/pytorch_kinematics/frame.py
index 0c35d84..83289e2 100644
--- a/src/pytorch_kinematics/frame.py
+++ b/src/pytorch_kinematics/frame.py
@@ -1,114 +1,133 @@
+from typing import Any, ClassVar, Iterable, List, Optional
+
import torch
import pytorch_kinematics.transforms as tf
from pytorch_kinematics.transforms import axis_and_angle_to_matrix_33
-class Visual(object):
- TYPES = ['box', 'cylinder', 'sphere', 'capsule', 'mesh']
+class Visual:
+ TYPES: ClassVar = ["box", "cylinder", "sphere", "capsule", "mesh"]
- def __init__(self, offset=None, geom_type=None, geom_param=None):
+ def __init__(
+ self,
+ offset: Optional[tf.Transform3d] = None,
+ geom_type: Optional[str] = None,
+ geom_param: Any = None,
+ ) -> None:
if offset is None:
- self.offset = None
+ self.offset: Optional[tf.Transform3d] = None
else:
self.offset = offset
- self.geom_type = geom_type
- self.geom_param = geom_param
+ self.geom_type: Optional[str] = geom_type
+ self.geom_param: Any = geom_param
- def __repr__(self):
- return "Visual(offset={0}, geom_type='{1}', geom_param={2})".format(self.offset,
- self.geom_type,
- self.geom_param)
+ def __repr__(self) -> str:
+ return f"Visual(offset={self.offset}, geom_type='{self.geom_type}', geom_param={self.geom_param})"
-class Link(object):
- def __init__(self, name=None, offset=None, visuals=()):
+class Link:
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ offset: Optional[tf.Transform3d] = None,
+ visuals: Iterable[Visual] = (),
+ ) -> None:
if offset is None:
- self.offset = None
+ self.offset: Optional[tf.Transform3d] = None
else:
self.offset = offset
- self.name = name
- self.visuals = visuals
+ self.name: Optional[str] = name
+ self.visuals: Iterable[Visual] = visuals
- def to(self, *args, **kwargs):
+ def to(self, *args, **kwargs) -> "Link":
if self.offset is not None:
self.offset = self.offset.to(*args, **kwargs)
return self
- def __repr__(self):
- return "Link(name='{0}', offset={1}, visuals={2})".format(self.name,
- self.offset,
- self.visuals)
-
-
-class Joint(object):
- TYPES = ['fixed', 'revolute', 'prismatic']
-
- def __init__(self, name=None, offset=None, joint_type='fixed', axis=(0.0, 0.0, 1.0),
- dtype=torch.float32, device="cpu", limits=None,
- velocity_limits=None, effort_limits=None):
+ def __repr__(self) -> str:
+ return f"Link(name='{self.name}', offset={self.offset}, visuals={self.visuals})"
+
+
+class Joint:
+ TYPES: ClassVar = ["fixed", "revolute", "prismatic"]
+
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ offset: Optional[tf.Transform3d] = None,
+ joint_type: str = "fixed",
+ axis: Optional[torch.Tensor] = (0.0, 0.0, 1.0),
+ dtype: torch.dtype = torch.float32,
+ device: str = "cpu",
+ limits: Optional[torch.Tensor] = None,
+ velocity_limits: Optional[torch.Tensor] = None,
+ effort_limits: Optional[torch.Tensor] = None,
+ ) -> None:
if offset is None:
- self.offset = None
+ self.offset: Optional[tf.Transform3d] = None
else:
self.offset = offset
- self.name = name
+ self.name: Optional[str] = name
if joint_type not in self.TYPES:
- raise RuntimeError("joint specified as {} type not, but we only support {}".format(joint_type, self.TYPES))
- self.joint_type = joint_type
+ raise RuntimeError(f"joint specified as {joint_type} type not, but we only support {self.TYPES}")
+ self.joint_type: str = joint_type
if axis is None:
- self.axis = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device)
+ self.axis: torch.Tensor = torch.tensor([0.0, 0.0, 1.0], dtype=dtype, device=device)
+ elif torch.is_tensor(axis):
+ self.axis = axis.clone().detach().to(dtype=dtype, device=device)
else:
- if torch.is_tensor(axis):
- self.axis = axis.clone().detach().to(dtype=dtype, device=device)
- else:
- self.axis = torch.tensor(axis, dtype=dtype, device=device)
+ self.axis = torch.tensor(axis, dtype=dtype, device=device)
# normalize axis to have norm 1 (needed for correct representation scaling with theta)
self.axis = self.axis / self.axis.norm()
- self.limits = limits
- self.velocity_limits = velocity_limits
- self.effort_limits = effort_limits
+ self.limits: Optional[torch.Tensor] = limits
+ self.velocity_limits: Optional[torch.Tensor] = velocity_limits
+ self.effort_limits: Optional[torch.Tensor] = effort_limits
- def to(self, *args, **kwargs):
+ def to(self, *args, **kwargs) -> "Joint":
self.axis = self.axis.to(*args, **kwargs)
if self.offset is not None:
self.offset = self.offset.to(*args, **kwargs)
return self
- def clamp(self, joint_position):
+ def clamp(self, joint_position: torch.Tensor) -> torch.Tensor:
if self.limits is None:
return joint_position
else:
return torch.clamp(joint_position, self.limits[0], self.limits[1])
- def __repr__(self):
- return "Joint(name='{0}', offset={1}, joint_type='{2}', axis={3})".format(self.name,
- self.offset,
- self.joint_type,
- self.axis)
+ def __repr__(self) -> str:
+ return f"Joint(name='{self.name}', offset={self.offset}, joint_type='{self.joint_type}', axis={self.axis})"
# prefix components:
-space = ' '
-branch = '│ '
+space: str = " "
+branch: str = "│ "
# pointers:
-tee = '├── '
-last = '└── '
-
-class Frame(object):
- def __init__(self, name=None, link=None, joint=None, children=None):
- self.name = 'None' if name is None else name
- self.link = link if link is not None else Link()
- self.joint = joint if joint is not None else Joint()
+tee: str = "├── "
+last: str = "└── "
+
+
+class Frame:
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ link: Optional[Link] = None,
+ joint: Optional[Joint] = None,
+ children: Optional[List["Frame"]] = None,
+ ) -> None:
+ self.name: str = "None" if name is None else name
+ self.link: Link = link if link is not None else Link()
+ self.joint: Joint = joint if joint is not None else Joint()
if children is None:
- self.children = []
+ self.children: List[Frame] = []
+ else:
+ self.children = children
- def __str__(self, prefix='', root=True):
+ def __str__(self, prefix: str = "", root: bool = True) -> str:
pointers = [tee] * (len(self.children) - 1) + [last]
- if root:
- ret = prefix + self.name + "\n"
- else:
- ret = ""
+ ret = prefix + self.name + "\n" if root else ""
for pointer, child in zip(pointers, self.children):
ret += prefix + pointer + child.name + "\n"
if child.children:
@@ -117,31 +136,35 @@ def __str__(self, prefix='', root=True):
ret += child.__str__(prefix=prefix + extension, root=False)
return ret
- def to(self, *args, **kwargs):
+ def to(self, *args, **kwargs) -> "Frame":
self.joint = self.joint.to(*args, **kwargs)
self.link = self.link.to(*args, **kwargs)
self.children = [c.to(*args, **kwargs) for c in self.children]
return self
- def add_child(self, child):
+ def add_child(self, child: "Frame") -> None:
self.children.append(child)
- def is_end(self):
- return (len(self.children) == 0)
+ def is_end(self) -> bool:
+ return len(self.children) == 0
- def get_transform(self, theta):
+ def get_transform(self, theta: torch.Tensor) -> tf.Transform3d:
dtype = self.joint.axis.dtype
d = self.joint.axis.device
- if self.joint.joint_type == 'revolute':
+ if self.joint.joint_type == "revolute":
rot = axis_and_angle_to_matrix_33(self.joint.axis, theta)
t = tf.Transform3d(rot=rot, dtype=dtype, device=d)
- elif self.joint.joint_type == 'prismatic':
+ elif self.joint.joint_type == "prismatic":
pos = theta.unsqueeze(1) * self.joint.axis
t = tf.Transform3d(pos=pos, dtype=dtype, device=d)
- elif self.joint.joint_type == 'fixed':
- t = tf.Transform3d(default_batch_size=theta.shape[0], dtype=dtype, device=d)
+ elif self.joint.joint_type == "fixed":
+ t = tf.Transform3d(
+ default_batch_size=theta.shape[0],
+ dtype=dtype,
+ device=d,
+ )
else:
- raise ValueError("Unsupported joint type %s." % self.joint.joint_type)
+ raise ValueError(f"Unsupported joint type {self.joint.joint_type}.")
if self.joint.offset is None:
return t
else:
diff --git a/src/pytorch_kinematics/ik.py b/src/pytorch_kinematics/ik.py
index 356cffc..7eca8cf 100644
--- a/src/pytorch_kinematics/ik.py
+++ b/src/pytorch_kinematics/ik.py
@@ -1,48 +1,63 @@
-from pytorch_kinematics.chain import SerialChain
-from pytorch_kinematics.transforms import Transform3d
-from pytorch_kinematics.transforms import rotation_conversions
-from typing import NamedTuple, Union, Optional, Callable
-import typing
-import torch
import inspect
-from matplotlib import pyplot as plt, cm as cm
+from typing import Callable, List, Optional, Tuple, Type, Union
+
+import torch
+from matplotlib import cm as cm
+from matplotlib import pyplot as plt
+
+from pytorch_kinematics.chain import SerialChain
+from pytorch_kinematics.transforms import Transform3d, rotation_conversions
class IKSolution:
- def __init__(self, dof, num_problems, num_retries, pos_tolerance, rot_tolerance, device="cpu"):
- self.iterations = 0
- self.device = device
- self.num_problems = num_problems
- self.num_retries = num_retries
- self.dof = dof
- self.pos_tolerance = pos_tolerance
- self.rot_tolerance = rot_tolerance
+ def __init__(
+ self,
+ dof: int,
+ num_problems: int,
+ num_retries: int,
+ pos_tolerance: float,
+ rot_tolerance: float,
+ device: str = "cpu",
+ ) -> None:
+ self.iterations: int = 0
+ self.device: str = device
+ self.num_problems: int = num_problems
+ self.num_retries: int = num_retries
+ self.dof: int = dof
+ self.pos_tolerance: float = pos_tolerance
+ self.rot_tolerance: float = rot_tolerance
M = num_problems
# N x DOF tensor of joint angles; if converged[i] is False, then solutions[i] is undefined
- self.solutions = torch.zeros((M, self.num_retries, self.dof), device=self.device)
- self.remaining = torch.ones(M, dtype=torch.bool, device=self.device)
+ self.solutions: torch.Tensor = torch.zeros((M, self.num_retries, self.dof), device=self.device)
+ self.remaining: torch.Tensor = torch.ones(M, dtype=torch.bool, device=self.device)
# M is the total number of problems
# N is the total number of attempts
# M x N tensor of position and rotation errors
- self.err_pos = torch.zeros((M, self.num_retries), device=self.device)
- self.err_rot = torch.zeros_like(self.err_pos)
+ self.err_pos: torch.Tensor = torch.zeros((M, self.num_retries), device=self.device)
+ self.err_rot: torch.Tensor = torch.zeros_like(self.err_pos)
# M x N boolean values indicating whether the solution converged (a solution could be found)
- self.converged_pos = torch.zeros((M, self.num_retries), dtype=torch.bool, device=self.device)
- self.converged_rot = torch.zeros_like(self.converged_pos)
- self.converged = torch.zeros_like(self.converged_pos)
+ self.converged_pos: torch.Tensor = torch.zeros((M, self.num_retries), dtype=torch.bool, device=self.device)
+ self.converged_rot: torch.Tensor = torch.zeros_like(self.converged_pos)
+ self.converged: torch.Tensor = torch.zeros_like(self.converged_pos)
# M whether any position and rotation converged for that problem
- self.converged_pos_any = torch.zeros_like(self.remaining)
- self.converged_rot_any = torch.zeros_like(self.remaining)
- self.converged_any = torch.zeros_like(self.remaining)
+ self.converged_pos_any: torch.Tensor = torch.zeros_like(self.remaining)
+ self.converged_rot_any: torch.Tensor = torch.zeros_like(self.remaining)
+ self.converged_any: torch.Tensor = torch.zeros_like(self.remaining)
- def update_remaining_with_keep_mask(self, keep: torch.tensor):
+ def update_remaining_with_keep_mask(self, keep: torch.Tensor) -> torch.Tensor:
self.remaining = self.remaining & keep
return self.remaining
- def update(self, q: torch.tensor, err: torch.tensor, use_keep_mask=True, keep_mask=None):
+ def update(
+ self,
+ q: torch.Tensor,
+ err: torch.Tensor,
+ use_keep_mask: bool = True,
+ keep_mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
err = err.reshape(-1, self.num_retries, 6)
err_pos = err[..., :3].norm(dim=-1)
err_rot = err[..., 3:].norm(dim=-1)
@@ -74,25 +89,58 @@ def update(self, q: torch.tensor, err: torch.tensor, use_keep_mask=True, keep_ma
# helper config sampling method
def gaussian_around_config(config: torch.Tensor, std: float) -> Callable[[int], torch.Tensor]:
- def config_sampling_method(num_configs):
- return torch.randn(num_configs, config.shape[0], dtype=config.dtype, device=config.device) * std + config
+ def config_sampling_method(num_configs: int) -> torch.Tensor:
+ return (
+ torch.randn(
+ num_configs,
+ config.shape[0],
+ dtype=config.dtype,
+ device=config.device,
+ )
+ * std
+ + config
+ )
return config_sampling_method
class LineSearch:
- def do_line_search(self, chain, q, dq, target_pos, target_wxyz, initial_dx, problem_remaining=None):
+ def do_line_search(
+ self,
+ chain: SerialChain,
+ q: torch.Tensor,
+ dq: torch.Tensor,
+ target_pos: torch.Tensor,
+ target_wxyz: torch.Tensor,
+ initial_dx: torch.Tensor,
+ problem_remaining: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
class BacktrackingLineSearch(LineSearch):
- def __init__(self, max_lr=1.0, decrease_factor=0.5, max_iterations=5, sufficient_decrease=0.01):
- self.initial_lr = max_lr
- self.decrease_factor = decrease_factor
- self.max_iterations = max_iterations
- self.sufficient_decrease = sufficient_decrease
-
- def do_line_search(self, chain, q, dq, target_pos, target_wxyz, initial_dx, problem_remaining=None):
+ def __init__(
+ self,
+ max_lr: float = 1.0,
+ decrease_factor: float = 0.5,
+ max_iterations: int = 5,
+ sufficient_decrease: float = 0.01,
+ ) -> None:
+ self.initial_lr: float = max_lr
+ self.decrease_factor: float = decrease_factor
+ self.max_iterations: int = max_iterations
+ self.sufficient_decrease: float = sufficient_decrease
+
+ def do_line_search(
+ self,
+ chain: SerialChain,
+ q: torch.Tensor,
+ dq: torch.Tensor,
+ target_pos: torch.Tensor,
+ target_wxyz: torch.Tensor,
+ initial_dx: torch.Tensor,
+ problem_remaining: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
N = target_pos.shape[0]
NM = q.shape[0]
M = NM // N
@@ -104,7 +152,7 @@ def do_line_search(self, chain, q, dq, target_pos, target_wxyz, initial_dx, prob
# don't care about the ones that are no longer remaining
remaining[~problem_remaining] = False
remaining = remaining.reshape(-1)
- for i in range(self.max_iterations):
+ for _ in range(self.max_iterations):
if not remaining.any():
break
# try stepping with this learning rate
@@ -130,19 +178,25 @@ def do_line_search(self, chain, q, dq, target_pos, target_wxyz, initial_dx, prob
class InverseKinematics:
"""Jacobian follower based inverse kinematics solver"""
- def __init__(self, serial_chain: SerialChain,
- pos_tolerance: float = 1e-3, rot_tolerance: float = 1e-2,
- retry_configs: Optional[torch.Tensor] = None, num_retries: Optional[int] = None,
- joint_limits: Optional[torch.Tensor] = None,
- config_sampling_method: Union[str, Callable[[int], torch.Tensor]] = "uniform",
- max_iterations: int = 50,
- lr: float = 0.2, line_search: Optional[LineSearch] = None,
- regularlization: float = 1e-9,
- debug=False,
- early_stopping_any_converged=False,
- early_stopping_no_improvement="any", early_stopping_no_improvement_patience=2,
- optimizer_method: Union[str, typing.Type[torch.optim.Optimizer]] = "sgd"
- ):
+ def __init__(
+ self,
+ serial_chain: SerialChain,
+ pos_tolerance: float = 1e-3,
+ rot_tolerance: float = 1e-2,
+ retry_configs: Optional[torch.Tensor] = None,
+ num_retries: Optional[int] = None,
+ joint_limits: Optional[torch.Tensor] = None,
+ config_sampling_method: Union[str, Callable[[int], torch.Tensor]] = "uniform",
+ max_iterations: int = 50,
+ lr: float = 0.2,
+ line_search: Optional[LineSearch] = None,
+ regularization: float = 1e-9,
+ debug: bool = False,
+ early_stopping_any_converged: bool = False,
+ early_stopping_no_improvement: Union[None, str, float] = "any",
+ early_stopping_no_improvement_patience: int = 2,
+ optimizer_method: Union[str, Type[torch.optim.Optimizer]] = "sgd",
+ ) -> None:
"""
:param serial_chain:
:param pos_tolerance: position tolerance in meters
@@ -154,7 +208,7 @@ def __init__(self, serial_chain: SerialChain,
:param max_iterations: maximum number of iterations to run
:param lr: learning rate
:param line_search: LineSearch object to use for line search
- :param regularlization: regularization term to add to the Jacobian
+ :param regularization: regularization term to add to the Jacobian
:param debug: whether to print debug information
:param early_stopping_any_converged: whether to stop when any of the retries for a problem converged
:param early_stopping_no_improvement: {None, "all", "any", ratio} whether to stop when no improvement is made
@@ -167,45 +221,45 @@ def __init__(self, serial_chain: SerialChain,
considering it no improvement
:param optimizer_method: either a string or a torch.optim.Optimizer class
"""
- self.chain = serial_chain
- self.dtype = serial_chain.dtype
- self.device = serial_chain.device
+ self.chain: SerialChain = serial_chain
+ self.dtype: torch.dtype = serial_chain.dtype
+ self.device: torch.device = serial_chain.device
joint_names = self.chain.get_joint_parameter_names(exclude_fixed=True)
- self.dof = len(joint_names)
- self.debug = debug
- self.early_stopping_any_converged = early_stopping_any_converged
- self.early_stopping_no_improvement = early_stopping_no_improvement
- self.early_stopping_no_improvement_patience = early_stopping_no_improvement_patience
-
- self.max_iterations = max_iterations
- self.lr = lr
- self.regularlization = regularlization
- self.optimizer_method = optimizer_method
- self.line_search = line_search
-
- self.err = None
- self.err_all = None
- self.err_min = None
- self.no_improve_counter = None
-
- self.pos_tolerance = pos_tolerance
- self.rot_tolerance = rot_tolerance
- self.initial_config = retry_configs
+ self.dof: int = len(joint_names)
+ self.debug: bool = debug
+ self.early_stopping_any_converged: bool = early_stopping_any_converged
+ self.early_stopping_no_improvement: Union[None, str, float] = early_stopping_no_improvement
+ self.early_stopping_no_improvement_patience: int = early_stopping_no_improvement_patience
+
+ self.max_iterations: int = max_iterations
+ self.lr: float = lr
+ self.regularization: float = regularization
+ self.optimizer_method: Union[str, Type[torch.optim.Optimizer]] = optimizer_method
+ self.line_search: Optional[LineSearch] = line_search
+
+ self.err: Optional[torch.Tensor] = None
+ self.err_all: Optional[torch.Tensor] = None
+ self.err_min: Optional[torch.Tensor] = None
+ self.no_improve_counter: Optional[torch.Tensor] = None
+
+ self.pos_tolerance: float = pos_tolerance
+ self.rot_tolerance: float = rot_tolerance
+ self.initial_config: Optional[torch.Tensor] = retry_configs
if retry_configs is None and num_retries is None:
raise ValueError("either initial_configs or num_retries must be specified")
# sample initial configs instead
- self.config_sampling_method = config_sampling_method
- self.joint_limits = joint_limits
+ self.config_sampling_method: Union[str, Callable[[int], torch.Tensor]] = config_sampling_method
+ self.joint_limits: Optional[torch.Tensor] = joint_limits
if retry_configs is None:
- self.initial_config = self.sample_configs(num_retries)
- else:
- if retry_configs.shape[1] != self.dof:
- raise ValueError("initial_configs must have shape (N, %d)" % self.dof)
+ self.initial_config = self.sample_configs(num_retries) # type: ignore[arg-type]
+ elif retry_configs.shape[1] != self.dof:
+ raise ValueError("initial_configs must have shape (N, %d)" % self.dof)
# could give a batch of initial configs
- self.num_retries = self.initial_config.shape[-2]
+ assert self.initial_config is not None
+ self.num_retries: int = self.initial_config.shape[-2]
- def clear(self):
+ def clear(self) -> None:
self.err = None
self.err_all = None
self.err_min = None
@@ -216,14 +270,16 @@ def sample_configs(self, num_configs: int) -> torch.Tensor:
# bound by joint_limits
if self.joint_limits is None:
raise ValueError("joint_limits must be specified if config_sampling_method is uniform")
- return torch.rand(num_configs, self.dof, device=self.device) * (
- self.joint_limits[:, 1] - self.joint_limits[:, 0]) + self.joint_limits[:, 0]
+ return (
+ torch.rand(num_configs, self.dof, device=self.device) * (self.joint_limits[:, 1] - self.joint_limits[:, 0])
+ + self.joint_limits[:, 0]
+ )
elif self.config_sampling_method == "gaussian":
return torch.randn(num_configs, self.dof, device=self.device)
elif callable(self.config_sampling_method):
return self.config_sampling_method(num_configs)
else:
- raise ValueError("invalid config_sampling_method %s" % self.config_sampling_method)
+ raise ValueError(f"invalid config_sampling_method {self.config_sampling_method}")
def solve(self, target_poses: Transform3d) -> IKSolution:
"""
@@ -234,11 +290,15 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
raise NotImplementedError()
-def delta_pose(m: torch.tensor, target_pos, target_wxyz):
+def delta_pose(
+ m: torch.Tensor,
+ target_pos: torch.Tensor,
+ target_wxyz: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Determine the error in position and rotation between the given poses and the target poses
- :param m: (N x M x 4 x 4) tensor of homogenous transforms
+ :param m: (N x M x 4 x 4) tensor of homogeneous transforms
:param target_pos:
:param target_wxyz: target orientation represented in unit quaternion
:return: (N*M, 6, 1) tensor of delta pose (dx, dy, dz, droll, dpitch, dyaw)
@@ -249,8 +309,9 @@ def delta_pose(m: torch.tensor, target_pos, target_wxyz):
# quaternion that rotates from the current orientation to the desired orientation
# inverse for unit quaternion is the conjugate
- diff_wxyz = rotation_conversions.quaternion_multiply(target_wxyz.unsqueeze(1),
- rotation_conversions.quaternion_invert(cur_wxyz))
+ diff_wxyz = rotation_conversions.quaternion_multiply(
+ target_wxyz.unsqueeze(1), rotation_conversions.quaternion_invert(cur_wxyz)
+ )
# angular velocity vector needed to correct the orientation
# if time is considered, should divide by \delta t, but doing it iteratively we can choose delta t to be 1
diff_axis_angle = rotation_conversions.quaternion_to_axis_angle(diff_wxyz)
@@ -261,14 +322,14 @@ def delta_pose(m: torch.tensor, target_pos, target_wxyz):
return dx, pos_diff, rot_diff
-def apply_mask(mask, *args):
+def apply_mask(mask: torch.Tensor, *args: torch.Tensor) -> List[torch.Tensor]:
return [a[mask] for a in args]
class PseudoInverseIK(InverseKinematics):
- def compute_dq(self, J, dx):
+ def compute_dq(self, J: torch.Tensor, dx: torch.Tensor) -> torch.Tensor:
# lambda^2*I (lambda^2 is regularization)
- reg = self.regularlization * torch.eye(6, device=self.device, dtype=self.dtype)
+ reg = self.regularization * torch.eye(6, device=self.device, dtype=self.dtype)
# JJ^T + lambda^2*I (lambda^2 is regularization)
tmpA = J @ J.transpose(1, 2) + reg
@@ -291,7 +352,14 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
# convert target rot to desired rotation about x,y,z
target_wxyz = rotation_conversions.matrix_to_quaternion(target[:, :3, :3])
- sol = IKSolution(self.dof, M, self.num_retries, self.pos_tolerance, self.rot_tolerance, device=self.device)
+ sol = IKSolution(
+ self.dof,
+ M,
+ self.num_retries,
+ self.pos_tolerance,
+ self.rot_tolerance,
+ device=self.device,
+ )
q = self.initial_config
if q.numel() == M * self.dof * self.num_retries:
@@ -303,13 +371,14 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
q = q.unsqueeze(0).repeat(M * self.num_retries, 1)
else:
raise ValueError(
- f"initial_config must have shape ({M}, {self.num_retries}, {self.dof}) or ({self.num_retries}, {self.dof})")
+ f"initial_config must have shape ({M}, {self.num_retries}, {self.dof}) or ({self.num_retries}, {self.dof})"
+ )
# for logging, let's keep track of the joint angles at each iteration
if self.debug:
- pos_errors = []
- rot_errors = []
+ pos_errors: List[torch.Tensor] = []
+ rot_errors: List[torch.Tensor] = []
- optimizer = None
+ optimizer: Optional[torch.optim.Optimizer] = None
if inspect.isclass(self.optimizer_method) and issubclass(self.optimizer_method, torch.optim.Optimizer):
q.requires_grad = True
optimizer = torch.optim.Adam([q], lr=self.lr)
@@ -339,8 +408,15 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
else:
with torch.no_grad():
if self.line_search is not None:
- lr, improvement = self.line_search.do_line_search(self.chain, q, dq, target_pos, target_wxyz,
- dx, problem_remaining=sol.remaining)
+ lr, improvement = self.line_search.do_line_search(
+ self.chain,
+ q,
+ dq,
+ target_pos,
+ target_wxyz,
+ dx,
+ problem_remaining=sol.remaining,
+ )
lr = lr.unsqueeze(1)
else:
lr = self.lr
@@ -354,28 +430,27 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
if self.early_stopping_no_improvement is not None:
if self.no_improve_counter is None:
self.no_improve_counter = torch.zeros_like(self.err)
+ elif self.err_min is None:
+ self.err_min = self.err.clone()
else:
- if self.err_min is None:
- self.err_min = self.err.clone()
- else:
- improved = self.err < self.err_min
- self.err_min[improved] = self.err[improved]
-
- self.no_improve_counter[improved] = 0
- self.no_improve_counter[~improved] += 1
-
- # those that haven't improved
- could_improve = self.no_improve_counter <= self.early_stopping_no_improvement_patience
- # consider problems, and only throw out those whose all retries cannot be improved
- could_improve = could_improve.reshape(-1, self.num_retries)
- if self.early_stopping_no_improvement == "all":
- could_improve = could_improve.all(dim=1)
- elif self.early_stopping_no_improvement == "any":
- could_improve = could_improve.any(dim=1)
- elif isinstance(self.early_stopping_no_improvement, float):
- ratio_improved = could_improve.sum(dim=1) / self.num_retries
- could_improve = ratio_improved > self.early_stopping_no_improvement
- sol.update_remaining_with_keep_mask(could_improve)
+ improved = self.err < self.err_min
+ self.err_min[improved] = self.err[improved]
+
+ self.no_improve_counter[improved] = 0
+ self.no_improve_counter[~improved] += 1
+
+ # those that haven't improved
+ could_improve = self.no_improve_counter <= self.early_stopping_no_improvement_patience
+ # consider problems, and only throw out those whose all retries cannot be improved
+ could_improve = could_improve.reshape(-1, self.num_retries)
+ if self.early_stopping_no_improvement == "all":
+ could_improve = could_improve.all(dim=1)
+ elif self.early_stopping_no_improvement == "any":
+ could_improve = could_improve.any(dim=1)
+ elif isinstance(self.early_stopping_no_improvement, float):
+ ratio_improved = could_improve.sum(dim=1) / self.num_retries
+ could_improve = ratio_improved > self.early_stopping_no_improvement
+ sol.update_remaining_with_keep_mask(could_improve)
if self.debug:
pos_errors.append(pos_diff.reshape(-1, 3).norm(dim=1))
@@ -386,7 +461,7 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
fig, ax = plt.subplots(ncols=2, figsize=(10, 5))
pos_e = torch.stack(pos_errors, dim=0).cpu()
rot_e = torch.stack(rot_errors, dim=0).cpu()
- ax[0].set_ylim(0, 1.)
+ ax[0].set_ylim(0, 1.0)
# ignore nan
ignore = torch.isnan(rot_e)
axis_max = rot_e[~ignore].max().item()
@@ -406,21 +481,23 @@ def solve(self, target_poses: Transform3d) -> IKSolution:
plt.show()
if i == self.max_iterations - 1:
+ # self.err_all is set in the last loop above
+ assert self.err_all is not None
sol.update(q, self.err_all, use_keep_mask=False)
return sol
class PseudoInverseIKWithSVD(PseudoInverseIK):
# generally slower, but allows for selective damping if needed
- def compute_dq(self, J, dx):
- # reg = self.regularlization * torch.eye(6, device=self.device, dtype=self.dtype)
+ def compute_dq(self, J: torch.Tensor, dx: torch.Tensor) -> torch.Tensor:
+ # reg = self.regularization * torch.eye(6, device=self.device, dtype=self.dtype)
U, D, Vh = torch.linalg.svd(J)
m = D.shape[1]
# tmpA = U @ (D @ D.transpose(1, 2) + reg) @ U.transpose(1, 2)
# singular_val = torch.diagonal(D)
- denom = D ** 2 + self.regularlization
+ denom = D**2 + self.regularization
prod = D / denom
# J^T (JJ^T + lambda^2I)^-1 = V @ (D @ D^T + lambda^2I)^-1 @ U^T = sum_i (d_i / (d_i^2 + lambda^2) v_i @ u_i^T)
# should be equivalent to damped least squares
diff --git a/src/pytorch_kinematics/jacobian.py b/src/pytorch_kinematics/jacobian.py
index 4ff3b32..68dd51a 100644
--- a/src/pytorch_kinematics/jacobian.py
+++ b/src/pytorch_kinematics/jacobian.py
@@ -1,9 +1,20 @@
+from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
+
import torch
from pytorch_kinematics import transforms
-def calc_jacobian(serial_chain, th, tool=None, ret_eef_pose=False):
+if TYPE_CHECKING:
+ from pytorch_kinematics.chain import SerialChain
+
+
+def calc_jacobian(
+ serial_chain: "SerialChain",
+ th: Union[torch.Tensor, Sequence[float]],
+ tool: Optional[transforms.Transform3d] = None,
+ ret_eef_pose: bool = False,
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Return robot Jacobian J in base frame (N,6,DOF) where dot{x} = J dot{q}
The first 3 rows relate the translational velocities and the
@@ -17,46 +28,47 @@ def calc_jacobian(serial_chain, th, tool=None, ret_eef_pose=False):
if not torch.is_tensor(th):
th = torch.tensor(th, dtype=serial_chain.dtype, device=serial_chain.device)
if len(th.shape) <= 1:
- N = 1
+ N: int = 1
th = th.reshape(1, -1)
else:
N = th.shape[0]
- ndof = th.shape[1]
+ ndof: int = th.shape[1]
- j_eef = torch.zeros((N, 6, ndof), dtype=serial_chain.dtype, device=serial_chain.device)
+ j_eef: torch.Tensor = torch.zeros((N, 6, ndof), dtype=serial_chain.dtype, device=serial_chain.device)
if tool is None:
- cur_transform = transforms.Transform3d(device=serial_chain.device,
- dtype=serial_chain.dtype).get_matrix().repeat(N, 1, 1)
+ cur_transform: torch.Tensor = (
+ transforms.Transform3d(device=serial_chain.device, dtype=serial_chain.dtype).get_matrix().repeat(N, 1, 1)
+ )
else:
if tool.dtype != serial_chain.dtype or tool.device != serial_chain.device:
tool = tool.to(device=serial_chain.device, copy=True, dtype=serial_chain.dtype)
cur_transform = tool.get_matrix()
- cnt = 0
+ cnt: int = 0
for f in reversed(serial_chain._serial_frames):
if f.joint.joint_type == "revolute":
cnt += 1
# cur_transform transforms a point in eef frame into a point in joint frame, i.e. p_joint = curr_transform @ p_eef
- axis_in_eef = cur_transform[:, :3, :3].transpose(1, 2) @ f.joint.axis
- eef2joint_pos_in_joint = cur_transform[:, :3, 3].unsqueeze(2)
- joint2eef_rot = cur_transform[:, :3, :3].transpose(1, 2) # transpose of rotation is inverse
- eef2joint_pos_in_eef = joint2eef_rot @ eef2joint_pos_in_joint
- position_jacobian = torch.cross(axis_in_eef, eef2joint_pos_in_eef.squeeze(2), dim=1)
+ axis_in_eef: torch.Tensor = cur_transform[:, :3, :3].transpose(1, 2) @ f.joint.axis
+ eef2joint_pos_in_joint: torch.Tensor = cur_transform[:, :3, 3].unsqueeze(2)
+ joint2eef_rot: torch.Tensor = cur_transform[:, :3, :3].transpose(1, 2) # transpose of rotation is inverse
+ eef2joint_pos_in_eef: torch.Tensor = joint2eef_rot @ eef2joint_pos_in_joint
+ position_jacobian: torch.Tensor = torch.cross(axis_in_eef, eef2joint_pos_in_eef.squeeze(2), dim=1)
j_eef[:, :, -cnt] = torch.cat((position_jacobian, axis_in_eef), dim=-1)
elif f.joint.joint_type == "prismatic":
cnt += 1
j_eef[:, :3, -cnt] = (f.joint.axis.repeat(N, 1, 1) @ cur_transform[:, :3, :3])[:, 0, :]
- cur_frame_transform = f.get_transform(th[:, -cnt]).get_matrix()
+ cur_frame_transform: torch.Tensor = f.get_transform(th[:, -cnt]).get_matrix()
cur_transform = cur_frame_transform @ cur_transform
# currently j_eef is Jacobian in end-effector frame, convert to base/world frame
- pose = serial_chain.forward_kinematics(th).get_matrix()
- rotation = pose[:, :3, :3]
- j_tr = torch.zeros((N, 6, 6), dtype=serial_chain.dtype, device=serial_chain.device)
+ pose: torch.Tensor = serial_chain.forward_kinematics(th).get_matrix()
+ rotation: torch.Tensor = pose[:, :3, :3]
+ j_tr: torch.Tensor = torch.zeros((N, 6, 6), dtype=serial_chain.dtype, device=serial_chain.device)
j_tr[:, :3, :3] = rotation
j_tr[:, 3:, 3:] = rotation
- j_w = j_tr @ j_eef
+ j_w: torch.Tensor = j_tr @ j_eef
if ret_eef_pose:
return j_w, pose
return j_w
diff --git a/src/pytorch_kinematics/mjcf.py b/src/pytorch_kinematics/mjcf.py
index 3dc1857..19e3c0c 100644
--- a/src/pytorch_kinematics/mjcf.py
+++ b/src/pytorch_kinematics/mjcf.py
@@ -1,109 +1,168 @@
-from typing import Union, Optional, Dict
+from typing import Dict, List, Optional, Union
import mujoco
from mujoco._structs import _MjModelBodyViews as MjModelBodyViews
import pytorch_kinematics.transforms as tf
-from . import chain
-from . import frame
+
+from . import chain, frame
+
# Converts from MuJoCo joint types to pytorch_kinematics joint types
-JOINT_TYPE_MAP = {
- mujoco.mjtJoint.mjJNT_HINGE: 'revolute',
- mujoco.mjtJoint.mjJNT_SLIDE: "prismatic"
+JOINT_TYPE_MAP: Dict[int, str] = {
+ mujoco.mjtJoint.mjJNT_HINGE: "revolute",
+ mujoco.mjtJoint.mjJNT_SLIDE: "prismatic",
}
-def body_to_geoms(m: mujoco.MjModel, body: MjModelBodyViews):
- # Find all geoms which have body as parent
- visuals = []
+def body_to_geoms(m: mujoco.MjModel, body: MjModelBodyViews) -> List[frame.Visual]:
+ """
+ Collect Visual objects for all geoms attached to the given MuJoCo body.
+ """
+ visuals: List[frame.Visual] = []
for geom_id in range(m.ngeom):
geom = m.geom(geom_id)
if geom.bodyid == body.id:
- visuals.append(frame.Visual(offset=tf.Transform3d(rot=geom.quat, pos=geom.pos), geom_type=geom.type,
- geom_param=geom.size))
+ visuals.append(
+ frame.Visual(
+ offset=tf.Transform3d(rot=geom.quat, pos=geom.pos),
+ geom_type=geom.type,
+ geom_param=geom.size,
+ )
+ )
return visuals
-def _build_chain_recurse(m, parent_frame, parent_body):
+def _build_chain_recurse(
+ m: mujoco.MjModel,
+ parent_frame: frame.Frame,
+ parent_body: MjModelBodyViews,
+) -> None:
+ """
+ Recursively attach children frames/links/joints to a pytorch_kinematics Frame
+ based on MuJoCo's body tree.
+ """
parent_frame.link.visuals = body_to_geoms(m, parent_body)
+
# iterate through all bodies that are children of parent_body
for body_id in range(m.nbody):
body = m.body(body_id)
if body.parentid == parent_body.id and body_id != parent_body.id:
n_joints = body.jntnum
if n_joints > 1:
- raise ValueError("composite joints not supported (could implement this if needed)")
+ raise ValueError("composite joints not supported")
+
if n_joints == 1:
- # Find the joint for this body, again assuming there's only one joint per body.
+ # single joint case
joint = m.joint(body.jntadr[0])
joint_offset = tf.Transform3d(pos=joint.pos)
- child_joint = frame.Joint(joint.name, offset=joint_offset, axis=joint.axis,
- joint_type=JOINT_TYPE_MAP[joint.type[0]],
- limits=(joint.range[0], joint.range[1]))
+ child_joint = frame.Joint(
+ name=joint.name,
+ offset=joint_offset,
+ axis=joint.axis,
+ joint_type=JOINT_TYPE_MAP[joint.type[0]],
+ limits=(joint.range[0], joint.range[1]),
+ )
else:
+ # fixed joint
child_joint = frame.Joint(body.name + "_fixed_joint")
- child_link = frame.Link(body.name, offset=tf.Transform3d(rot=body.quat, pos=body.pos))
- child_frame = frame.Frame(name=body.name, link=child_link, joint=child_joint)
- parent_frame.children = parent_frame.children + [child_frame, ]
+
+ child_link = frame.Link(
+ body.name,
+ offset=tf.Transform3d(rot=body.quat, pos=body.pos),
+ )
+ child_frame = frame.Frame(
+ name=body.name,
+ link=child_link,
+ joint=child_joint,
+ )
+
+ parent_frame.children = [*parent_frame.children, child_frame]
_build_chain_recurse(m, child_frame, body)
- # iterate through all sites that are children of parent_body
+ # iterate through all MuJoCo sites attached to this body
for site_id in range(m.nsite):
site = m.site(site_id)
if site.bodyid == parent_body.id:
- site_link = frame.Link(site.name, offset=tf.Transform3d(rot=site.quat, pos=site.pos))
- site_frame = frame.Frame(name=site.name, link=site_link)
- parent_frame.children = parent_frame.children + [site_frame, ]
-
-
-def build_chain_from_mjcf(data, body: Union[None, str, int] = None, assets:Optional[Dict[str,bytes]]=None):
+ site_link = frame.Link(
+ site.name,
+ offset=tf.Transform3d(rot=site.quat, pos=site.pos),
+ )
+ site_frame = frame.Frame(
+ name=site.name,
+ link=site_link,
+ joint=frame.Joint(), # sites are fixed
+ )
+ parent_frame.children = [*parent_frame.children, site_frame]
+
+
+def build_chain_from_mjcf(
+ data: str,
+ body: Union[None, str, int] = None,
+ assets: Optional[Dict[str, bytes]] = None,
+) -> chain.Chain:
"""
- Build a Chain object from MJCF data.
+ Build a pytorch-kinematics Chain object from MJCF XML string.
Parameters
----------
data : str
MJCF string data.
body : str or int, optional
- The name or index of the body to use as the root of the chain. If None, body idx=0 is used.
+ The name or index of the body to use as root of the chain.
+ If None, body idx=0 is used (MuJoCo worldbody root).
+ assets : dict of name → file bytes, optional
+ MJCF asset dictionary.
Returns
-------
chain.Chain
- Chain object created from MJCF.
+ The constructed robot chain.
"""
m = mujoco.MjModel.from_xml_string(data, assets=assets)
- if body is None:
- root_body = m.body(0)
- else:
- root_body = m.body(body)
- root_frame = frame.Frame(root_body.name,
- link=frame.Link(root_body.name,
- offset=tf.Transform3d(rot=root_body.quat, pos=root_body.pos)),
- joint=frame.Joint())
+
+ # Select root body
+ root_body = m.body(0) if body is None else m.body(body)
+
+ root_frame = frame.Frame(
+ name=root_body.name,
+ link=frame.Link(
+ root_body.name,
+ offset=tf.Transform3d(rot=root_body.quat, pos=root_body.pos),
+ ),
+ joint=frame.Joint(),
+ )
+
_build_chain_recurse(m, root_frame, root_body)
return chain.Chain(root_frame)
-def build_serial_chain_from_mjcf(data, end_link_name, root_link_name=""):
+def build_serial_chain_from_mjcf(
+ data: str,
+ end_link_name: str,
+ root_link_name: str = "",
+) -> chain.SerialChain:
"""
- Build a SerialChain object from MJCF data.
+ Build a SerialChain from MJCF XML.
Parameters
----------
data : str
- MJCF string data.
+ MJCF robot description (XML).
end_link_name : str
- The name of the link that is the end effector.
+ Name of the end-effector link.
root_link_name : str, optional
- The name of the root link.
+ Name of the root link. Default MJCF root if empty.
Returns
-------
chain.SerialChain
- SerialChain object created from MJCF.
+ The resulting SerialChain.
"""
mjcf_chain = build_chain_from_mjcf(data)
- serial_chain = chain.SerialChain(mjcf_chain, end_link_name, "" if root_link_name == "" else root_link_name)
+ serial_chain = chain.SerialChain(
+ mjcf_chain,
+ end_link_name,
+ "" if root_link_name == "" else root_link_name,
+ )
return serial_chain
diff --git a/src/pytorch_kinematics/sdf.py b/src/pytorch_kinematics/sdf.py
index 32ac177..526c8e8 100644
--- a/src/pytorch_kinematics/sdf.py
+++ b/src/pytorch_kinematics/sdf.py
@@ -1,24 +1,34 @@
-import torch
import math
-from .urdf_parser_py.sdf import SDF, Mesh, Cylinder, Box, Sphere
-from . import frame
-from . import chain
+from typing import Any, Dict, List, Optional, Sequence
+
+import torch
+
import pytorch_kinematics.transforms as tf
-JOINT_TYPE_MAP = {'revolute': 'revolute',
- 'prismatic': 'prismatic',
- 'fixed': 'fixed'}
+from . import chain, frame
+from .urdf_parser_py.sdf import SDF, Box, Cylinder, Mesh, Sphere
-def _convert_transform(pose):
+JOINT_TYPE_MAP: Dict[str, str] = {
+ "revolute": "revolute",
+ "prismatic": "prismatic",
+ "fixed": "fixed",
+}
+
+
+def _convert_transform(pose: Optional[Sequence[float]]) -> tf.Transform3d:
if pose is None:
return tf.Transform3d()
else:
- return tf.Transform3d(rot=tf.euler_angles_to_matrix(torch.tensor(pose[3:]), "ZYX"), pos=pose[:3])
+ # pose: [x, y, z, roll, pitch, yaw] (assumed)
+ return tf.Transform3d(
+ rot=tf.euler_angles_to_matrix(torch.tensor(pose[3:], dtype=torch.float32), "ZYX"),
+ pos=pose[:3],
+ )
-def _convert_visuals(visuals):
- vlist = []
+def _convert_visuals(visuals: Sequence[Any]) -> List[frame.Visual]:
+ vlist: List[frame.Visual] = []
for v in visuals:
v_tf = _convert_transform(v.pose)
if isinstance(v.geometry, Mesh):
@@ -27,7 +37,13 @@ def _convert_visuals(visuals):
elif isinstance(v.geometry, Cylinder):
g_type = "cylinder"
v_tf = v_tf.compose(
- tf.Transform3d(rot=tf.euler_angles_to_matrix(torch.tensor([0.5 * math.pi, 0, 0]), "ZYX")))
+ tf.Transform3d(
+ rot=tf.euler_angles_to_matrix(
+ torch.tensor([0.5 * math.pi, 0.0, 0.0], dtype=torch.float32),
+ "ZYX",
+ )
+ )
+ )
g_param = (v.geometry.radius, v.geometry.length)
elif isinstance(v.geometry, Box):
g_type = "box"
@@ -42,8 +58,12 @@ def _convert_visuals(visuals):
return vlist
-def _build_chain_recurse(root_frame, lmap, joints):
- children = []
+def _build_chain_recurse(
+ root_frame: frame.Frame,
+ lmap: Dict[str, Any],
+ joints: Sequence[Any],
+) -> List[frame.Frame]:
+ children: List[frame.Frame] = []
for j in joints:
if j.parent == root_frame.link.name:
child_frame = frame.Frame(j.child)
@@ -55,16 +75,24 @@ def _build_chain_recurse(root_frame, lmap, joints):
limits = (j.axis.limit.lower, j.axis.limit.upper)
except AttributeError:
limits = None
- child_frame.joint = frame.Joint(j.name, offset=t_p.inverse().compose(t_c),
- joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis.xyz, limits=limits)
- child_frame.link = frame.Link(link_c.name, offset=tf.Transform3d(),
- visuals=_convert_visuals(link_c.visuals))
+ child_frame.joint = frame.Joint(
+ j.name,
+ offset=t_p.inverse().compose(t_c),
+ joint_type=JOINT_TYPE_MAP[j.type],
+ axis=j.axis.xyz,
+ limits=limits,
+ )
+ child_frame.link = frame.Link(
+ link_c.name,
+ offset=tf.Transform3d(),
+ visuals=_convert_visuals(link_c.visuals),
+ )
child_frame.children = _build_chain_recurse(child_frame, lmap, joints)
children.append(child_frame)
return children
-def build_chain_from_sdf(data):
+def build_chain_from_sdf(data: str) -> chain.Chain:
"""
Build a Chain object from SDF data.
@@ -78,12 +106,12 @@ def build_chain_from_sdf(data):
chain.Chain
Chain object created from SDF.
"""
- sdf = SDF.from_xml_string(data)
+ sdf: SDF = SDF.from_xml_string(data)
robot = sdf.model
- lmap = robot.link_map
- joints = robot.joints
+ lmap: Dict[str, Any] = robot.link_map
+ joints: Sequence[Any] = robot.joints
n_joints = len(joints)
- has_root = [True for _ in range(len(joints))]
+ has_root: List[bool] = [True for _ in range(len(joints))]
for i in range(n_joints):
for j in range(i + 1, n_joints):
if joints[i].parent == joints[j].child:
@@ -94,15 +122,27 @@ def build_chain_from_sdf(data):
if has_root[i]:
root_link = lmap[joints[i].parent]
break
+
root_frame = frame.Frame(root_link.name)
root_frame.joint = frame.Joint(offset=_convert_transform(root_link.pose))
- root_frame.link = frame.Link(root_link.name, tf.Transform3d(),
- _convert_visuals(root_link.visuals))
+ root_frame.link = frame.Link(
+ root_link.name,
+ tf.Transform3d(),
+ _convert_visuals(root_link.visuals),
+ )
root_frame.children = _build_chain_recurse(root_frame, lmap, joints)
return chain.Chain(root_frame)
-def build_serial_chain_from_sdf(data, end_link_name, root_link_name=""):
+def build_serial_chain_from_sdf(
+ data: str,
+ end_link_name: str,
+ root_link_name: str = "",
+) -> chain.SerialChain:
mjcf_chain = build_chain_from_sdf(data)
- serial_chain = chain.SerialChain(mjcf_chain, end_link_name, "" if root_link_name == "" else root_link_name)
+ serial_chain = chain.SerialChain(
+ mjcf_chain,
+ end_link_name,
+ "" if root_link_name == "" else root_link_name,
+ )
return serial_chain
diff --git a/src/pytorch_kinematics/transforms/math.py b/src/pytorch_kinematics/transforms/math.py
index 9a315b7..07cbee2 100644
--- a/src/pytorch_kinematics/transforms/math.py
+++ b/src/pytorch_kinematics/transforms/math.py
@@ -78,8 +78,8 @@ def quaternion_slerp(q1: torch.Tensor, q2: torch.Tensor, t: Union[float, torch.t
def acos_linear_extrapolation(
- x: torch.Tensor,
- bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4,
+ x: torch.Tensor,
+ bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4,
) -> torch.Tensor:
"""
Implements `arccos(x)` which is linearly extrapolated outside `x`'s original
diff --git a/src/pytorch_kinematics/transforms/perturbation.py b/src/pytorch_kinematics/transforms/perturbation.py
index cccbbce..1a2b00e 100644
--- a/src/pytorch_kinematics/transforms/perturbation.py
+++ b/src/pytorch_kinematics/transforms/perturbation.py
@@ -1,26 +1,47 @@
+from typing import Optional
+
import torch
-from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_33
+
+from pytorch_kinematics.transforms.rotation_conversions import (
+ axis_and_angle_to_matrix_33,
+)
-def sample_perturbations(T, num_perturbations, radian_sigma, translation_sigma, axis_of_rotation=None,
- translation_perpendicular_to_axis_of_rotation=True):
+def sample_perturbations(
+ T: torch.Tensor,
+ num_perturbations: int,
+ radian_sigma: float,
+ translation_sigma: float,
+ axis_of_rotation: Optional[torch.Tensor] = None,
+ translation_perpendicular_to_axis_of_rotation: bool = True,
+) -> torch.Tensor:
"""
Sample perturbations around the given transform. The translation and rotation are sampled independently from
- 0 mean gaussians. The angular perturbations' directions are uniformly sampled from the unit sphere while its
- magnitude is sampled from a gaussian.
- :param T: given transform to perturb around
- :param num_perturbations: number of perturbations to sample
- :param radian_sigma: standard deviation of the gaussian angular perturbation in radians
- :param translation_sigma: standard deviation of the gaussian translation perturbation in meters / T units
- :param axis_of_rotation: if not None, the axis of rotation to sample the perturbations around
- :param translation_perpendicular_to_axis_of_rotation: if True and the axis_of_rotation is not None, the translation
- perturbations will be perpendicular to the axis of rotation
- :return: perturbed transforms; may not include the original transform
+ 0-mean Gaussians. Rotational perturbations are sampled via axis-angle with random directions unless an axis is given.
+ Parameters
+ ----------
+ T : torch.Tensor
+ Input transform of shape (..., 4, 4). Only the last two dims are used.
+ num_perturbations : int
+ Number of perturbations to sample.
+ radian_sigma : float
+ Stddev of Gaussian angular perturbation (radians).
+ translation_sigma : float
+ Stddev of Gaussian translation perturbation (meters).
+ axis_of_rotation : torch.Tensor, optional
+ If supplied, perturb around this axis (shape (3,) or (N,3)).
+ translation_perpendicular_to_axis_of_rotation : bool
+ If True, translation perturbations are forced perpendicular to axis_of_rotation.
+ Returns
+ -------
+ torch.Tensor
+ Perturbed transforms of shape (num_perturbations, 4, 4).
"""
dtype = T.dtype
device = T.device
perturbed = torch.eye(4, dtype=dtype, device=device).repeat(num_perturbations, 1, 1)
+ # Gaussian translation perturbations
delta_t = torch.randn((num_perturbations, 3), dtype=dtype, device=device) * translation_sigma
# consider sampling from the Bingham distribution
theta = torch.randn(num_perturbations, dtype=dtype, device=device) * radian_sigma
diff --git a/src/pytorch_kinematics/transforms/rotation_conversions.py b/src/pytorch_kinematics/transforms/rotation_conversions.py
index 310f697..f751ab0 100644
--- a/src/pytorch_kinematics/transforms/rotation_conversions.py
+++ b/src/pytorch_kinematics/transforms/rotation_conversions.py
@@ -7,6 +7,7 @@
import torch
import torch.nn.functional as F
+
"""
The transformation matrices returned from the functions in this file assume
the points on which the transformation will be applied are column vectors.
@@ -115,9 +116,7 @@ def matrix_to_quaternion(matrix):
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
- m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
- matrix.reshape(batch_dim + (9,)), dim=-1
- )
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(matrix.reshape((*batch_dim, 9)), dim=-1)
q_abs = _sqrt_positive_part(
torch.stack(
@@ -150,9 +149,7 @@ def matrix_to_quaternion(matrix):
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
- return quat_candidates[
- F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
- ].reshape(batch_dim + (4,))
+ return quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape((*batch_dim, 4)) # pyre-ignore[16]
def _axis_angle_rotation(axis: str, angle):
@@ -180,7 +177,7 @@ def _axis_angle_rotation(axis: str, angle):
if axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
- return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
+ return torch.stack(R_flat, -1).reshape((*angle.shape, 3, 3))
def euler_angles_to_matrix(euler_angles, convention: str):
@@ -208,9 +205,7 @@ def euler_angles_to_matrix(euler_angles, convention: str):
return functools.reduce(torch.matmul, matrices)
-def _angle_from_tan(
- axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
-):
+def _angle_from_tan(axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool):
"""
Extract the first or third Euler angle from the two members of
the matrix which are positive constant times its sine and cosine.
@@ -274,27 +269,19 @@ def matrix_to_euler_angles(matrix, convention: str):
i2 = _index_from_letter(convention[2])
tait_bryan = i0 != i2
if tait_bryan:
- central_angle = torch.asin(
- matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
- )
+ central_angle = torch.asin(matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0))
else:
central_angle = torch.acos(matrix[..., i0, i0])
o = (
- _angle_from_tan(
- convention[0], convention[1], matrix[..., i2], False, tait_bryan
- ),
+ _angle_from_tan(convention[0], convention[1], matrix[..., i2], False, tait_bryan),
central_angle,
- _angle_from_tan(
- convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
- ),
+ _angle_from_tan(convention[2], convention[1], matrix[..., i0, :], True, tait_bryan),
)
return torch.stack(o, -1)
-def random_quaternions(
- n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
-):
+def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False):
"""
Generate random quaternions representing rotations,
i.e. versors with nonnegative real part.
@@ -316,9 +303,7 @@ def random_quaternions(
return o
-def random_rotations(
- n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
-):
+def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False):
"""
Generate random rotations as 3x3 rotation matrices.
@@ -333,15 +318,11 @@ def random_rotations(
Returns:
Rotation matrices as tensor of shape (n, 3, 3).
"""
- quaternions = random_quaternions(
- n, dtype=dtype, device=device, requires_grad=requires_grad
- )
+ quaternions = random_quaternions(n, dtype=dtype, device=device, requires_grad=requires_grad)
return quaternion_to_matrix(quaternions)
-def random_rotation(
- dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
-):
+def random_rotation(dtype: Optional[torch.dtype] = None, device=None, requires_grad=False):
"""
Generate a single random 3x3 rotation matrix.
@@ -397,7 +378,7 @@ def quaternion_raw_multiply(a, b):
def quaternion_multiply(a, b):
"""
Multiply two quaternions representing rotations, returning the quaternion
- representing their composition, i.e. the versor with nonnegative real part.
+ representing their composition, i.e. the versor with nonnegative real part.
Usual torch rules for broadcasting apply.
Args:
@@ -466,7 +447,10 @@ def axis_and_d_to_pris_matrix(axis, d):
mat33 = torch.eye(3).to(axis).expand(*batch_axes, 3, 3)
pos = axis * d.unsqueeze(-1)
mat44 = torch.cat((mat33, pos.unsqueeze(-1)), -1)
- mat44 = torch.cat((mat44, torch.tensor([0.0, 0.0, 0.0, 1.0]).expand(*batch_axes, 1, 4).to(axis)), -2)
+ mat44 = torch.cat(
+ (mat44, torch.tensor([0.0, 0.0, 0.0, 1.0]).expand(*batch_axes, 1, 4).to(axis)),
+ -2,
+ )
return mat44
@@ -485,7 +469,10 @@ def axis_and_angle_to_matrix_44(axis, theta):
rot = axis_and_angle_to_matrix_33(axis, theta)
batch_shape = axis.shape[:-1]
mat44 = torch.cat((rot, torch.zeros(*batch_shape, 3, 1).to(axis)), -1)
- mat44 = torch.cat((mat44, torch.tensor([0.0, 0.0, 0.0, 1.0]).expand(*batch_shape, 1, 4).to(axis)), -2)
+ mat44 = torch.cat(
+ (mat44, torch.tensor([0.0, 0.0, 0.0, 1.0]).expand(*batch_shape, 1, 4).to(axis)),
+ -2,
+ )
return mat44
@@ -515,9 +502,14 @@ def axis_and_angle_to_matrix_33(axis, theta):
r20 = kz * kx * one_minus_c - ky * s
r21 = kz * ky * one_minus_c + kx * s
r22 = c + kz * kz * one_minus_c
- rot = torch.stack([torch.stack([r00, r01, r02], -1),
- torch.stack([r10, r11, r12], -1),
- torch.stack([r20, r21, r22], -1)], -2)
+ rot = torch.stack(
+ [
+ torch.stack([r00, r01, r02], -1),
+ torch.stack([r10, r11, r12], -1),
+ torch.stack([r20, r21, r22], -1),
+ ],
+ -2,
+ )
return rot
@@ -536,8 +528,11 @@ def axis_angle_to_matrix(axis_angle):
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
- warn('This is deprecated because it is slow. Use axis_and_angle_to_matrix_33 instead.',
- DeprecationWarning, stacklevel=2)
+ warn(
+ "This is deprecated because it is slow. Use axis_and_angle_to_matrix_33 instead.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
@@ -575,17 +570,11 @@ def axis_angle_to_quaternion(axis_angle):
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
- sin_half_angles_over_angles[~small_angles] = (
- torch.sin(half_angles[~small_angles]) / angles[~small_angles]
- )
+ sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles]
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
- sin_half_angles_over_angles[small_angles] = (
- 0.5 - (angles[small_angles] * angles[small_angles]) / 48
- )
- quaternions = torch.cat(
- [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
- )
+ sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48
+ quaternions = torch.cat([torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1)
return quaternions
@@ -609,14 +598,10 @@ def quaternion_to_axis_angle(quaternions):
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
- sin_half_angles_over_angles[~small_angles] = (
- torch.sin(half_angles[~small_angles]) / angles[~small_angles]
- )
+ sin_half_angles_over_angles[~small_angles] = torch.sin(half_angles[~small_angles]) / angles[~small_angles]
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
- sin_half_angles_over_angles[small_angles] = (
- 0.5 - (angles[small_angles] * angles[small_angles]) / 48
- )
+ sin_half_angles_over_angles[small_angles] = 0.5 - (angles[small_angles] * angles[small_angles]) / 48
return quaternions[..., 1:] / sin_half_angles_over_angles
@@ -699,36 +684,36 @@ def pos_rot_to_matrix(pos, rot):
# map axes strings to/from tuples of inner axis, parity, repetition, frame
_AXES2TUPLE = {
- 'sxyz': (0, 0, 0, 0),
- 'sxyx': (0, 0, 1, 0),
- 'sxzy': (0, 1, 0, 0),
- 'sxzx': (0, 1, 1, 0),
- 'syzx': (1, 0, 0, 0),
- 'syzy': (1, 0, 1, 0),
- 'syxz': (1, 1, 0, 0),
- 'syxy': (1, 1, 1, 0),
- 'szxy': (2, 0, 0, 0),
- 'szxz': (2, 0, 1, 0),
- 'szyx': (2, 1, 0, 0),
- 'szyz': (2, 1, 1, 0),
- 'rzyx': (0, 0, 0, 1),
- 'rxyx': (0, 0, 1, 1),
- 'ryzx': (0, 1, 0, 1),
- 'rxzx': (0, 1, 1, 1),
- 'rxzy': (1, 0, 0, 1),
- 'ryzy': (1, 0, 1, 1),
- 'rzxy': (1, 1, 0, 1),
- 'ryxy': (1, 1, 1, 1),
- 'ryxz': (2, 0, 0, 1),
- 'rzxz': (2, 0, 1, 1),
- 'rxyz': (2, 1, 0, 1),
- 'rzyz': (2, 1, 1, 1),
+ "sxyz": (0, 0, 0, 0),
+ "sxyx": (0, 0, 1, 0),
+ "sxzy": (0, 1, 0, 0),
+ "sxzx": (0, 1, 1, 0),
+ "syzx": (1, 0, 0, 0),
+ "syzy": (1, 0, 1, 0),
+ "syxz": (1, 1, 0, 0),
+ "syxy": (1, 1, 1, 0),
+ "szxy": (2, 0, 0, 0),
+ "szxz": (2, 0, 1, 0),
+ "szyx": (2, 1, 0, 0),
+ "szyz": (2, 1, 1, 0),
+ "rzyx": (0, 0, 0, 1),
+ "rxyx": (0, 0, 1, 1),
+ "ryzx": (0, 1, 0, 1),
+ "rxzx": (0, 1, 1, 1),
+ "rxzy": (1, 0, 0, 1),
+ "ryzy": (1, 0, 1, 1),
+ "rzxy": (1, 1, 0, 1),
+ "ryxy": (1, 1, 1, 1),
+ "ryxz": (2, 0, 0, 1),
+ "rzxz": (2, 0, 1, 1),
+ "rxyz": (2, 1, 0, 1),
+ "rzyz": (2, 1, 1, 1),
}
_TUPLE2AXES = {v: k for k, v in _AXES2TUPLE.items()}
-def quaternion_from_euler(rpy, axes='sxyz'):
+def quaternion_from_euler(rpy, axes="sxyz"):
"""
Return quaternion from Euler angles and axis sequence.
Taken from https://github.com/cgohlke/transformations/blob/master/transformations/transformations.py#L1238
@@ -740,7 +725,7 @@ def quaternion_from_euler(rpy, axes='sxyz'):
try:
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
except (AttributeError, KeyError):
- _TUPLE2AXES[axes] # noqa: validation
+ _TUPLE2AXES[axes]
firstaxis, parity, repetition, frame = axes
ai, aj, ak = torch.unbind(rpy, -1)
diff --git a/src/pytorch_kinematics/transforms/so3.py b/src/pytorch_kinematics/transforms/so3.py
index 1eab95c..df572a4 100644
--- a/src/pytorch_kinematics/transforms/so3.py
+++ b/src/pytorch_kinematics/transforms/so3.py
@@ -1,15 +1,20 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
-import warnings
from typing import Tuple
+
import torch
from pytorch_kinematics.transforms.math import acos_linear_extrapolation
-HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
+
+HAT_INV_SKEW_SYMMETRIC_TOL: float = 1e-5
-def so3_relative_angle(R1, R2, cos_angle: bool = False):
+def so3_relative_angle(
+ R1: torch.Tensor,
+ R2: torch.Tensor,
+ cos_angle: bool = False,
+) -> torch.Tensor:
"""
Calculates the relative angle (in radians) between pairs of
rotation matrices `R1` and `R2` with `angle = acos(0.5 * (Trace(R1 R2^T)-1))`
@@ -38,10 +43,10 @@ def so3_relative_angle(R1, R2, cos_angle: bool = False):
def so3_rotation_angle(
- R: torch.Tensor,
- eps: float = 1e-4,
- cos_angle: bool = False,
- cos_bound: float = 1e-4,
+ R: torch.Tensor,
+ eps: float = 1e-4,
+ cos_angle: bool = False,
+ cos_bound: float = 1e-4,
) -> torch.Tensor:
"""
Calculates angles (in radians) of a batch of rotation matrices `R` with
@@ -67,14 +72,13 @@ def so3_rotation_angle(
ValueError if `R` is of incorrect shape.
ValueError if `R` has an unexpected trace.
"""
-
N, dim1, dim2 = R.shape
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
- rot_trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
+ rot_trace: torch.Tensor = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
- if ((rot_trace < -1.0 - eps) + (rot_trace > 3.0 + eps)).any():
+ if ((rot_trace < -1.0 - eps) | (rot_trace > 3.0 + eps)).any():
raise ValueError("A matrix has trace outside valid range [-1-eps,3+eps].")
# phi ... rotation angle
@@ -82,14 +86,17 @@ def so3_rotation_angle(
if cos_angle:
return phi_cos
- else:
- if cos_bound > 0.0:
- return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound)
- else:
- return torch.acos(phi_cos)
+ if cos_bound > 0.0:
+ return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound)
+
+ return torch.acos(phi_cos)
-def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor:
+
+def so3_exp_map(
+ log_rot: torch.Tensor,
+ eps: float = 1e-4,
+) -> torch.Tensor:
"""
Convert a batch of logarithmic representations of rotation matrices `log_rot`
to a batch of 3x3 rotation matrices using Rodrigues formula [1].
@@ -102,16 +109,14 @@ def so3_exp_map(log_rot: torch.Tensor, eps: float = 0.0001) -> torch.Tensor:
log_rot: Batch of vectors of shape `(minibatch, 3)`.
eps: A float constant handling the conversion singularity.
Returns:
- Batch of rotation matrices of shape `(minibatch, 3, 3)`.
- Raises:
- ValueError if `log_rot` is of incorrect shape.
- [1] https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
+ (N, 3, 3) rotation matrices
"""
return _so3_exp_map(log_rot, eps=eps)[0]
def _so3_exp_map(
- log_rot: torch.Tensor, eps: float = 0.0001
+ log_rot: torch.Tensor,
+ eps: float = 1e-4,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
A helper function that computes the so3 exponential map and,
@@ -122,26 +127,31 @@ def _so3_exp_map(
if dim != 3:
raise ValueError("Input tensor shape has to be Nx3.")
- nrms = (log_rot * log_rot).sum(1)
+ nrms: torch.Tensor = (log_rot * log_rot).sum(1)
# phis ... rotation angles
- rot_angles = torch.clamp(nrms, eps).sqrt()
+ rot_angles: torch.Tensor = torch.clamp(nrms, eps).sqrt()
rot_angles_inv = 1.0 / rot_angles
- fac1 = rot_angles_inv * rot_angles.sin()
- fac2 = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
- skews = hat(log_rot)
- skews_square = torch.bmm(skews, skews)
+
+ fac1: torch.Tensor = rot_angles_inv * rot_angles.sin()
+ fac2: torch.Tensor = rot_angles_inv * rot_angles_inv * (1.0 - rot_angles.cos())
+
+ skews: torch.Tensor = hat(log_rot)
+ skews_square: torch.Tensor = torch.bmm(skews, skews)
R = (
# pyre-fixme[16]: `float` has no attribute `__getitem__`.
- fac1[:, None, None] * skews
- + fac2[:, None, None] * skews_square
- + torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
+ fac1[:, None, None] * skews
+ + fac2[:, None, None] * skews_square
+ + torch.eye(3, dtype=log_rot.dtype, device=log_rot.device)[None]
)
return R, rot_angles, skews, skews_square
-def so3_log_map(R, eps: float = 0.0001):
+def so3_log_map(
+ R: torch.Tensor,
+ eps: float = 1e-4,
+) -> torch.Tensor:
"""
Convert a batch of 3x3 rotation matrices `R`
to a batch of 3-dimensional matrix logarithms of rotation matrices
@@ -160,27 +170,22 @@ def so3_log_map(R, eps: float = 0.0001):
ValueError if `R` is of incorrect shape.
ValueError if `R` has an unexpected trace.
"""
-
N, dim1, dim2 = R.shape
if dim1 != 3 or dim2 != 3:
raise ValueError("Input has to be a batch of 3x3 Tensors.")
- phi = so3_rotation_angle(R)
-
- phi_sin = phi.sin()
+ phi: torch.Tensor = so3_rotation_angle(R)
+ phi_sin: torch.Tensor = phi.sin()
- phi_denom = (
- torch.clamp(phi_sin.abs(), eps) * phi_sin.sign()
- + (phi_sin == 0).type_as(phi) * eps
- )
+ phi_denom = torch.clamp(phi_sin.abs(), eps) * phi_sin.sign() + (phi_sin == 0).type_as(phi) * eps
log_rot_hat = (phi / (2.0 * phi_denom))[:, None, None] * (R - R.permute(0, 2, 1))
- log_rot = hat_inv(log_rot_hat)
+ log_rot: torch.Tensor = hat_inv(log_rot_hat)
return log_rot
-def hat_inv(h):
+def hat_inv(h: torch.Tensor) -> torch.Tensor:
"""
Compute the inverse Hat operator [1] of a batch of 3x3 matrices.
@@ -209,12 +214,10 @@ def hat_inv(h):
y = h[:, 0, 2]
z = h[:, 1, 0]
- v = torch.stack((x, y, z), dim=1)
-
- return v
+ return torch.stack((x, y, z), dim=1)
-def hat(v):
+def hat(v: torch.Tensor) -> torch.Tensor:
"""
Compute the Hat operator [1] of a batch of 3D vectors.
@@ -233,13 +236,11 @@ def hat(v):
[1] https://en.wikipedia.org/wiki/Hat_operator
"""
-
N, dim = v.shape
if dim != 3:
raise ValueError("Input vectors have to be 3-dimensional.")
h = v.new_zeros(N, 3, 3)
-
x, y, z = v.unbind(1)
h[:, 0, 1] = -z
diff --git a/src/pytorch_kinematics/transforms/transform3d.py b/src/pytorch_kinematics/transforms/transform3d.py
index bad7949..59c542f 100644
--- a/src/pytorch_kinematics/transforms/transform3d.py
+++ b/src/pytorch_kinematics/transforms/transform3d.py
@@ -3,14 +3,20 @@
import math
import typing
import warnings
-from typing import Optional
+from typing import Any, Iterable, Optional, Union
import torch
+from arm_pytorch_utilities import linalg
-from .rotation_conversions import _axis_angle_rotation, matrix_to_quaternion, quaternion_to_matrix, \
- euler_angles_to_matrix
from pytorch_kinematics.transforms.perturbation import sample_perturbations
-from arm_pytorch_utilities import linalg
+
+from .rotation_conversions import (
+ _axis_angle_rotation,
+ euler_angles_to_matrix,
+ matrix_to_quaternion,
+ quaternion_to_matrix,
+)
+
DEFAULT_EULER_CONVENTION = "XYZ"
@@ -144,14 +150,14 @@ class Transform3d:
"""
def __init__(
- self,
- default_batch_size=1,
- dtype: torch.dtype = torch.float32,
- device='cpu',
- matrix: Optional[torch.Tensor] = None,
- rot: Optional[typing.Iterable] = None,
- pos: Optional[typing.Iterable] = None,
- ):
+ self,
+ default_batch_size: int = 1,
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ matrix: Optional[torch.Tensor] = None,
+ rot: Optional[typing.Iterable] = None,
+ pos: Optional[typing.Iterable] = None,
+ ) -> None:
"""
Args:
default_batch_size: A positive integer representing the minibatch size
@@ -175,14 +181,12 @@ def __init__(
matrix argument, if any.
"""
if matrix is None:
- self._matrix = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).repeat(default_batch_size, 1, 1)
+ self._matrix: torch.Tensor = torch.eye(4, dtype=dtype, device=device).unsqueeze(0).repeat(default_batch_size, 1, 1)
else:
if matrix.ndim not in (2, 3):
raise ValueError('"matrix" has to be a 2- or a 3-dimensional tensor.')
if matrix.shape[-2] != 4 or matrix.shape[-1] != 4:
- raise ValueError(
- '"matrix" has to be a tensor of shape (minibatch, 4, 4)'
- )
+ raise ValueError('"matrix" has to be a tensor of shape (minibatch, 4, 4)')
# set the device from matrix
device = matrix.device
self._matrix = matrix.view(-1, 4, 4)
@@ -213,25 +217,25 @@ def __init__(
rot_h = torch.cat((rot, zeros), dim=-2).reshape(-1, 4, 3)
self._matrix = torch.cat((rot_h, self._matrix[:, :, 3].reshape(-1, 4, 1)), dim=-1)
- self._lu = None
- self.device = device
- self.dtype = self._matrix.dtype
+ self._lu: Optional[Any] = None
+ self.device: Union[str, torch.device] = device
+ self.dtype: torch.dtype = self._matrix.dtype
- def __len__(self):
+ def __len__(self) -> int:
return self.get_matrix().shape[0]
- def __getitem__(self, item):
+ def __getitem__(self, item: Any) -> "Transform3d":
return Transform3d(matrix=self.get_matrix()[item])
- def __repr__(self):
+ def __repr__(self) -> str:
m = self.get_matrix()
pos = m[:, :3, 3]
rot = matrix_to_quaternion(m[:, :3, :3])
- return "Transform3d(rot={}, pos={})".format(rot, pos).replace('\n ', '')
+ return f"Transform3d(rot={rot}, pos={pos})".replace("\n ", "")
- def compose(self, *others):
+ def compose(self, *others: "Transform3d") -> "Transform3d":
"""
- Return a new Transform3d with the tranforms to compose stored as
+ Return a new Transform3d with the transforms to compose stored as
an internal list.
Args:
@@ -240,7 +244,6 @@ def compose(self, *others):
Returns:
A new Transform3d with the stored transforms
"""
-
mat = self._matrix
for other in others:
mat = _broadcast_bmm(mat, other.get_matrix())
@@ -248,21 +251,20 @@ def compose(self, *others):
out = Transform3d(device=self.device, dtype=self.dtype, matrix=mat)
return out
- def get_matrix(self):
+ def get_matrix(self) -> torch.Tensor:
"""
Return the Nx4x4 homogeneous transformation matrix represented by this object.
"""
return self._matrix
- def _get_matrix_inverse(self):
+ def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
-
return self._invert_transformation_matrix(self._matrix)
@staticmethod
- def _invert_transformation_matrix(T):
+ def _invert_transformation_matrix(T: torch.Tensor) -> torch.Tensor:
"""
Invert homogeneous transformation matrix.
"""
@@ -273,7 +275,7 @@ def _invert_transformation_matrix(T):
Tinv[:, :3, 3:] = -Tinv[:, :3, :3] @ t.unsqueeze(-1)
return Tinv
- def inverse(self, invert_composed: bool = False):
+ def inverse(self, invert_composed: bool = False) -> "Transform3d":
"""
Returns a new Transform3D object that represents an inverse of the
current transformation.
@@ -287,18 +289,21 @@ def inverse(self, invert_composed: bool = False):
"""
i_matrix = self._get_matrix_inverse()
-
tinv = Transform3d(matrix=i_matrix, device=self.device)
-
return tinv
- def stack(self, *others):
- transforms = [self] + list(others)
+ def stack(self, *others: "Transform3d") -> "Transform3d":
+ transforms = [self, *list(others)]
matrix = torch.cat([t._matrix for t in transforms], dim=0)
out = Transform3d(matrix=matrix, device=self.device, dtype=self.dtype)
return out
- def transform_points(self, points, eps: Optional[float] = None, batch_to_batch=False):
+ def transform_points(
+ self,
+ points: torch.Tensor,
+ eps: Optional[float] = None,
+ batch_to_batch: bool = False,
+ ) -> torch.Tensor:
"""
Use this transform to transform a set of 3D points. Assumes row major
ordering of the input points.
@@ -306,7 +311,7 @@ def transform_points(self, points, eps: Optional[float] = None, batch_to_batch=F
Args:
points: Tensor of shape (P, 3) or (N, P, 3)
eps: If eps!=None, the argument is used to clamp the
- last coordinate before peforming the final division.
+ last coordinate before performing the final division.
The clamping corresponds to:
last_coord := (last_coord.sign() + (last_coord==0)) *
torch.clamp(last_coord.abs(), eps),
@@ -348,7 +353,7 @@ def transform_points(self, points, eps: Optional[float] = None, batch_to_batch=F
return points_out
- def transform_normals(self, normals, batch_to_batch=False):
+ def transform_normals(self, normals: torch.Tensor, batch_to_batch: bool = False) -> torch.Tensor:
"""
Use this transform to transform a set of normal vectors.
@@ -366,10 +371,7 @@ def transform_normals(self, normals, batch_to_batch=False):
raise ValueError(msg % (normals.shape,))
mat = self.inverse().get_matrix()[:, :3, :3]
- if batch_to_batch:
- normals_out = linalg.batch_batch_product(normals, mat)
- else:
- normals_out = _broadcast_bmm(normals, mat)
+ normals_out = linalg.batch_batch_product(normals, mat) if batch_to_batch else _broadcast_bmm(normals, mat)
# This doesn't pass unit tests. TODO investigate further
# if self._lu is None:
@@ -383,7 +385,7 @@ def transform_normals(self, normals, batch_to_batch=False):
return normals_out
- def transform_shape_operator(self, shape_operators):
+ def transform_shape_operator(self, shape_operators: torch.Tensor) -> torch.Tensor:
"""
Use this transform to transform a set of shape_operator (or Weingarten map).
This is the hessian of a signed-distance, i.e. gradient of a normal vector.
@@ -408,19 +410,24 @@ def transform_shape_operator(self, shape_operators):
return shape_operators_out
- def translate(self, *args, **kwargs):
+ def translate(self, *args: Any, **kwargs: Any) -> "Transform3d":
return self.compose(Translate(device=self.device, *args, **kwargs))
- def scale(self, *args, **kwargs):
+ def scale(self, *args: Any, **kwargs: Any) -> "Transform3d":
return self.compose(Scale(device=self.device, *args, **kwargs))
- def rotate(self, *args, **kwargs):
+ def rotate(self, *args: Any, **kwargs: Any) -> "Transform3d":
return self.compose(Rotate(device=self.device, *args, **kwargs))
- def rotate_axis_angle(self, *args, **kwargs):
+ def rotate_axis_angle(self, *args: Any, **kwargs: Any) -> "Transform3d":
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
- def sample_perturbations(self, num_perturbations, radian_sigma, translation_sigma):
+ def sample_perturbations(
+ self,
+ num_perturbations: int,
+ radian_sigma: float,
+ translation_sigma: float,
+ ) -> "Transform3d":
mat = self.get_matrix()
if mat.shape[0] == 1:
mat = mat[0]
@@ -428,7 +435,7 @@ def sample_perturbations(self, num_perturbations, radian_sigma, translation_sigm
out = Transform3d(matrix=all_mats)
return out
- def clone(self):
+ def clone(self) -> "Transform3d":
"""
Deep copy of Transforms object. All internal tensors are cloned
individually.
@@ -442,7 +449,12 @@ def clone(self):
other._matrix = self._matrix.clone()
return other
- def to(self, device, copy: bool = False, dtype=None):
+ def to(
+ self,
+ device: Union[str, torch.device],
+ copy: bool = False,
+ dtype: Optional[torch.dtype] = None,
+ ) -> "Transform3d":
"""
Match functionality of torch.Tensor.to()
If copy = True or the self Tensor is on a different device, the
@@ -467,15 +479,22 @@ def to(self, device, copy: bool = False, dtype=None):
other._matrix = self._matrix.to(device=device, dtype=dtype)
return other
- def cpu(self):
+ def cpu(self) -> "Transform3d":
return self.to(torch.device("cpu"))
- def cuda(self):
+ def cuda(self) -> "Transform3d":
return self.to(torch.device("cuda"))
class Translate(Transform3d):
- def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
+ def __init__(
+ self,
+ x: Union[float, torch.Tensor],
+ y: Optional[Union[float, torch.Tensor]] = None,
+ z: Optional[Union[float, torch.Tensor]] = None,
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> None:
"""
Create a new Transform3d representing 3D translations.
@@ -498,7 +517,7 @@ def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
mat[:, :3, 3] = xyz
self._matrix = mat
- def _get_matrix_inverse(self):
+ def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
@@ -509,7 +528,14 @@ def _get_matrix_inverse(self):
class Scale(Transform3d):
- def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
+ def __init__(
+ self,
+ x: Union[float, torch.Tensor],
+ y: Optional[Union[float, torch.Tensor]] = None,
+ z: Optional[Union[float, torch.Tensor]] = None,
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ ) -> None:
"""
A Transform3d representing a scaling operation, with different scale
factors along each coordinate axis.
@@ -538,7 +564,7 @@ def __init__(self, x, y=None, z=None, dtype=torch.float32, device: str = "cpu"):
mat[:, 2, 2] = xyz[:, 2]
self._matrix = mat
- def _get_matrix_inverse(self):
+ def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
@@ -550,8 +576,12 @@ def _get_matrix_inverse(self):
class Rotate(Transform3d):
def __init__(
- self, R, dtype=torch.float32, device: str = "cpu", orthogonal_tol: float = 1e-5
- ):
+ self,
+ R: Union[torch.Tensor, Iterable[Any]],
+ dtype: torch.dtype = torch.float32,
+ device: Union[str, torch.device] = "cpu",
+ orthogonal_tol: float = 1e-5,
+ ) -> None:
"""
Create a new Transform3d representing 3D rotation using a rotation
matrix as the input.
@@ -583,7 +613,7 @@ def __init__(
mat[:, :3, :3] = R
self._matrix = mat
- def _get_matrix_inverse(self):
+ def _get_matrix_inverse(self) -> torch.Tensor:
"""
Return the inverse of self._matrix.
"""
@@ -592,13 +622,13 @@ def _get_matrix_inverse(self):
class RotateAxisAngle(Rotate):
def __init__(
- self,
- angle,
- axis: str = "X",
- degrees: bool = True,
- dtype=torch.float64,
- device: str = "cpu",
- ):
+ self,
+ angle: Union[float, torch.Tensor],
+ axis: str = "X",
+ degrees: bool = True,
+ dtype: torch.dtype = torch.float64,
+ device: Union[str, torch.device] = "cpu",
+ ) -> None:
"""
Create a new Transform3d representing 3D rotation about an axis
by an angle.
@@ -630,7 +660,11 @@ def __init__(
super().__init__(device=device, R=R)
-def _handle_coord(c, dtype, device):
+def _handle_coord(
+ c: Union[float, torch.Tensor],
+ dtype: torch.dtype,
+ device: Union[str, torch.device],
+) -> torch.Tensor:
"""
Helper function for _handle_input.
@@ -647,7 +681,15 @@ def _handle_coord(c, dtype, device):
return c
-def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = False):
+def _handle_input(
+ x: Union[float, torch.Tensor],
+ y: Optional[Union[float, torch.Tensor]],
+ z: Optional[Union[float, torch.Tensor]],
+ dtype: torch.dtype,
+ device: Union[str, torch.device],
+ name: str,
+ allow_singleton: bool = False,
+) -> torch.Tensor:
"""
Helper function to handle parsing logic for building transforms. The output
is always a tensor of shape (N, 3), but there are several types of allowed
@@ -681,7 +723,7 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
msg = "Expected tensor of shape (N, 3); got %r (in %s)"
raise ValueError(msg % (x.shape, name))
if y is not None or z is not None:
- msg = "Expected y and z to be None (in %s)" % name
+ msg = f"Expected y and z to be None (in {name})"
raise ValueError(msg)
return x
@@ -697,14 +739,19 @@ def _handle_input(x, y, z, dtype, device, name: str, allow_singleton: bool = Fal
N = max(sizes)
for c in xyz:
if c.shape[0] != 1 and c.shape[0] != N:
- msg = "Got non-broadcastable sizes %r (in %s)" % (sizes, name)
+ msg = f"Got non-broadcastable sizes {sizes!r} (in {name})"
raise ValueError(msg)
xyz = [c.expand(N) for c in xyz]
xyz = torch.stack(xyz, dim=1)
return xyz
-def _handle_angle_input(x, dtype, device: str, name: str):
+def _handle_angle_input(
+ x: Union[float, torch.Tensor],
+ dtype: torch.dtype,
+ device: Union[str, torch.device],
+ name: str,
+) -> torch.Tensor:
"""
Helper function for building a rotation function using angles.
The output is always of shape (N,).
@@ -721,7 +768,7 @@ def _handle_angle_input(x, dtype, device: str, name: str):
return _handle_coord(x, dtype, device)
-def _broadcast_bmm(a, b):
+def _broadcast_bmm(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
"""
Batch multiply two matrices and broadcast if necessary.
@@ -730,7 +777,7 @@ def _broadcast_bmm(a, b):
b: torch tensor of shape (N, K, K)
Returns:
- a and b broadcast multipled. The output batch dimension is max(N, M).
+ a and b broadcast multiplied. The output batch dimension is max(N, M).
To broadcast transforms across a batch dimension if M != N then
expect that either M = 1 or N = 1. The tensor with batch dimension 1 is
@@ -749,7 +796,7 @@ def _broadcast_bmm(a, b):
return a.bmm(b)
-def _check_valid_rotation_matrix(R, tol: float = 1e-7):
+def _check_valid_rotation_matrix(R: torch.Tensor, tol: float = 1e-7) -> None:
"""
Determine if R is a valid rotation matrix by checking it satisfies the
following conditions:
@@ -773,4 +820,3 @@ def _check_valid_rotation_matrix(R, tol: float = 1e-7):
if not (orthogonal and no_distortion):
msg = "R is not a valid rotation matrix"
warnings.warn(msg)
- return
diff --git a/src/pytorch_kinematics/urdf.py b/src/pytorch_kinematics/urdf.py
index 85d99fa..2c3dca5 100644
--- a/src/pytorch_kinematics/urdf.py
+++ b/src/pytorch_kinematics/urdf.py
@@ -1,24 +1,33 @@
-from .urdf_parser_py.urdf import URDF, Mesh, Cylinder, Box, Sphere
-from . import frame
-from . import chain
+from typing import Any, Dict, List, Optional, Sequence
+
import torch
+
import pytorch_kinematics.transforms as tf
-JOINT_TYPE_MAP = {'revolute': 'revolute',
- 'continuous': 'revolute',
- 'prismatic': 'prismatic',
- 'fixed': 'fixed'}
+from . import chain, frame
+from .urdf_parser_py.urdf import URDF, Box, Cylinder, Mesh, Sphere
+
+
+JOINT_TYPE_MAP: Dict[str, str] = {
+ "revolute": "revolute",
+ "continuous": "revolute",
+ "prismatic": "prismatic",
+ "fixed": "fixed",
+}
-def _convert_transform(origin):
+def _convert_transform(origin: Optional[Any]) -> tf.Transform3d:
if origin is None:
return tf.Transform3d()
else:
rpy = torch.tensor(origin.rpy, dtype=torch.float32, device="cpu")
- return tf.Transform3d(rot=tf.quaternion_from_euler(rpy, "sxyz"), pos=origin.xyz)
+ return tf.Transform3d(
+ rot=tf.quaternion_from_euler(rpy, "sxyz"),
+ pos=origin.xyz,
+ )
-def _convert_visual(visual):
+def _convert_visual(visual: Optional[Any]) -> frame.Visual:
if visual is None or visual.geometry is None:
return frame.Visual()
else:
@@ -41,8 +50,12 @@ def _convert_visual(visual):
return frame.Visual(v_tf, g_type, g_param)
-def _build_chain_recurse(root_frame, lmap, joints):
- children = []
+def _build_chain_recurse(
+ root_frame: frame.Frame,
+ lmap: Dict[str, Any],
+ joints: Sequence[Any],
+) -> List[frame.Frame]:
+ children: List[frame.Frame] = []
for j in joints:
if j.parent == root_frame.link.name:
try:
@@ -58,19 +71,30 @@ def _build_chain_recurse(root_frame, lmap, joints):
effort_limits = (-j.limit.effort, j.limit.effort)
except AttributeError:
effort_limits = None
+
child_frame = frame.Frame(j.child)
- child_frame.joint = frame.Joint(j.name, offset=_convert_transform(j.origin),
- joint_type=JOINT_TYPE_MAP[j.type], axis=j.axis, limits=limits,
- velocity_limits=velocity_limits, effort_limits=effort_limits)
+ child_frame.joint = frame.Joint(
+ j.name,
+ offset=_convert_transform(j.origin),
+ joint_type=JOINT_TYPE_MAP[j.type],
+ axis=j.axis,
+ limits=limits,
+ velocity_limits=velocity_limits,
+ effort_limits=effort_limits,
+ )
+
link = lmap[j.child]
- child_frame.link = frame.Link(link.name, offset=_convert_transform(link.origin),
- visuals=[_convert_visual(link.visual)])
+ child_frame.link = frame.Link(
+ link.name,
+ offset=_convert_transform(link.origin),
+ visuals=[_convert_visual(link.visual)],
+ )
child_frame.children = _build_chain_recurse(child_frame, lmap, joints)
children.append(child_frame)
return children
-def build_chain_from_urdf(data):
+def build_chain_from_urdf(data: str) -> chain.Chain:
"""
Build a Chain object from URDF data.
@@ -98,14 +122,14 @@ def build_chain_from_urdf(data):
>>> chain = pk.build_chain_from_urdf(data)
>>> print(chain)
link1_frame
- link2_frame
+ link2_frame
"""
- robot = URDF.from_xml_string(data)
- lmap = robot.link_map
- joints = robot.joints
+ robot: URDF = URDF.from_xml_string(data)
+ lmap: Dict[str, Any] = robot.link_map
+ joints: Sequence[Any] = robot.joints
n_joints = len(joints)
- has_root = [True for _ in range(len(joints))]
+ has_root: List[bool] = [True for _ in range(len(joints))]
for i in range(n_joints):
for j in range(i + 1, n_joints):
if joints[i].parent == joints[j].child:
@@ -116,17 +140,28 @@ def build_chain_from_urdf(data):
if has_root[i]:
root_link = lmap[joints[i].parent]
break
+ else:
+ # Fallback (should rarely happen)
+ root_link = lmap[joints[0].parent]
+
root_frame = frame.Frame(root_link.name)
root_frame.joint = frame.Joint()
- root_frame.link = frame.Link(root_link.name, _convert_transform(root_link.origin),
- [_convert_visual(root_link.visual)])
+ root_frame.link = frame.Link(
+ root_link.name,
+ _convert_transform(root_link.origin),
+ [_convert_visual(root_link.visual)],
+ )
root_frame.children = _build_chain_recurse(root_frame, lmap, joints)
return chain.Chain(root_frame)
-def build_serial_chain_from_urdf(data, end_link_name, root_link_name=""):
+def build_serial_chain_from_urdf(
+ data: str,
+ end_link_name: str,
+ root_link_name: str = "",
+) -> chain.SerialChain:
"""
- Build a SerialChain object from urdf data.
+ Build a SerialChain object from URDF data.
Parameters
----------
@@ -143,4 +178,4 @@ def build_serial_chain_from_urdf(data, end_link_name, root_link_name=""):
SerialChain object created from URDF.
"""
urdf_chain = build_chain_from_urdf(data)
- return chain.SerialChain(urdf_chain, end_link_name, root_link_name or '')
+ return chain.SerialChain(urdf_chain, end_link_name, root_link_name or "")
diff --git a/src/pytorch_kinematics/urdf_parser_py/sdf.py b/src/pytorch_kinematics/urdf_parser_py/sdf.py
index 74d8fc7..a52f113 100644
--- a/src/pytorch_kinematics/urdf_parser_py/sdf.py
+++ b/src/pytorch_kinematics/urdf_parser_py/sdf.py
@@ -1,16 +1,19 @@
-from .xml_reflection.basics import *
+from typing import ClassVar
+
from . import xml_reflection as xmlr
+from .xml_reflection.basics import node_add, xml_children
+
# What is the scope of plugins? Model, World, Sensor?
-xmlr.start_namespace('sdf')
+xmlr.start_namespace("sdf")
-name_attribute = xmlr.Attribute('name', str, False)
-pose_element = xmlr.Element('pose', 'vector6', False)
+name_attribute = xmlr.Attribute("name", str, False)
+pose_element = xmlr.Element("pose", "vector6", False)
class Inertia(xmlr.Object):
- KEYS = ['ixx', 'ixy', 'ixz', 'iyy', 'iyz', 'izz']
+ KEYS: ClassVar = ["ixx", "ixy", "ixz", "iyy", "iyz", "izz"]
def __init__(self, ixx=0.0, ixy=0.0, ixz=0.0, iyy=0.0, iyz=0.0, izz=0.0):
self.ixx = ixx
@@ -24,11 +27,11 @@ def to_matrix(self):
return [
[self.ixx, self.ixy, self.ixz],
[self.ixy, self.iyy, self.iyz],
- [self.ixz, self.iyz, self.izz]]
+ [self.ixz, self.iyz, self.izz],
+ ]
-xmlr.reflect(Inertia,
- params=[xmlr.Element(key, float) for key in Inertia.KEYS])
+xmlr.reflect(Inertia, params=[xmlr.Element(key, float) for key in Inertia.KEYS])
# Pretty much copy-paste... Better method?
@@ -42,11 +45,14 @@ def __init__(self, mass=0.0, inertia=None, pose=None):
self.pose = pose
-xmlr.reflect(Inertial, params=[
- xmlr.Element('mass', float),
- xmlr.Element('inertia', Inertia),
- pose_element
-])
+xmlr.reflect(
+ Inertial,
+ params=[
+ xmlr.Element("mass", float),
+ xmlr.Element("inertia", Inertia),
+ pose_element,
+ ],
+)
class Box(xmlr.Object):
@@ -54,9 +60,7 @@ def __init__(self, size=None):
self.size = size
-xmlr.reflect(Box, tag='box', params=[
- xmlr.Element('size', 'vector3')
-])
+xmlr.reflect(Box, tag="box", params=[xmlr.Element("size", "vector3")])
class Cylinder(xmlr.Object):
@@ -65,10 +69,11 @@ def __init__(self, radius=0.0, length=0.0):
self.length = length
-xmlr.reflect(Cylinder, tag='cylinder', params=[
- xmlr.Element('radius', float),
- xmlr.Element('length', float)
-])
+xmlr.reflect(
+ Cylinder,
+ tag="cylinder",
+ params=[xmlr.Element("radius", float), xmlr.Element("length", float)],
+)
class Sphere(xmlr.Object):
@@ -76,9 +81,7 @@ def __init__(self, radius=0.0):
self.radius = radius
-xmlr.reflect(Sphere, tag='sphere', params=[
- xmlr.Element('radius', float)
-])
+xmlr.reflect(Sphere, tag="sphere", params=[xmlr.Element("radius", float)])
class Mesh(xmlr.Object):
@@ -87,24 +90,26 @@ def __init__(self, filename=None, scale=None):
self.scale = scale
-xmlr.reflect(Mesh, tag='mesh', params=[
- xmlr.Element('filename', str),
- xmlr.Element('scale', 'vector3', required=False)
-])
+xmlr.reflect(
+ Mesh,
+ tag="mesh",
+ params=[
+ xmlr.Element("filename", str),
+ xmlr.Element("scale", "vector3", required=False),
+ ],
+)
class GeometricType(xmlr.ValueType):
def __init__(self):
- self.factory = xmlr.FactoryType('geometric', {
- 'box': Box,
- 'cylinder': Cylinder,
- 'sphere': Sphere,
- 'mesh': Mesh
- })
+ self.factory = xmlr.FactoryType(
+ "geometric",
+ {"box": Box, "cylinder": Cylinder, "sphere": Sphere, "mesh": Mesh},
+ )
def from_xml(self, node, path):
children = xml_children(node)
- assert len(children) == 1, 'One element only for geometric'
+ assert len(children) == 1, "One element only for geometric"
return self.factory.from_xml(children[0], path=path)
def write_xml(self, node, obj):
@@ -113,7 +118,7 @@ def write_xml(self, node, obj):
obj.write_xml(child)
-xmlr.add_type('geometric', GeometricType())
+xmlr.add_type("geometric", GeometricType())
class Script(xmlr.Object):
@@ -122,10 +127,11 @@ def __init__(self, uri=None, name=None):
self.name = name
-xmlr.reflect(Script, tag='script', params=[
- xmlr.Element('name', str, False),
- xmlr.Element('uri', str, False)
-])
+xmlr.reflect(
+ Script,
+ tag="script",
+ params=[xmlr.Element("name", str, False), xmlr.Element("uri", str, False)],
+)
class Material(xmlr.Object):
@@ -134,10 +140,11 @@ def __init__(self, name=None, script=None):
self.script = script
-xmlr.reflect(Material, tag='material', params=[
- name_attribute,
- xmlr.Element('script', Script, False)
-])
+xmlr.reflect(
+ Material,
+ tag="material",
+ params=[name_attribute, xmlr.Element("script", Script, False)],
+)
class Visual(xmlr.Object):
@@ -147,12 +154,16 @@ def __init__(self, name=None, geometry=None, pose=None):
self.pose = pose
-xmlr.reflect(Visual, tag='visual', params=[
- name_attribute,
- xmlr.Element('geometry', 'geometric'),
- xmlr.Element('material', Material, False),
- pose_element
-])
+xmlr.reflect(
+ Visual,
+ tag="visual",
+ params=[
+ name_attribute,
+ xmlr.Element("geometry", "geometric"),
+ xmlr.Element("material", Material, False),
+ pose_element,
+ ],
+)
class Collision(xmlr.Object):
@@ -162,11 +173,11 @@ def __init__(self, name=None, geometry=None, pose=None):
self.pose = pose
-xmlr.reflect(Collision, tag='collision', params=[
- name_attribute,
- xmlr.Element('geometry', 'geometric'),
- pose_element
-])
+xmlr.reflect(
+ Collision,
+ tag="collision",
+ params=[name_attribute, xmlr.Element("geometry", "geometric"), pose_element],
+)
class Dynamics(xmlr.Object):
@@ -175,10 +186,14 @@ def __init__(self, damping=None, friction=None):
self.friction = friction
-xmlr.reflect(Dynamics, tag='dynamics', params=[
- xmlr.Element('damping', float, False),
- xmlr.Element('friction', float, False)
-])
+xmlr.reflect(
+ Dynamics,
+ tag="dynamics",
+ params=[
+ xmlr.Element("damping", float, False),
+ xmlr.Element("friction", float, False),
+ ],
+)
class Limit(xmlr.Object):
@@ -187,35 +202,47 @@ def __init__(self, lower=None, upper=None):
self.upper = upper
-xmlr.reflect(Limit, tag='limit', params=[
- xmlr.Element('lower', float, False),
- xmlr.Element('upper', float, False)
-])
+xmlr.reflect(
+ Limit,
+ tag="limit",
+ params=[xmlr.Element("lower", float, False), xmlr.Element("upper", float, False)],
+)
class Axis(xmlr.Object):
- def __init__(self, xyz=None, limit=None, dynamics=None,
- use_parent_model_frame=None):
+ def __init__(self, xyz=None, limit=None, dynamics=None, use_parent_model_frame=None):
self.xyz = xyz
self.limit = limit
self.dynamics = dynamics
self.use_parent_model_frame = use_parent_model_frame
-xmlr.reflect(Axis, tag='axis', params=[
- xmlr.Element('xyz', 'vector3'),
- xmlr.Element('limit', Limit, False),
- xmlr.Element('dynamics', Dynamics, False),
- xmlr.Element('use_parent_model_frame', bool, False)
-])
+xmlr.reflect(
+ Axis,
+ tag="axis",
+ params=[
+ xmlr.Element("xyz", "vector3"),
+ xmlr.Element("limit", Limit, False),
+ xmlr.Element("dynamics", Dynamics, False),
+ xmlr.Element("use_parent_model_frame", bool, False),
+ ],
+)
class Joint(xmlr.Object):
- TYPES = ['unknown', 'revolute', 'gearbox', 'revolute2',
- 'prismatic', 'ball', 'screw', 'universal', 'fixed']
-
- def __init__(self, name=None, parent=None, child=None, joint_type=None,
- axis=None, pose=None):
+ TYPES: ClassVar = [
+ "unknown",
+ "revolute",
+ "gearbox",
+ "revolute2",
+ "prismatic",
+ "ball",
+ "screw",
+ "universal",
+ "fixed",
+ ]
+
+ def __init__(self, name=None, parent=None, child=None, joint_type=None, axis=None, pose=None):
self.aggregate_init()
self.name = name
self.parent = parent
@@ -226,20 +253,26 @@ def __init__(self, name=None, parent=None, child=None, joint_type=None,
# Aliases
@property
- def joint_type(self): return self.type
+ def joint_type(self):
+ return self.type
@joint_type.setter
- def joint_type(self, value): self.type = value
+ def joint_type(self, value):
+ self.type = value
-xmlr.reflect(Joint, tag='joint', params=[
- name_attribute,
- xmlr.Attribute('type', str, False),
- xmlr.Element('axis', Axis),
- xmlr.Element('parent', str),
- xmlr.Element('child', str),
- pose_element
-])
+xmlr.reflect(
+ Joint,
+ tag="joint",
+ params=[
+ name_attribute,
+ xmlr.Attribute("type", str, False),
+ xmlr.Element("axis", Axis),
+ xmlr.Element("parent", str),
+ xmlr.Element("child", str),
+ pose_element,
+ ],
+)
class Link(xmlr.Object):
@@ -253,14 +286,18 @@ def __init__(self, name=None, pose=None, inertial=None, kinematic=False):
self.collisions = []
-xmlr.reflect(Link, tag='link', params=[
- name_attribute,
- xmlr.Element('inertial', Inertial),
- xmlr.Attribute('kinematic', bool, False),
- xmlr.AggregateElement('visual', Visual, var='visuals'),
- xmlr.AggregateElement('collision', Collision, var='collisions'),
- pose_element
-])
+xmlr.reflect(
+ Link,
+ tag="link",
+ params=[
+ name_attribute,
+ xmlr.Element("inertial", Inertial),
+ xmlr.Attribute("kinematic", bool, False),
+ xmlr.AggregateElement("visual", Visual, var="visuals"),
+ xmlr.AggregateElement("collision", Collision, var="collisions"),
+ pose_element,
+ ],
+)
class Model(xmlr.Object):
@@ -279,7 +316,7 @@ def __init__(self, name=None, pose=None):
def add_aggregate(self, typeName, elem):
xmlr.Object.add_aggregate(self, typeName, elem)
- if typeName == 'joint':
+ if typeName == "joint":
joint = elem
self.joint_map[joint.name] = joint
self.parent_map[joint.child] = (joint.name, joint.parent)
@@ -287,23 +324,27 @@ def add_aggregate(self, typeName, elem):
self.child_map[joint.parent].append((joint.name, joint.child))
else:
self.child_map[joint.parent] = [(joint.name, joint.child)]
- elif typeName == 'link':
+ elif typeName == "link":
link = elem
self.link_map[link.name] = link
def add_link(self, link):
- self.add_aggregate('link', link)
+ self.add_aggregate("link", link)
def add_joint(self, joint):
- self.add_aggregate('joint', joint)
+ self.add_aggregate("joint", joint)
-xmlr.reflect(Model, tag='model', params=[
- name_attribute,
- xmlr.AggregateElement('link', Link, var='links'),
- xmlr.AggregateElement('joint', Joint, var='joints'),
- pose_element
-])
+xmlr.reflect(
+ Model,
+ tag="model",
+ params=[
+ name_attribute,
+ xmlr.AggregateElement("link", Link, var="links"),
+ xmlr.AggregateElement("joint", Joint, var="joints"),
+ pose_element,
+ ],
+)
class SDF(xmlr.Object):
@@ -311,9 +352,13 @@ def __init__(self, version=None):
self.version = version
-xmlr.reflect(SDF, tag='sdf', params=[
- xmlr.Attribute('version', str, False),
- xmlr.Element('model', Model, False),
-])
+xmlr.reflect(
+ SDF,
+ tag="sdf",
+ params=[
+ xmlr.Attribute("version", str, False),
+ xmlr.Element("model", Model, False),
+ ],
+)
xmlr.end_namespace()
diff --git a/src/pytorch_kinematics/urdf_parser_py/urdf.py b/src/pytorch_kinematics/urdf_parser_py/urdf.py
index 65a7046..be63511 100644
--- a/src/pytorch_kinematics/urdf_parser_py/urdf.py
+++ b/src/pytorch_kinematics/urdf_parser_py/urdf.py
@@ -1,14 +1,17 @@
-from .xml_reflection.basics import *
+from typing import ClassVar
+
from . import xml_reflection as xmlr
+from .xml_reflection.basics import node_add, xml_children
+
# Add a 'namespace' for names to avoid a conflict between URDF and SDF?
# A type registry? How to scope that? Just make a 'global' type pointer?
# Or just qualify names? urdf.geometric, sdf.geometric
-xmlr.start_namespace('urdf')
+xmlr.start_namespace("urdf")
-xmlr.add_type('element_link', xmlr.SimpleElementType('link', str))
-xmlr.add_type('element_xyz', xmlr.SimpleElementType('xyz', 'vector3'))
+xmlr.add_type("element_link", xmlr.SimpleElementType("link", str))
+xmlr.add_type("element_xyz", xmlr.SimpleElementType("xyz", "vector3"))
verbose = True
@@ -19,38 +22,45 @@ def __init__(self, xyz=None, rpy=None):
self.rpy = rpy
def check_valid(self):
- assert (self.xyz is None or len(self.xyz) == 3) and \
- (self.rpy is None or len(self.rpy) == 3)
+ assert (self.xyz is None or len(self.xyz) == 3) and (self.rpy is None or len(self.rpy) == 3)
# Aliases for backwards compatibility
@property
- def rotation(self): return self.rpy
+ def rotation(self):
+ return self.rpy
@rotation.setter
- def rotation(self, value): self.rpy = value
+ def rotation(self, value):
+ self.rpy = value
@property
- def position(self): return self.xyz
+ def position(self):
+ return self.xyz
@position.setter
- def position(self, value): self.xyz = value
+ def position(self, value):
+ self.xyz = value
-xmlr.reflect(Pose, tag='origin', params=[
- xmlr.Attribute('xyz', 'vector3', False, default=[0, 0, 0]),
- xmlr.Attribute('rpy', 'vector3', False, default=[0, 0, 0])
-])
+xmlr.reflect(
+ Pose,
+ tag="origin",
+ params=[
+ xmlr.Attribute("xyz", "vector3", False, default=[0, 0, 0]),
+ xmlr.Attribute("rpy", "vector3", False, default=[0, 0, 0]),
+ ],
+)
# Common stuff
-name_attribute = xmlr.Attribute('name', str)
-origin_element = xmlr.Element('origin', Pose, False)
+name_attribute = xmlr.Attribute("name", str)
+origin_element = xmlr.Element("origin", Pose, False)
class Color(xmlr.Object):
def __init__(self, *args):
# What about named colors?
count = len(args)
- if count == 4 or count == 3:
+ if count in (4, 3):
self.rgba = args
elif count == 1:
self.rgba = args[0]
@@ -58,14 +68,12 @@ def __init__(self, *args):
self.rgba = None
if self.rgba is not None:
if len(self.rgba) == 3:
- self.rgba += [1.]
+ self.rgba += [1.0]
if len(self.rgba) != 4:
- raise Exception('Invalid color argument count')
+ raise Exception("Invalid color argument count")
-xmlr.reflect(Color, tag='color', params=[
- xmlr.Attribute('rgba', 'vector4')
-])
+xmlr.reflect(Color, tag="color", params=[xmlr.Attribute("rgba", "vector4")])
class JointDynamics(xmlr.Object):
@@ -74,10 +82,14 @@ def __init__(self, damping=None, friction=None):
self.friction = friction
-xmlr.reflect(JointDynamics, tag='dynamics', params=[
- xmlr.Attribute('damping', float, False),
- xmlr.Attribute('friction', float, False)
-])
+xmlr.reflect(
+ JointDynamics,
+ tag="dynamics",
+ params=[
+ xmlr.Attribute("damping", float, False),
+ xmlr.Attribute("friction", float, False),
+ ],
+)
class Box(xmlr.Object):
@@ -85,9 +97,7 @@ def __init__(self, size=None):
self.size = size
-xmlr.reflect(Box, tag='box', params=[
- xmlr.Attribute('size', 'vector3')
-])
+xmlr.reflect(Box, tag="box", params=[xmlr.Attribute("size", "vector3")])
class Cylinder(xmlr.Object):
@@ -96,29 +106,32 @@ def __init__(self, radius=0.0, length=0.0):
self.length = length
-xmlr.reflect(Cylinder, tag='cylinder', params=[
- xmlr.Attribute('radius', float),
- xmlr.Attribute('length', float)
-])
+xmlr.reflect(
+ Cylinder,
+ tag="cylinder",
+ params=[xmlr.Attribute("radius", float), xmlr.Attribute("length", float)],
+)
+
class Capsule(xmlr.Object):
def __init__(self, radius=0.0, length=0.0):
self.radius = radius
self.length = length
-xmlr.reflect(Capsule, tag='capsule', params=[
- xmlr.Attribute('radius', float),
- xmlr.Attribute('length', float)
-])
+
+xmlr.reflect(
+ Capsule,
+ tag="capsule",
+ params=[xmlr.Attribute("radius", float), xmlr.Attribute("length", float)],
+)
+
class Sphere(xmlr.Object):
def __init__(self, radius=0.0):
self.radius = radius
-xmlr.reflect(Sphere, tag='sphere', params=[
- xmlr.Attribute('radius', float)
-])
+xmlr.reflect(Sphere, tag="sphere", params=[xmlr.Attribute("radius", float)])
class Mesh(xmlr.Object):
@@ -127,25 +140,32 @@ def __init__(self, filename=None, scale=None):
self.scale = scale
-xmlr.reflect(Mesh, tag='mesh', params=[
- xmlr.Attribute('filename', str),
- xmlr.Attribute('scale', 'vector3', required=False)
-])
+xmlr.reflect(
+ Mesh,
+ tag="mesh",
+ params=[
+ xmlr.Attribute("filename", str),
+ xmlr.Attribute("scale", "vector3", required=False),
+ ],
+)
class GeometricType(xmlr.ValueType):
def __init__(self):
- self.factory = xmlr.FactoryType('geometric', {
- 'box': Box,
- 'cylinder': Cylinder,
- 'sphere': Sphere,
- 'mesh': Mesh,
- 'capsule': Capsule
- })
+ self.factory = xmlr.FactoryType(
+ "geometric",
+ {
+ "box": Box,
+ "cylinder": Cylinder,
+ "sphere": Sphere,
+ "mesh": Mesh,
+ "capsule": Capsule,
+ },
+ )
def from_xml(self, node, path):
children = xml_children(node)
- assert len(children) == 1, 'One element only for geometric'
+ assert len(children) == 1, "One element only for geometric"
return self.factory.from_xml(children[0], path=path)
def write_xml(self, node, obj):
@@ -154,7 +174,7 @@ def write_xml(self, node, obj):
obj.write_xml(child)
-xmlr.add_type('geometric', GeometricType())
+xmlr.add_type("geometric", GeometricType())
class Collision(xmlr.Object):
@@ -163,10 +183,11 @@ def __init__(self, geometry=None, origin=None):
self.origin = origin
-xmlr.reflect(Collision, tag='collision', params=[
- origin_element,
- xmlr.Element('geometry', 'geometric')
-])
+xmlr.reflect(
+ Collision,
+ tag="collision",
+ params=[origin_element, xmlr.Element("geometry", "geometric")],
+)
class Texture(xmlr.Object):
@@ -174,9 +195,7 @@ def __init__(self, filename=None):
self.filename = filename
-xmlr.reflect(Texture, tag='texture', params=[
- xmlr.Attribute('filename', str)
-])
+xmlr.reflect(Texture, tag="texture", params=[xmlr.Attribute("filename", str)])
class Material(xmlr.Object):
@@ -190,11 +209,15 @@ def check_valid(self):
xmlr.on_error("Material has neither a color nor texture.")
-xmlr.reflect(Material, tag='material', params=[
- name_attribute,
- xmlr.Element('color', Color, False),
- xmlr.Element('texture', Texture, False)
-])
+xmlr.reflect(
+ Material,
+ tag="material",
+ params=[
+ name_attribute,
+ xmlr.Element("color", Color, False),
+ xmlr.Element("texture", Texture, False),
+ ],
+)
class LinkMaterial(Material):
@@ -209,15 +232,19 @@ def __init__(self, geometry=None, material=None, origin=None):
self.origin = origin
-xmlr.reflect(Visual, tag='visual', params=[
- origin_element,
- xmlr.Element('geometry', 'geometric'),
- xmlr.Element('material', LinkMaterial, False)
-])
+xmlr.reflect(
+ Visual,
+ tag="visual",
+ params=[
+ origin_element,
+ xmlr.Element("geometry", "geometric"),
+ xmlr.Element("material", LinkMaterial, False),
+ ],
+)
class Inertia(xmlr.Object):
- KEYS = ['ixx', 'ixy', 'ixz', 'iyy', 'iyz', 'izz']
+ KEYS: ClassVar = ["ixx", "ixy", "ixz", "iyy", "iyz", "izz"]
def __init__(self, ixx=0.0, ixy=0.0, ixz=0.0, iyy=0.0, iyz=0.0, izz=0.0):
self.ixx = ixx
@@ -231,11 +258,11 @@ def to_matrix(self):
return [
[self.ixx, self.ixy, self.ixz],
[self.ixy, self.iyy, self.iyz],
- [self.ixz, self.iyz, self.izz]]
+ [self.ixz, self.iyz, self.izz],
+ ]
-xmlr.reflect(Inertia, tag='inertia',
- params=[xmlr.Attribute(key, float) for key in Inertia.KEYS])
+xmlr.reflect(Inertia, tag="inertia", params=[xmlr.Attribute(key, float) for key in Inertia.KEYS])
class Inertial(xmlr.Object):
@@ -245,11 +272,15 @@ def __init__(self, mass=0.0, inertia=None, origin=None):
self.origin = origin
-xmlr.reflect(Inertial, tag='inertial', params=[
- origin_element,
- xmlr.Element('mass', 'element_value'),
- xmlr.Element('inertia', Inertia, False)
-])
+xmlr.reflect(
+ Inertial,
+ tag="inertial",
+ params=[
+ origin_element,
+ xmlr.Element("mass", "element_value"),
+ xmlr.Element("inertia", Inertia, False),
+ ],
+)
# FIXME: we are missing the reference position here.
@@ -259,10 +290,14 @@ def __init__(self, rising=None, falling=None):
self.falling = falling
-xmlr.reflect(JointCalibration, tag='calibration', params=[
- xmlr.Attribute('rising', float, False, 0),
- xmlr.Attribute('falling', float, False, 0)
-])
+xmlr.reflect(
+ JointCalibration,
+ tag="calibration",
+ params=[
+ xmlr.Attribute("rising", float, False, 0),
+ xmlr.Attribute("falling", float, False, 0),
+ ],
+)
class JointLimit(xmlr.Object):
@@ -273,12 +308,16 @@ def __init__(self, effort=None, velocity=None, lower=None, upper=None):
self.upper = upper
-xmlr.reflect(JointLimit, tag='limit', params=[
- xmlr.Attribute('effort', float),
- xmlr.Attribute('lower', float, False, 0),
- xmlr.Attribute('upper', float, False, 0),
- xmlr.Attribute('velocity', float)
-])
+xmlr.reflect(
+ JointLimit,
+ tag="limit",
+ params=[
+ xmlr.Attribute("effort", float),
+ xmlr.Attribute("lower", float, False, 0),
+ xmlr.Attribute("upper", float, False, 0),
+ xmlr.Attribute("velocity", float),
+ ],
+)
# FIXME: we are missing __str__ here.
@@ -291,11 +330,15 @@ def __init__(self, joint_name=None, multiplier=None, offset=None):
self.offset = offset
-xmlr.reflect(JointMimic, tag='mimic', params=[
- xmlr.Attribute('joint', str),
- xmlr.Attribute('multiplier', float, False),
- xmlr.Attribute('offset', float, False)
-])
+xmlr.reflect(
+ JointMimic,
+ tag="mimic",
+ params=[
+ xmlr.Attribute("joint", str),
+ xmlr.Attribute("multiplier", float, False),
+ xmlr.Attribute("offset", float, False),
+ ],
+)
class SafetyController(xmlr.Object):
@@ -306,22 +349,43 @@ def __init__(self, velocity=None, position=None, lower=None, upper=None):
self.soft_upper_limit = upper
-xmlr.reflect(SafetyController, tag='safety_controller', params=[
- xmlr.Attribute('k_velocity', float),
- xmlr.Attribute('k_position', float, False, 0),
- xmlr.Attribute('soft_lower_limit', float, False, 0),
- xmlr.Attribute('soft_upper_limit', float, False, 0)
-])
+xmlr.reflect(
+ SafetyController,
+ tag="safety_controller",
+ params=[
+ xmlr.Attribute("k_velocity", float),
+ xmlr.Attribute("k_position", float, False, 0),
+ xmlr.Attribute("soft_lower_limit", float, False, 0),
+ xmlr.Attribute("soft_upper_limit", float, False, 0),
+ ],
+)
class Joint(xmlr.Object):
- TYPES = ['unknown', 'revolute', 'continuous', 'prismatic',
- 'floating', 'planar', 'fixed']
-
- def __init__(self, name=None, parent=None, child=None, joint_type=None,
- axis=None, origin=None,
- limit=None, dynamics=None, safety_controller=None,
- calibration=None, mimic=None):
+ TYPES: ClassVar = [
+ "unknown",
+ "revolute",
+ "continuous",
+ "prismatic",
+ "floating",
+ "planar",
+ "fixed",
+ ]
+
+ def __init__(
+ self,
+ name=None,
+ parent=None,
+ child=None,
+ joint_type=None,
+ axis=None,
+ origin=None,
+ limit=None,
+ dynamics=None,
+ safety_controller=None,
+ calibration=None,
+ mimic=None,
+ ):
self.name = name
self.parent = parent
self.child = child
@@ -335,34 +399,39 @@ def __init__(self, name=None, parent=None, child=None, joint_type=None,
self.mimic = mimic
def check_valid(self):
- assert self.type in self.TYPES, "Invalid joint type: {}".format(self.type) # noqa
+ assert self.type in self.TYPES, f"Invalid joint type: {self.type}"
# Aliases
@property
- def joint_type(self): return self.type
+ def joint_type(self):
+ return self.type
@joint_type.setter
- def joint_type(self, value): self.type = value
-
-
-xmlr.reflect(Joint, tag='joint', params=[
- name_attribute,
- xmlr.Attribute('type', str),
- origin_element,
- xmlr.Element('axis', 'element_xyz', False),
- xmlr.Element('parent', 'element_link'),
- xmlr.Element('child', 'element_link'),
- xmlr.Element('limit', JointLimit, False),
- xmlr.Element('dynamics', JointDynamics, False),
- xmlr.Element('safety_controller', SafetyController, False),
- xmlr.Element('calibration', JointCalibration, False),
- xmlr.Element('mimic', JointMimic, False),
-])
+ def joint_type(self, value):
+ self.type = value
+
+
+xmlr.reflect(
+ Joint,
+ tag="joint",
+ params=[
+ name_attribute,
+ xmlr.Attribute("type", str),
+ origin_element,
+ xmlr.Element("axis", "element_xyz", False),
+ xmlr.Element("parent", "element_link"),
+ xmlr.Element("child", "element_link"),
+ xmlr.Element("limit", JointLimit, False),
+ xmlr.Element("dynamics", JointDynamics, False),
+ xmlr.Element("safety_controller", SafetyController, False),
+ xmlr.Element("calibration", JointCalibration, False),
+ xmlr.Element("mimic", JointMimic, False),
+ ],
+)
class Link(xmlr.Object):
- def __init__(self, name=None, visual=None, inertial=None, collision=None,
- origin=None):
+ def __init__(self, name=None, visual=None, inertial=None, collision=None, origin=None):
self.aggregate_init()
self.name = name
self.visuals = []
@@ -399,18 +468,21 @@ def __set_collision(self, collision):
collision = property(__get_collision, __set_collision)
-xmlr.reflect(Link, tag='link', params=[
- name_attribute,
- origin_element,
- xmlr.AggregateElement('visual', Visual),
- xmlr.AggregateElement('collision', Collision),
- xmlr.Element('inertial', Inertial, False),
-])
+xmlr.reflect(
+ Link,
+ tag="link",
+ params=[
+ name_attribute,
+ origin_element,
+ xmlr.AggregateElement("visual", Visual),
+ xmlr.AggregateElement("collision", Collision),
+ xmlr.Element("inertial", Inertial, False),
+ ],
+)
class PR2Transmission(xmlr.Object):
- def __init__(self, name=None, joint=None, actuator=None, type=None,
- mechanicalReduction=1):
+ def __init__(self, name=None, joint=None, actuator=None, type=None, mechanicalReduction=1):
self.name = name
self.type = type
self.joint = joint
@@ -418,13 +490,17 @@ def __init__(self, name=None, joint=None, actuator=None, type=None,
self.mechanicalReduction = mechanicalReduction
-xmlr.reflect(PR2Transmission, tag='pr2_transmission', params=[
- name_attribute,
- xmlr.Attribute('type', str),
- xmlr.Element('joint', 'element_name'),
- xmlr.Element('actuator', 'element_name'),
- xmlr.Element('mechanicalReduction', float)
-])
+xmlr.reflect(
+ PR2Transmission,
+ tag="pr2_transmission",
+ params=[
+ name_attribute,
+ xmlr.Attribute("type", str),
+ xmlr.Element("joint", "element_name"),
+ xmlr.Element("actuator", "element_name"),
+ xmlr.Element("mechanicalReduction", float),
+ ],
+)
class Actuator(xmlr.Object):
@@ -433,10 +509,11 @@ def __init__(self, name=None, mechanicalReduction=1):
self.mechanicalReduction = None
-xmlr.reflect(Actuator, tag='actuator', params=[
- name_attribute,
- xmlr.Element('mechanicalReduction', float, required=False)
-])
+xmlr.reflect(
+ Actuator,
+ tag="actuator",
+ params=[name_attribute, xmlr.Element("mechanicalReduction", float, required=False)],
+)
class TransmissionJoint(xmlr.Object):
@@ -449,14 +526,18 @@ def check_valid(self):
assert len(self.hardwareInterfaces) > 0, "no hardwareInterface defined"
-xmlr.reflect(TransmissionJoint, tag='joint', params=[
- name_attribute,
- xmlr.AggregateElement('hardwareInterface', str),
-])
+xmlr.reflect(
+ TransmissionJoint,
+ tag="joint",
+ params=[
+ name_attribute,
+ xmlr.AggregateElement("hardwareInterface", str),
+ ],
+)
class Transmission(xmlr.Object):
- """ New format: http://wiki.ros.org/urdf/XML/Transmission """
+ """New format: http://wiki.ros.org/urdf/XML/Transmission"""
def __init__(self, name=None):
self.aggregate_init()
@@ -469,16 +550,21 @@ def check_valid(self):
assert len(self.actuators) > 0, "no actuator defined"
-xmlr.reflect(Transmission, tag='new_transmission', params=[
- name_attribute,
- xmlr.Element('type', str),
- xmlr.AggregateElement('joint', TransmissionJoint),
- xmlr.AggregateElement('actuator', Actuator)
-])
+xmlr.reflect(
+ Transmission,
+ tag="new_transmission",
+ params=[
+ name_attribute,
+ xmlr.Element("type", str),
+ xmlr.AggregateElement("joint", TransmissionJoint),
+ xmlr.AggregateElement("actuator", Actuator),
+ ],
+)
-xmlr.add_type('transmission',
- xmlr.DuckTypedFactory('transmission',
- [Transmission, PR2Transmission]))
+xmlr.add_type(
+ "transmission",
+ xmlr.DuckTypedFactory("transmission", [Transmission, PR2Transmission]),
+)
class Robot(xmlr.Object):
@@ -501,7 +587,7 @@ def __init__(self, name=None):
def add_aggregate(self, typeName, elem):
xmlr.Object.add_aggregate(self, typeName, elem)
- if typeName == 'joint':
+ if typeName == "joint":
joint = elem
self.joint_map[joint.name] = joint
self.parent_map[joint.child] = (joint.name, joint.parent)
@@ -509,15 +595,15 @@ def add_aggregate(self, typeName, elem):
self.child_map[joint.parent].append((joint.name, joint.child))
else:
self.child_map[joint.parent] = [(joint.name, joint.child)]
- elif typeName == 'link':
+ elif typeName == "link":
link = elem
self.link_map[link.name] = link
def add_link(self, link):
- self.add_aggregate('link', link)
+ self.add_aggregate("link", link)
def add_joint(self, joint):
- self.add_aggregate('joint', joint)
+ self.add_aggregate("joint", joint)
def get_chain(self, root, tip, joints=True, links=True, fixed=True):
chain = []
@@ -526,9 +612,8 @@ def get_chain(self, root, tip, joints=True, links=True, fixed=True):
link = tip
while link != root:
(joint, parent) = self.parent_map[link]
- if joints:
- if fixed or self.joint_map[joint].joint_type != 'fixed':
- chain.append(joint)
+ if joints and fixed or self.joint_map[joint].joint_type != "fixed":
+ chain.append(joint)
if links:
chain.append(parent)
link = parent
@@ -545,7 +630,7 @@ def get_root(self):
return root
@classmethod
- def from_parameter_server(cls, key='robot_description'):
+ def from_parameter_server(cls, key="robot_description"):
"""
Retrieve the robot model on the parameter server
and parse it to create a URDF robot structure.
@@ -554,17 +639,22 @@ def from_parameter_server(cls, key='robot_description'):
"""
# Could move this into xml_reflection
import rospy
+
return cls.from_xml_string(rospy.get_param(key))
-xmlr.reflect(Robot, tag='robot', params=[
- xmlr.Attribute('name', str, False), # Is 'name' a required attribute?
- xmlr.AggregateElement('link', Link),
- xmlr.AggregateElement('joint', Joint),
- xmlr.AggregateElement('gazebo', xmlr.RawType()),
- xmlr.AggregateElement('transmission', 'transmission'),
- xmlr.AggregateElement('material', Material)
-])
+xmlr.reflect(
+ Robot,
+ tag="robot",
+ params=[
+ xmlr.Attribute("name", str, False), # Is 'name' a required attribute?
+ xmlr.AggregateElement("link", Link),
+ xmlr.AggregateElement("joint", Joint),
+ xmlr.AggregateElement("gazebo", xmlr.RawType()),
+ xmlr.AggregateElement("transmission", "transmission"),
+ xmlr.AggregateElement("material", Material),
+ ],
+)
# Make an alias
URDF = Robot
diff --git a/src/pytorch_kinematics/urdf_parser_py/xml_reflection/__init__.py b/src/pytorch_kinematics/urdf_parser_py/xml_reflection/__init__.py
index bb67a43..ac5de24 100644
--- a/src/pytorch_kinematics/urdf_parser_py/xml_reflection/__init__.py
+++ b/src/pytorch_kinematics/urdf_parser_py/xml_reflection/__init__.py
@@ -1 +1,56 @@
-from .core import *
+from .core import (
+ AggregateElement,
+ Attribute,
+ BasicType,
+ DuckTypedFactory,
+ Element,
+ FactoryType,
+ Info,
+ ListType,
+ Object,
+ ObjectType,
+ Param,
+ ParseError,
+ Path,
+ RawType,
+ Reflection,
+ SimpleElementType,
+ ValueType,
+ VectorType,
+ add_type,
+ end_namespace,
+ get_type,
+ make_type,
+ on_error_stderr,
+ reflect,
+ start_namespace,
+)
+
+
+__all__ = [
+ "reflect",
+ "on_error_stderr",
+ "start_namespace",
+ "end_namespace",
+ "add_type",
+ "get_type",
+ "make_type",
+ "Path",
+ "ParseError",
+ "ValueType",
+ "BasicType",
+ "ListType",
+ "VectorType",
+ "RawType",
+ "SimpleElementType",
+ "ObjectType",
+ "FactoryType",
+ "DuckTypedFactory",
+ "Param",
+ "Attribute",
+ "Element",
+ "AggregateElement",
+ "Info",
+ "Reflection",
+ "Object",
+]
diff --git a/src/pytorch_kinematics/urdf_parser_py/xml_reflection/basics.py b/src/pytorch_kinematics/urdf_parser_py/xml_reflection/basics.py
index 60a0d44..17b3ae2 100644
--- a/src/pytorch_kinematics/urdf_parser_py/xml_reflection/basics.py
+++ b/src/pytorch_kinematics/urdf_parser_py/xml_reflection/basics.py
@@ -1,90 +1,94 @@
-import string
-import yaml
import collections
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Union
+
+import yaml
from lxml import etree
-def xml_string(rootXml, addHeader=True):
+def xml_string(rootXml: etree._Element, addHeader: bool = True) -> str:
# Meh
- xmlString = etree.tostring(rootXml, pretty_print=True, encoding='unicode')
+ xmlString: str = etree.tostring(rootXml, pretty_print=True, encoding="unicode")
if addHeader:
xmlString = '\n' + xmlString
return xmlString
-def dict_sub(obj, keys):
- return dict((key, obj[key]) for key in keys)
+def dict_sub(obj: Mapping[str, Any], keys: Iterable[str]) -> Dict[str, Any]:
+ return {key: obj[key] for key in keys}
-def node_add(doc, sub):
+def node_add(
+ doc: etree._Element,
+ sub: Optional[Union[str, etree._Element]],
+) -> Optional[etree._Element]:
if sub is None:
return None
- if type(sub) == str:
+ if isinstance(sub, str):
return etree.SubElement(doc, sub)
elif isinstance(sub, etree._Element):
doc.append(sub) # This screws up the rest of the tree for prettyprint
return sub
else:
- raise Exception('Invalid sub value')
+ raise Exception("Invalid sub value")
-def pfloat(x):
- return str(x).rstrip('.')
+def pfloat(x: Any) -> str:
+ return str(x).rstrip(".")
-def xml_children(node):
- children = node.getchildren()
+def xml_children(node: etree._Element) -> List[etree._Element]:
+ children: List[etree._Element] = node.getchildren()
- def predicate(node):
- return not isinstance(node, etree._Comment)
+ def predicate(n: etree._Element) -> bool:
+ return not isinstance(n, etree._Comment)
return list(filter(predicate, children))
-def isstring(obj):
+def isstring(obj: Any) -> bool:
try:
- return isinstance(obj, basestring)
+ return isinstance(obj, basestring) # type: ignore[name-defined]
except NameError:
return isinstance(obj, str)
-def to_yaml(obj):
- """ Simplify yaml representation for pretty printing """
+def to_yaml(obj: Any) -> Any:
+ """Simplify yaml representation for pretty printing"""
# Is there a better way to do this by adding a representation with
# yaml.Dumper?
# Ordered dict: http://pyyaml.org/ticket/29#comment:11
if obj is None or isstring(obj):
- out = str(obj)
+ out: Any = str(obj)
elif type(obj) in [int, float, bool]:
return obj
- elif hasattr(obj, 'to_yaml'):
+ elif hasattr(obj, "to_yaml"):
out = obj.to_yaml()
elif isinstance(obj, etree._Element):
out = etree.tostring(obj, pretty_print=True)
- elif type(obj) == dict:
+ elif isinstance(obj, dict):
out = {}
- for (var, value) in obj.items():
+ for var, value in obj.items():
out[str(var)] = to_yaml(value)
- elif hasattr(obj, 'tolist'):
+ elif hasattr(obj, "tolist"):
# For numpy objects
out = to_yaml(obj.tolist())
- elif isinstance(obj, collections.Iterable):
+ elif isinstance(obj, collections.abc.Iterable):
out = [to_yaml(item) for item in obj]
else:
out = str(obj)
return out
-class SelectiveReflection(object):
- def get_refl_vars(self):
+class SelectiveReflection:
+ def get_refl_vars(self) -> List[str]:
return list(vars(self).keys())
class YamlReflection(SelectiveReflection):
- def to_yaml(self):
- raw = dict((var, getattr(self, var)) for var in self.get_refl_vars())
+ def to_yaml(self) -> Any:
+ raw = {var: getattr(self, var) for var in self.get_refl_vars()}
return to_yaml(raw)
- def __str__(self):
+ def __str__(self) -> str:
# Good idea? Will it remove other important things?
return yaml.dump(self.to_yaml()).rstrip()
diff --git a/src/pytorch_kinematics/urdf_parser_py/xml_reflection/core.py b/src/pytorch_kinematics/urdf_parser_py/xml_reflection/core.py
index 8659bde..ac26cb9 100644
--- a/src/pytorch_kinematics/urdf_parser_py/xml_reflection/core.py
+++ b/src/pytorch_kinematics/urdf_parser_py/xml_reflection/core.py
@@ -1,6 +1,15 @@
-from .basics import *
-import sys
import copy
+import sys
+from typing import Any, Callable, Dict, List, Optional, Type
+
+from lxml import etree
+
+from .basics import (
+ YamlReflection,
+ node_add,
+ xml_children,
+ xml_string,
+)
# @todo Get rid of "import *"
@@ -10,12 +19,12 @@
# Rename?
# Do parent operations after, to allow child to 'override' parameters?
-# Need to make sure that duplicate entires do not get into the 'unset*' lists
+# Need to make sure that duplicate entries do not get into the 'unset*' lists
def reflect(cls, *args, **kwargs):
"""
- Simple wrapper to add XML reflection to an xml_reflection.Object class
+ Simple wrapper to add XML reflection to an xml_reflection.Object class.
"""
cls.XML_REFL = Reflection(*args, **kwargs)
@@ -24,49 +33,50 @@ def reflect(cls, *args, **kwargs):
# 'pre_dump' and 'post_load'?
# When dumping to yaml, include tag name?
+
# How to incorporate line number and all that jazz?
def on_error_stderr(message):
- """ What to do on an error. This can be changed to raise an exception. """
- sys.stderr.write(message + '\n')
+ """What to do on an error. This can be changed to raise an exception."""
+ sys.stderr.write(message + "\n")
-on_error = on_error_stderr
+on_error: Callable[[str], None] = on_error_stderr
-skip_default = False
-# defaultIfMatching = True # Not implemeneted yet
+skip_default: bool = False
+# defaultIfMatching = True # Not implemented yet
# Registering Types
-value_types = {}
-value_type_prefix = ''
+value_types: Dict[Any, "ValueType"] = {}
+value_type_prefix: str = ""
-def start_namespace(namespace):
+def start_namespace(namespace: str) -> None:
"""
Basic mechanism to prevent conflicts for string types for URDF and SDF
@note Does not handle nesting!
"""
global value_type_prefix
- value_type_prefix = namespace + '.'
+ value_type_prefix = namespace + "."
-def end_namespace():
+def end_namespace() -> None:
global value_type_prefix
- value_type_prefix = ''
+ value_type_prefix = ""
-def add_type(key, value):
+def add_type(key: Any, value: "ValueType") -> None:
if isinstance(key, str):
key = value_type_prefix + key
assert key not in value_types
value_types[key] = value
-def get_type(cur_type):
- """ Can wrap value types if needed """
+def get_type(cur_type: Any) -> "ValueType":
+ """Can wrap value types if needed"""
if value_type_prefix and isinstance(cur_type, str):
# See if it exists in current 'namespace'
- curKey = value_type_prefix + cur_type
- value_type = value_types.get(curKey)
+ cur_key = value_type_prefix + cur_type
+ value_type = value_types.get(cur_key)
else:
value_type = None
if value_type is None:
@@ -78,19 +88,16 @@ def get_type(cur_type):
return value_type
-def make_type(cur_type):
+def make_type(cur_type: Any) -> "ValueType":
if isinstance(cur_type, ValueType):
return cur_type
elif isinstance(cur_type, str):
- if cur_type.startswith('vector'):
+ if cur_type.startswith("vector"):
extra = cur_type[6:]
- if extra:
- count = float(extra)
- else:
- count = None
+ count: Optional[float] = float(extra) if extra else None
return VectorType(count)
else:
- raise Exception("Invalid value type: {}".format(cur_type))
+ raise Exception(f"Invalid value type: {cur_type}")
elif cur_type == list:
return ListType()
elif issubclass(cur_type, Object):
@@ -98,10 +105,10 @@ def make_type(cur_type):
elif cur_type in [str, float, bool]:
return BasicType(cur_type)
else:
- raise Exception("Invalid type: {}".format(cur_type))
+ raise Exception(f"Invalid type: {cur_type}")
-class Path(object):
+class Path:
def __init__(self, tag, parent=None, suffix="", tree=None):
self.parent = parent
self.tag = tag
@@ -110,24 +117,23 @@ def __init__(self, tag, parent=None, suffix="", tree=None):
def __str__(self):
if self.parent is not None:
- return "{}/{}{}".format(self.parent, self.tag, self.suffix)
+ return f"{self.parent}/{self.tag}{self.suffix}"
+ elif self.tag is not None and len(self.tag) > 0:
+ return f"/{self.tag}{self.suffix}"
else:
- if self.tag is not None and len(self.tag) > 0:
- return "/{}{}".format(self.tag, self.suffix)
- else:
- return self.suffix
+ return self.suffix
class ParseError(Exception):
def __init__(self, e, path):
self.e = e
self.path = path
- message = "ParseError in {}:\n{}".format(self.path, self.e)
- super(ParseError, self).__init__(message)
+ message = f"ParseError in {self.path}:\n{self.e}"
+ super().__init__(message)
-class ValueType(object):
- """ Primitive value type """
+class ValueType:
+ """Primitive value type"""
def from_xml(self, node, path):
return self.from_string(node.text)
@@ -158,13 +164,13 @@ def from_string(self, value):
class ListType(ValueType):
def to_string(self, values):
- return ' '.join(values)
+ return " ".join(values)
def from_string(self, text):
return text.split()
def equals(self, aValues, bValues):
- return len(aValues) == len(bValues) and all(a == b for (a, b) in zip(aValues, bValues)) # noqa
+ return len(aValues) == len(bValues) and all(a == b for (a, b) in zip(aValues, bValues))
class VectorType(ListType):
@@ -200,7 +206,7 @@ def write_xml(self, node, value):
children = xml_children(value)
list(map(node.append, children))
# Copy attributes
- for (attrib_key, attrib_value) in value.attrib.items():
+ for attrib_key, attrib_value in value.attrib.items():
node.set(attrib_key, attrib_value)
@@ -241,14 +247,14 @@ def __init__(self, name, typeMap):
self.name = name
self.typeMap = typeMap
self.nameMap = {}
- for (key, value) in typeMap.items():
+ for key, value in typeMap.items():
# Reverse lookup
self.nameMap[value] = key
def from_xml(self, node, path):
cur_type = self.typeMap.get(node.tag)
if cur_type is None:
- raise Exception("Invalid {} tag: {}".format(self.name, node.tag))
+ raise Exception(f"Invalid {self.name} tag: {node.tag}")
value_type = get_type(cur_type)
return value_type.from_xml(node, path)
@@ -256,7 +262,7 @@ def get_name(self, obj):
cur_type = type(obj)
name = self.nameMap.get(cur_type)
if name is None:
- raise Exception("Invalid {} type: {}".format(self.name, cur_type))
+ raise Exception(f"Invalid {self.name} type: {cur_type}")
return name
def write_xml(self, node, obj):
@@ -278,16 +284,16 @@ def from_xml(self, node, path):
error_set.append((value_type, e))
# Should have returned, we encountered errors
out = "Could not perform duck-typed parsing."
- for (value_type, e) in error_set:
- out += "\nValue Type: {}\nException: {}\n".format(value_type, e)
+ for value_type, e in error_set:
+ out += f"\nValue Type: {value_type}\nException: {e}\n"
raise ParseError(Exception(out), path)
def write_xml(self, node, obj):
obj.write_xml(node)
-class Param(object):
- """ Mirroring Gazebo's SDF api
+class Param:
+ """Mirroring Gazebo's SDF api
@param xml_var: Xml name
@todo If the value_type is an object with a tag defined in it's
@@ -296,8 +302,7 @@ class Param(object):
XML name
"""
- def __init__(self, xml_var, value_type, required=True, default=None,
- var=None):
+ def __init__(self, xml_var, value_type, required=True, default=None, var=None):
self.xml_var = xml_var
if var is None:
self.var = xml_var
@@ -307,25 +312,24 @@ def __init__(self, xml_var, value_type, required=True, default=None,
self.value_type = get_type(value_type)
self.default = default
if required:
- assert default is None, "Default does not make sense for a required field" # noqa
+ assert default is None, "Default does not make sense for a required field"
self.required = required
self.is_aggregate = False
def set_default(self, obj):
if self.required:
- raise Exception("Required {} not set in XML: {}".format(self.type, self.xml_var)) # noqa
+ raise Exception(f"Required {self.type} not set in XML: {self.xml_var}")
elif not skip_default:
setattr(obj, self.var, self.default)
class Attribute(Param):
- def __init__(self, xml_var, value_type, required=True, default=None,
- var=None):
+ def __init__(self, xml_var, value_type, required=True, default=None, var=None):
Param.__init__(self, xml_var, value_type, required, default, var)
- self.type = 'attribute'
+ self.type = "attribute"
def set_from_string(self, obj, value):
- """ Node is the parent node in this case """
+ """Node is the parent node in this case"""
# Duplicate attributes cannot occur at this point
setattr(obj, self.var, self.value_type.from_string(value))
@@ -337,7 +341,7 @@ def add_to_xml(self, obj, node):
# Do not set with default value if value is None
if value is None:
if self.required:
- raise Exception("Required attribute not set in object: {}".format(self.var)) # noqa
+ raise Exception(f"Required attribute not set in object: {self.var}")
elif not skip_default:
value = self.default
# Allow value type to handle None?
@@ -351,10 +355,9 @@ def add_to_xml(self, obj, node):
class Element(Param):
- def __init__(self, xml_var, value_type, required=True, default=None,
- var=None, is_raw=False):
+ def __init__(self, xml_var, value_type, required=True, default=None, var=None, is_raw=False):
Param.__init__(self, xml_var, value_type, required, default, var)
- self.type = 'element'
+ self.type = "element"
self.is_raw = is_raw
def set_from_xml(self, obj, node, path):
@@ -365,26 +368,22 @@ def add_to_xml(self, obj, parent):
value = getattr(obj, self.xml_var)
if value is None:
if self.required:
- raise Exception("Required element not defined in object: {}".format(self.var)) # noqa
+ raise Exception(f"Required element not defined in object: {self.var}")
elif not skip_default:
value = self.default
if value is not None:
self.add_scalar_to_xml(parent, value)
def add_scalar_to_xml(self, parent, value):
- if self.is_raw:
- node = parent
- else:
- node = node_add(parent, self.xml_var)
+ node = parent if self.is_raw else node_add(parent, self.xml_var)
self.value_type.write_xml(node, value)
class AggregateElement(Element):
def __init__(self, xml_var, value_type, var=None, is_raw=False):
if var is None:
- var = xml_var + 's'
- Element.__init__(self, xml_var, value_type, required=False, var=var,
- is_raw=is_raw)
+ var = xml_var + "s"
+ Element.__init__(self, xml_var, value_type, required=False, var=var, is_raw=is_raw)
self.is_aggregate = True
def add_from_xml(self, obj, node, path):
@@ -396,21 +395,23 @@ def set_default(self, obj):
class Info:
- """ Small container for keeping track of what's been consumed """
+ """Small container for keeping track of what's been consumed"""
def __init__(self, node):
self.attributes = list(node.attrib.keys())
self.children = xml_children(node)
-class Reflection(object):
- def __init__(self, params=[], parent_cls=None, tag=None):
- """ Construct a XML reflection thing
+class Reflection:
+ def __init__(self, params=None, parent_cls=None, tag=None):
+ """Construct a XML reflection thing
@param parent_cls: Parent class, to use it's reflection as well.
@param tag: Only necessary if you intend to use Object.write_xml_doc()
This does not override the name supplied in the reflection
definition thing.
"""
+ if params is None:
+ params = []
if parent_cls is not None:
self.parent = parent_cls.XML_REFL
else:
@@ -472,16 +473,16 @@ def set_from_xml(self, obj, node, path, info=None):
def get_attr_path(attribute):
attr_path = copy.copy(path)
- attr_path.suffix += '[@{}]'.format(attribute.xml_var)
+ attr_path.suffix += f"[@{attribute.xml_var}]"
return attr_path
def get_element_path(element):
element_path = Path(element.xml_var, parent=path)
- # Add an index (allow this to be overriden)
+ # Add an index (allow this to be overridden)
if element.is_aggregate:
values = obj.get_aggregate_list(element.xml_var)
index = 1 + len(values) # 1-based indexing for W3C XPath
- element_path.suffix = "[{}]".format(index)
+ element_path.suffix = f"[{index}]"
return element_path
id_var = "name"
@@ -495,7 +496,7 @@ def get_element_path(element):
attribute.set_from_string(obj, value)
if attribute.xml_var == id_var:
# Add id_var suffix to current path (do not copy so it propagates)
- path.suffix = "[@{}='{}']".format(id_var, attribute.get_value(obj))
+ path.suffix = f"[@{id_var}='{attribute.get_value(obj)}']"
except ParseError:
raise
except Exception as e:
@@ -512,12 +513,11 @@ def get_element_path(element):
element_path = get_element_path(element)
if element.is_aggregate:
element.add_from_xml(obj, child, element_path)
+ elif tag in unset_scalars:
+ element.set_from_xml(obj, child, element_path)
+ unset_scalars.remove(tag)
else:
- if tag in unset_scalars:
- element.set_from_xml(obj, child, element_path)
- unset_scalars.remove(tag)
- else:
- on_error("Scalar element defined multiple times: {}".format(tag)) # noqa
+ on_error(f"Scalar element defined multiple times: {tag}")
info.children.remove(child)
# For unset attributes and scalar elements, we should not pass the attribute
@@ -542,9 +542,9 @@ def get_element_path(element):
if is_final:
for xml_var in info.attributes:
- on_error('Unknown attribute "{}" in {}'.format(xml_var, path))
+ on_error(f'Unknown attribute "{xml_var}" in {path}')
for node in info.children:
- on_error('Unknown tag "{}" in {}'.format(node.tag, path))
+ on_error(f'Unknown tag "{node.tag}" in {path}')
# Allow children parsers to adopt this current path (if modified with id_var)
return path
@@ -561,7 +561,8 @@ def add_to_xml(self, obj, node):
class Object(YamlReflection):
- """ Raw python object for yaml / xml representation """
+ """Raw python object for yaml / xml representation"""
+
XML_REFL = None
def get_refl_vars(self):
@@ -571,20 +572,19 @@ def check_valid(self):
pass
def pre_write_xml(self):
- """ If anything needs to be converted prior to dumping to xml
- i.e., getting the names of objects and such """
- pass
+ """If anything needs to be converted prior to dumping to xml
+ i.e., getting the names of objects and such"""
def write_xml(self, node):
- """ Adds contents directly to XML node """
+ """Adds contents directly to XML node"""
self.check_valid()
self.pre_write_xml()
self.XML_REFL.add_to_xml(self, node)
def to_xml(self):
- """ Creates an overarching tag and adds its contents to the node """
+ """Creates an overarching tag and adds its contents to the node"""
tag = self.XML_REFL.tag
- assert tag is not None, "Must define 'tag' in reflection to use this function" # noqa
+ assert tag is not None, "Must define 'tag' in reflection to use this function"
doc = etree.Element(tag)
self.write_xml(doc)
return doc
@@ -606,66 +606,58 @@ def read_xml(self, node, path):
raise ParseError(e, path)
@classmethod
- def from_xml(cls, node, path):
+ def from_xml(cls: Type["Object"], node: etree._Element, path: Path) -> "Object":
cur_type = get_type(cls)
return cur_type.from_xml(node, path)
@classmethod
- def from_xml_string(cls, xml_string):
- node = etree.fromstring(xml_string)
+ def from_xml_string(cls: Type["Object"], xml_str: str) -> "Object":
+ node = etree.fromstring(xml_str)
path = Path(cls.XML_REFL.tag, tree=etree.ElementTree(node))
return cls.from_xml(node, path)
@classmethod
- def from_xml_file(cls, file_path):
- xml_string = open(file_path, 'r').read()
- return cls.from_xml_string(xml_string)
-
- # Confusing distinction between loading code in object and reflection
- # registry thing...
+ def from_xml_file(cls: Type["Object"], file_path: str) -> "Object":
+ with open(file_path) as fp:
+ return cls.from_xml_string(fp.read())
- def get_aggregate_list(self, xml_var):
+ def get_aggregate_list(self, xml_var: str) -> List[Any]:
var = self.XML_REFL.paramMap[xml_var].var
values = getattr(self, var)
assert isinstance(values, list)
return values
- def aggregate_init(self):
- """ Must be called in constructor! """
- self.aggregate_order = []
- # Store this info in the loaded object??? Nah
- self.aggregate_type = {}
+ def aggregate_init(self) -> None:
+ """Must be called in constructor!"""
+ self.aggregate_order: List[Any] = []
+ self.aggregate_type: Dict[Any, str] = {}
- def add_aggregate(self, xml_var, obj):
- """ NOTE: One must keep careful track of aggregate types for this system.
- Can use 'lump_aggregates()' before writing if you don't care. """
+ def add_aggregate(self, xml_var: str, obj: Any) -> None:
self.get_aggregate_list(xml_var).append(obj)
self.aggregate_order.append(obj)
self.aggregate_type[obj] = xml_var
- def add_aggregates_to_xml(self, node):
+ def add_aggregates_to_xml(self, node: etree._Element) -> None:
for value in self.aggregate_order:
typeName = self.aggregate_type[value]
element = self.XML_REFL.element_map[typeName]
element.add_scalar_to_xml(node, value)
- def remove_aggregate(self, obj):
+ def remove_aggregate(self, obj: Any) -> None:
self.aggregate_order.remove(obj)
xml_var = self.aggregate_type[obj]
del self.aggregate_type[obj]
self.get_aggregate_list(xml_var).remove(obj)
- def lump_aggregates(self):
- """ Put all aggregate types together, just because """
+ def lump_aggregates(self) -> None:
+ """Put all aggregate types together, just because."""
self.aggregate_init()
for param in self.XML_REFL.aggregates:
for obj in self.get_aggregate_list(param.xml_var):
self.add_aggregate(param.var, obj)
- """ Compatibility """
-
- def parse(self, xml_string):
- node = etree.fromstring(xml_string)
+ def parse(self, xml_str: str) -> "Object":
+ node = etree.fromstring(xml_str)
path = Path(self.XML_REFL.tag, tree=etree.ElementTree(node))
self.read_xml(node, path)
return self
@@ -673,10 +665,10 @@ def parse(self, xml_string):
# Really common types
# Better name: element_with_name? Attributed element?
-add_type('element_name', SimpleElementType('name', str))
-add_type('element_value', SimpleElementType('value', float))
+add_type("element_name", SimpleElementType("name", str))
+add_type("element_value", SimpleElementType("value", float))
# Add in common vector types so they aren't absorbed into the namespaces
-get_type('vector3')
-get_type('vector4')
-get_type('vector6')
+get_type("vector3")
+get_type("vector4")
+get_type("vector6")
diff --git a/tests/gen_fk_perf.py b/tests/gen_fk_perf.py
index 334abad..8d9eaa9 100644
--- a/tests/gen_fk_perf.py
+++ b/tests/gen_fk_perf.py
@@ -1,10 +1,10 @@
""" Generate performance data for multiple models, devices, data types, batch sizes, etc. """
import timeit
-from time import perf_counter
+
+import numpy as np
import torch
import pytorch_kinematics as pk
-import numpy as np
def main():
@@ -12,18 +12,18 @@ def main():
torch.set_printoptions(precision=3, sci_mode=False, linewidth=220)
chains = {
- 'val': pk.build_chain_from_mjcf(open('val.xml').read()),
- 'val_serial': pk.build_serial_chain_from_mjcf(open('val.xml').read(), end_link_name='left_tool'),
- 'kuka_iiwa': pk.build_serial_chain_from_urdf(open('kuka_iiwa.urdf').read(), end_link_name='lbr_iiwa_link_7'),
+ "val": pk.build_chain_from_mjcf(open("val.xml").read()),
+ "val_serial": pk.build_serial_chain_from_mjcf(open("val.xml").read(), end_link_name="left_tool"),
+ "kuka_iiwa": pk.build_serial_chain_from_urdf(open("kuka_iiwa.urdf").read(), end_link_name="lbr_iiwa_link_7"),
}
- devices = ['cpu', 'cuda']
+ devices = ["cpu", "cuda"]
dtypes = [torch.float32, torch.float64]
batch_sizes = [1, 10, 100, 1_000, 10_000, 100_000]
number = 100
# iterate over all combinations and store in a pandas dataframe
- headers = ['method', 'chain', 'device', 'dtype', 'batch_size', 'time']
+ headers = ["method", "chain", "device", "dtype", "batch_size", "time"]
data = []
def _fk(th):
@@ -42,9 +42,10 @@ def _fk(th):
# pickle the data for visualization in jupyter notebook
import pickle
- with open('fk_perf.pkl', 'wb') as f:
+
+ with open("fk_perf.pkl", "wb") as f:
pickle.dump([headers, data], f)
-if __name__ == '__main__':
+if __name__ == "__main__":
main()
diff --git a/tests/hopper.xml b/tests/hopper.xml
index 3c98b55..4dfc44c 100644
--- a/tests/hopper.xml
+++ b/tests/hopper.xml
@@ -38,4 +38,4 @@
-
\ No newline at end of file
+
diff --git a/tests/humanoid.xml b/tests/humanoid.xml
index 5c50f7d..9dcc726 100644
--- a/tests/humanoid.xml
+++ b/tests/humanoid.xml
@@ -18,12 +18,12 @@
-
-
-
+
+
+
@@ -35,7 +35,7 @@
-
+
diff --git a/tests/joint_limit_robot.urdf b/tests/joint_limit_robot.urdf
index ee3b989..9df0e43 100644
--- a/tests/joint_limit_robot.urdf
+++ b/tests/joint_limit_robot.urdf
@@ -1,6 +1,6 @@
-
@@ -284,6 +284,5 @@
-
-
+
diff --git a/tests/prismatic_robot.urdf b/tests/prismatic_robot.urdf
index 1eb0537..f482d8c 100644
--- a/tests/prismatic_robot.urdf
+++ b/tests/prismatic_robot.urdf
@@ -1,4 +1,4 @@
-
+
diff --git a/tests/simple_y_arm.urdf b/tests/simple_y_arm.urdf
index a62edf3..247f18a 100644
--- a/tests/simple_y_arm.urdf
+++ b/tests/simple_y_arm.urdf
@@ -36,4 +36,4 @@
-
\ No newline at end of file
+
diff --git a/tests/test_attributes.py b/tests/test_attributes.py
index d6b3a22..ea41d8e 100644
--- a/tests/test_attributes.py
+++ b/tests/test_attributes.py
@@ -2,8 +2,10 @@
import pytorch_kinematics as pk
+
TEST_DIR = os.path.dirname(__file__)
+
def test_limits():
chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), "lbr_iiwa_link_7")
iiwa_low_individual = []
@@ -35,7 +37,7 @@ def test_limits():
assert low == [x + 8 for x in nums]
assert high == [x + 9 for x in nums]
assert v_low == [-x for x in nums]
- assert v_high == [x for x in nums]
+ assert v_high == list(nums)
assert e_low == [-(x + 4) for x in nums]
assert e_high == [x + 4 for x in nums]
@@ -60,7 +62,7 @@ def test_empty_limits():
assert low == [0] * len(nums)
assert high == [0] * len(nums)
assert v_low == [-x for x in nums]
- assert v_high == [x for x in nums]
+ assert v_high == list(nums)
assert e_low == [-(x + 4) for x in nums]
assert e_high == [x + 4 for x in nums]
diff --git a/tests/test_inverse_kinematics.py b/tests/test_inverse_kinematics.py
index 2659f16..a90ca8d 100644
--- a/tests/test_inverse_kinematics.py
+++ b/tests/test_inverse_kinematics.py
@@ -2,13 +2,13 @@
from timeit import default_timer as timer
import numpy as np
+import pybullet as p
+import pybullet_data
+import pytorch_seed
import torch
import pytorch_kinematics as pk
-import pytorch_seed
-import pybullet as p
-import pybullet_data
visualize = False
@@ -24,6 +24,7 @@ def make_transparent(link):
for link in visual_data:
make_transparent(link)
+
def create_test_chain(robot="kuka_iiwa", device="cpu"):
if robot == "kuka_iiwa":
urdf = "kuka_iiwa/model.urdf"
@@ -40,6 +41,7 @@ def create_test_chain(robot="kuka_iiwa", device="cpu"):
raise NotImplementedError(f"Robot {robot} not implemented")
return chain, urdf
+
def test_jacobian_follower(robot="kuka_iiwa"):
pytorch_seed.seed(2)
device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -68,13 +70,17 @@ def test_jacobian_follower(robot="kuka_iiwa"):
goal_rot = pk.matrix_to_euler_angles(goal[..., :3, :3], "XYZ")
num_retries = 10
- ik = pk.PseudoInverseIK(chain, max_iterations=30, num_retries=num_retries,
- joint_limits=lim.T,
- early_stopping_any_converged=True,
- early_stopping_no_improvement="all",
- # line_search=pk.BacktrackingLineSearch(max_lr=0.2),
- debug=False,
- lr=0.2)
+ ik = pk.PseudoInverseIK(
+ chain,
+ max_iterations=30,
+ num_retries=num_retries,
+ joint_limits=lim.T,
+ early_stopping_any_converged=True,
+ early_stopping_no_improvement="all",
+ # line_search=pk.BacktrackingLineSearch(max_lr=0.2),
+ debug=False,
+ lr=0.2,
+ )
# do IK
timer_start = timer()
@@ -101,7 +107,7 @@ def test_jacobian_follower(robot="kuka_iiwa"):
pitch = -65
# dist = 1.
dist = 2.4
- target = np.array([2., 1.5, 0])
+ target = np.array([2.0, 1.5, 0])
p.resetDebugVisualizerCamera(dist, yaw, pitch, target)
plane_id = p.loadURDF("plane.urdf", [0, 0, 0], useFixedBase=True)
@@ -121,7 +127,12 @@ def test_jacobian_follower(robot="kuka_iiwa"):
for i in range(num_robots):
this_offset = np.array([i % 4 * offset, i // 4 * offset, 0])
- armId = p.loadURDF(urdf, basePosition=pos + this_offset, baseOrientation=rot, useFixedBase=True)
+ armId = p.loadURDF(
+ urdf,
+ basePosition=pos + this_offset,
+ baseOrientation=rot,
+ useFixedBase=True,
+ )
# _make_robot_translucent(armId, alpha=0.6)
robots.append({"id": armId, "offset": this_offset, "pos": pos})
@@ -129,13 +140,18 @@ def test_jacobian_follower(robot="kuka_iiwa"):
goals = []
# draw cone to indicate pose instead of sphere
- visId = p.createVisualShape(p.GEOM_MESH, fileName="meshes/cone.obj", meshScale=1.0,
- rgbaColor=[0., 1., 0., 0.5])
- for i in range(num_robots):
+ visId = p.createVisualShape(
+ p.GEOM_MESH,
+ fileName="meshes/cone.obj",
+ meshScale=1.0,
+ rgbaColor=[0.0, 1.0, 0.0, 0.5],
+ )
+ for _ in range(num_robots):
goals.append(p.createMultiBody(baseMass=0, baseVisualShapeIndex=visId))
try:
import window_recorder
+
with window_recorder.WindowRecorder(save_dir="."):
# batch over goals with num_robots
for j in range(0, M, num_robots):
@@ -154,9 +170,11 @@ def test_jacobian_follower(robot="kuka_iiwa"):
if ii > show_max_num_retries_per_goal:
break
for jj in range(num_robots):
- p.resetBasePositionAndOrientation(goals[jj],
- goal_pos[j + jj].cpu().numpy() + robots[jj]["offset"],
- xyzw[jj].cpu().numpy())
+ p.resetBasePositionAndOrientation(
+ goals[jj],
+ goal_pos[j + jj].cpu().numpy() + robots[jj]["offset"],
+ xyzw[jj].cpu().numpy(),
+ )
armId = robots[jj]["id"]
q = solutions[jj, ii, :]
for dof in range(q.shape[0]):
@@ -192,17 +210,21 @@ def test_ik_in_place_no_err(robot="kuka_iiwa"):
# transform to world frame for visualization
goal_tf = rob_tf.compose(goal_in_rob_frame_tf)
goal = goal_tf.get_matrix()
- goal_pos = goal[..., :3, 3]
- goal_rot = pk.matrix_to_euler_angles(goal[..., :3, :3], "XYZ")
-
- ik = pk.PseudoInverseIK(chain, max_iterations=30, num_retries=10,
- joint_limits=lim.T,
- early_stopping_any_converged=True,
- early_stopping_no_improvement="all",
- retry_configs=cur_q.reshape(1, -1),
- # line_search=pk.BacktrackingLineSearch(max_lr=0.2),
- debug=False,
- lr=0.2)
+ goal[..., :3, 3]
+ pk.matrix_to_euler_angles(goal[..., :3, :3], "XYZ")
+
+ ik = pk.PseudoInverseIK(
+ chain,
+ max_iterations=30,
+ num_retries=10,
+ joint_limits=lim.T,
+ early_stopping_any_converged=True,
+ early_stopping_no_improvement="all",
+ retry_configs=cur_q.reshape(1, -1),
+ # line_search=pk.BacktrackingLineSearch(max_lr=0.2),
+ debug=False,
+ lr=0.2,
+ )
# do IK
sol = ik.solve(goal_in_rob_frame_tf)
@@ -212,12 +234,10 @@ def test_ik_in_place_no_err(robot="kuka_iiwa"):
assert torch.allclose(sol.err_rot[0], torch.zeros(1, device=device), atol=1e-6)
-
-
if __name__ == "__main__":
print("Testing kuka_iiwa IK")
test_jacobian_follower(robot="kuka_iiwa")
test_ik_in_place_no_err(robot="kuka_iiwa")
print("Testing widowx IK")
test_jacobian_follower(robot="widowx")
- test_ik_in_place_no_err(robot="widowx")
\ No newline at end of file
+ test_ik_in_place_no_err(robot="widowx")
diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py
index 960c722..5b8ac11 100644
--- a/tests/test_jacobian.py
+++ b/tests/test_jacobian.py
@@ -6,58 +6,84 @@
import pytorch_kinematics as pk
+
TEST_DIR = os.path.dirname(__file__)
def test_correctness():
- chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(),
- "lbr_iiwa_link_7")
+ chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), "lbr_iiwa_link_7")
th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0])
J = chain.jacobian(th)
- J_expected = torch.tensor([[[0, 1.41421356e-02, 0, 2.82842712e-01, 0, 0, 0],
- [-6.60827561e-01, 0, -4.57275649e-01, 0, 5.72756493e-02, 0, 0],
- [0, 6.60827561e-01, 0, -3.63842712e-01, 0, 8.10000000e-02, 0],
- [0, 0, -7.07106781e-01, 0, -7.07106781e-01, 0, -1],
- [0, 1, 0, -1, 0, 1, 0],
- [1, 0, 7.07106781e-01, 0, -7.07106781e-01, 0, 0]]])
+ J_expected = torch.tensor(
+ [
+ [
+ [0, 1.41421356e-02, 0, 2.82842712e-01, 0, 0, 0],
+ [-6.60827561e-01, 0, -4.57275649e-01, 0, 5.72756493e-02, 0, 0],
+ [0, 6.60827561e-01, 0, -3.63842712e-01, 0, 8.10000000e-02, 0],
+ [0, 0, -7.07106781e-01, 0, -7.07106781e-01, 0, -1],
+ [0, 1, 0, -1, 0, 1, 0],
+ [1, 0, 7.07106781e-01, 0, -7.07106781e-01, 0, 0],
+ ]
+ ]
+ )
assert torch.allclose(J, J_expected, atol=1e-7)
chain = pk.build_chain_from_sdf(open(os.path.join(TEST_DIR, "simple_arm.sdf")).read())
chain = pk.SerialChain(chain, "arm_wrist_roll")
th = torch.tensor([0.8, 0.2, -0.5, -0.3])
J = chain.jacobian(th)
- torch.allclose(J, torch.tensor([[[0., -1.51017878, -0.46280904, 0.],
- [0., 0.37144033, 0.29716627, 0.],
- [0., 0., 0., 0.],
- [0., 0., 0., 0.],
- [0., 0., 0., 0.],
- [0., 1., 1., 1.]]]))
+ torch.allclose(
+ J,
+ torch.tensor(
+ [
+ [
+ [0.0, -1.51017878, -0.46280904, 0.0],
+ [0.0, 0.37144033, 0.29716627, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0],
+ [0.0, 1.0, 1.0, 1.0],
+ ]
+ ]
+ ),
+ )
def test_jacobian_at_different_loc_than_ee():
- chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(),
- "lbr_iiwa_link_7")
+ chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), "lbr_iiwa_link_7")
th = torch.tensor([0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0])
loc = torch.tensor([0.1, 0, 0])
J = chain.jacobian(th, locations=loc)
- J_c1 = torch.tensor([[[-0., 0.11414214, -0., 0.18284271, 0., 0.1, 0.],
- [-0.66082756, -0., -0.38656497, -0., 0.12798633, -0., 0.1],
- [-0., 0.66082756, -0., -0.36384271, 0., 0.081, -0.],
- [-0., -0., -0.70710678, -0., -0.70710678, 0., -1.],
- [0., 1., 0., -1., 0., 1., 0.],
- [1., 0., 0.70710678, 0., -0.70710678, -0., 0.]]])
+ J_c1 = torch.tensor(
+ [
+ [
+ [-0.0, 0.11414214, -0.0, 0.18284271, 0.0, 0.1, 0.0],
+ [-0.66082756, -0.0, -0.38656497, -0.0, 0.12798633, -0.0, 0.1],
+ [-0.0, 0.66082756, -0.0, -0.36384271, 0.0, 0.081, -0.0],
+ [-0.0, -0.0, -0.70710678, -0.0, -0.70710678, 0.0, -1.0],
+ [0.0, 1.0, 0.0, -1.0, 0.0, 1.0, 0.0],
+ [1.0, 0.0, 0.70710678, 0.0, -0.70710678, -0.0, 0.0],
+ ]
+ ]
+ )
assert torch.allclose(J, J_c1, atol=1e-7)
loc = torch.tensor([-0.1, 0.05, 0])
J = chain.jacobian(th, locations=loc)
- J_c2 = torch.tensor([[[-0.05, -0.08585786, -0.03535534, 0.38284271, 0.03535534, -0.1, -0.],
- [-0.66082756, -0., -0.52798633, -0., -0.01343503, 0., -0.1],
- [-0., 0.66082756, -0.03535534, -0.36384271, -0.03535534, 0.081, -0.05],
- [-0., -0., -0.70710678, -0., -0.70710678, 0., -1.],
- [0., 1., 0., -1., 0., 1., 0.],
- [1., 0., 0.70710678, 0., -0.70710678, -0., 0.]]])
+ J_c2 = torch.tensor(
+ [
+ [
+ [-0.05, -0.08585786, -0.03535534, 0.38284271, 0.03535534, -0.1, -0.0],
+ [-0.66082756, -0.0, -0.52798633, -0.0, -0.01343503, 0.0, -0.1],
+ [-0.0, 0.66082756, -0.03535534, -0.36384271, -0.03535534, 0.081, -0.05],
+ [-0.0, -0.0, -0.70710678, -0.0, -0.70710678, 0.0, -1.0],
+ [0.0, 1.0, 0.0, -1.0, 0.0, 1.0, 0.0],
+ [1.0, 0.0, 0.70710678, 0.0, -0.70710678, -0.0, 0.0],
+ ]
+ ]
+ )
assert torch.allclose(J, J_c2, atol=1e-7)
@@ -70,18 +96,21 @@ def test_jacobian_at_different_loc_than_ee():
def test_jacobian_y_joint_axis():
chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "simple_y_arm.urdf")).read(), "eef")
- th = torch.tensor([0.])
+ th = torch.tensor([0.0])
J = chain.jacobian(th)
- J_c3 = torch.tensor([[[0.], [0.], [-0.3], [0.], [1.], [0.]]])
+ J_c3 = torch.tensor([[[0.0], [0.0], [-0.3], [0.0], [1.0], [0.0]]])
assert torch.allclose(J, J_c3, atol=1e-7)
def test_parallel():
N = 100
- chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(),
- "lbr_iiwa_link_7")
+ chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), "lbr_iiwa_link_7")
th = torch.cat(
- (torch.tensor([[0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]]), torch.rand(N, 7)))
+ (
+ torch.tensor([[0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]]),
+ torch.rand(N, 7),
+ )
+ )
J = chain.jacobian(th)
for i in range(N):
J_i = chain.jacobian(th[i])
@@ -93,8 +122,7 @@ def test_dtype_device():
d = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float64
- chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(),
- "lbr_iiwa_link_7")
+ chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), "lbr_iiwa_link_7")
chain = chain.to(dtype=dtype, device=d)
th = torch.rand(N, 7, dtype=dtype, device=d)
J = chain.jacobian(th)
@@ -106,8 +134,7 @@ def test_gradient():
d = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float64
- chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(),
- "lbr_iiwa_link_7")
+ chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), "lbr_iiwa_link_7")
chain = chain.to(dtype=dtype, device=d)
th = torch.rand(N, 7, dtype=dtype, device=d, requires_grad=True)
J = chain.jacobian(th)
@@ -122,12 +149,12 @@ def test_jacobian_prismatic():
tg = chain.forward_kinematics(th)
m = tg.get_matrix()
pos = m[0, :3, 3]
- assert torch.allclose(pos, torch.tensor([0, 0, 1.]))
+ assert torch.allclose(pos, torch.tensor([0, 0, 1.0]))
th = torch.tensor([0, 0.1, 0])
tg = chain.forward_kinematics(th)
m = tg.get_matrix()
pos = m[0, :3, 3]
- assert torch.allclose(pos, torch.tensor([0, -0.1, 1.]))
+ assert torch.allclose(pos, torch.tensor([0, -0.1, 1.0]))
th = torch.tensor([0.1, 0.1, 0])
tg = chain.forward_kinematics(th)
m = tg.get_matrix()
@@ -140,28 +167,40 @@ def test_jacobian_prismatic():
assert torch.allclose(pos, torch.tensor([0.1, -0.1, 1.1]))
J = chain.jacobian(th)
- assert torch.allclose(J, torch.tensor([[[0., 0., 1.],
- [0., -1., 0.],
- [1., 0., 0.],
- [0., 0., 0.],
- [0., 0., 0.],
- [0., 0., 0.]]]))
+ assert torch.allclose(
+ J,
+ torch.tensor(
+ [
+ [
+ [0.0, 0.0, 1.0],
+ [0.0, -1.0, 0.0],
+ [1.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0],
+ ]
+ ]
+ ),
+ )
def test_comparison_to_autograd():
- chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(),
- "lbr_iiwa_link_7")
+ chain = pk.build_serial_chain_from_urdf(open(os.path.join(TEST_DIR, "kuka_iiwa.urdf")).read(), "lbr_iiwa_link_7")
d = "cuda" if torch.cuda.is_available() else "cpu"
chain = chain.to(device=d)
def get_pt(th):
- return chain.forward_kinematics(th).transform_points(
- torch.zeros((1, 3), device=th.device, dtype=th.dtype)).squeeze(1)
+ return chain.forward_kinematics(th).transform_points(torch.zeros((1, 3), device=th.device, dtype=th.dtype)).squeeze(1)
# compare the time taken
N = 1000
- ths = (torch.tensor([[0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]], device=d),
- torch.rand(N - 1, 7, device=d))
+ ths = (
+ torch.tensor(
+ [[0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]],
+ device=d,
+ ),
+ torch.rand(N - 1, 7, device=d),
+ )
th = torch.cat(ths)
autograd_start = timer()
@@ -181,6 +220,7 @@ def get_pt(th):
# if we have functools (for pytorch>=1.13.0 it comes with installing pytorch)
try:
import functorch
+
ft_start = timer()
grad_func = torch.vmap(functorch.jacrev(get_pt))
j3 = grad_func(th).squeeze(1)
@@ -188,7 +228,7 @@ def get_pt(th):
assert torch.allclose(j1_, j3, atol=1e-6)
assert torch.allclose(j3, j2[:, :3], atol=1e-6)
print(f"for N={N} on {d} functorch:{(ft_end - ft_start) * 1000}ms")
- except:
+ except ImportError:
pass
diff --git a/tests/test_kinematics.py b/tests/test_kinematics.py
index f70acbc..f38c47c 100644
--- a/tests/test_kinematics.py
+++ b/tests/test_kinematics.py
@@ -7,6 +7,7 @@
import pytorch_kinematics as pk
from pytorch_kinematics.transforms.math import quaternion_close
+
TEST_DIR = os.path.dirname(__file__)
@@ -25,45 +26,53 @@ def test_fk_mjcf():
print(chain.get_joint_parameter_names())
th = {joint: 0.0 for joint in chain.get_joint_parameter_names()}
- th.update({'hip_1': 1.0, 'ankle_1': 1})
+ th.update({"hip_1": 1.0, "ankle_1": 1})
ret = chain.forward_kinematics(th)
- tg = ret['aux_1']
+ tg = ret["aux_1"]
pos, rot = quat_pos_from_transform3d(tg)
- assert quaternion_close(rot, torch.tensor([0.87758256, 0., 0., 0.47942554], dtype=torch.float64))
+ assert quaternion_close(rot, torch.tensor([0.87758256, 0.0, 0.0, 0.47942554], dtype=torch.float64))
assert torch.allclose(pos, torch.tensor([0.2, 0.2, 0.75], dtype=torch.float64))
- tg = ret['front_left_foot']
+ tg = ret["front_left_foot"]
pos, rot = quat_pos_from_transform3d(tg)
- assert quaternion_close(rot, torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549], dtype=torch.float64))
+ assert quaternion_close(
+ rot,
+ torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549], dtype=torch.float64),
+ )
assert torch.allclose(pos, torch.tensor([0.13976626, 0.47635466, 0.75], dtype=torch.float64))
print(ret)
def test_fk_serial_mjcf():
- chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "ant.xml")).read(), 'front_left_foot')
+ chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "ant.xml")).read(), "front_left_foot")
chain = chain.to(dtype=torch.float64)
tg = chain.forward_kinematics([1.0, 1.0])
pos, rot = quat_pos_from_transform3d(tg)
- assert quaternion_close(rot, torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549], dtype=torch.float64))
+ assert quaternion_close(
+ rot,
+ torch.tensor([0.77015115, -0.4600326, 0.13497724, 0.42073549], dtype=torch.float64),
+ )
assert torch.allclose(pos, torch.tensor([0.13976626, 0.47635466, 0.75], dtype=torch.float64))
def test_fkik():
- data = '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- ''
- chain = pk.build_serial_chain_from_urdf(data, 'link3')
+ data = (
+ ''
+ ''
+ ''
+ ''
+ ''
+ ''
+ ''
+ ''
+ ""
+ ''
+ ''
+ ''
+ ''
+ ""
+ ""
+ )
+ chain = pk.build_serial_chain_from_urdf(data, "link3")
th1 = torch.tensor([0.42553542, 0.17529176])
tg = chain.forward_kinematics(th1)
pos, rot = quat_pos_from_transform3d(tg)
@@ -92,10 +101,14 @@ def test_urdf():
chain.to(dtype=torch.float64)
th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]
ret = chain.forward_kinematics(th)
- tg = ret['lbr_iiwa_link_7']
+ tg = ret["lbr_iiwa_link_7"]
pos, rot = quat_pos_from_transform3d(tg)
assert quaternion_close(rot, torch.tensor([7.07106781e-01, 0, -7.07106781e-01, 0], dtype=torch.float64))
- assert torch.allclose(pos, torch.tensor([-6.60827561e-01, 0, 3.74142136e-01], dtype=torch.float64), atol=1e-6)
+ assert torch.allclose(
+ pos,
+ torch.tensor([-6.60827561e-01, 0, 3.74142136e-01], dtype=torch.float64),
+ atol=1e-6,
+ )
def test_urdf_serial():
@@ -104,10 +117,14 @@ def test_urdf_serial():
th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0, 0.0]
ret = chain.forward_kinematics(th, end_only=False)
- tg = ret['lbr_iiwa_link_7']
+ tg = ret["lbr_iiwa_link_7"]
pos, rot = quat_pos_from_transform3d(tg)
assert quaternion_close(rot, torch.tensor([7.07106781e-01, 0, -7.07106781e-01, 0], dtype=torch.float64))
- assert torch.allclose(pos, torch.tensor([-6.60827561e-01, 0, 3.74142136e-01], dtype=torch.float64), atol=1e-6)
+ assert torch.allclose(
+ pos,
+ torch.tensor([-6.60827561e-01, 0, 3.74142136e-01], dtype=torch.float64),
+ atol=1e-6,
+ )
N = 1000
d = "cuda" if torch.cuda.is_available() else "cpu"
@@ -126,18 +143,18 @@ def test_urdf_serial():
def _fk_parallel():
tg_batch = chain.forward_kinematics(th_batch)
- m = tg_batch.get_matrix()
+ tg_batch.get_matrix()
dt_parallel = timeit(_fk_parallel, number=number) / number
- print("elapsed {}s for N={} when parallel".format(dt_parallel, N))
+ print(f"elapsed {dt_parallel}s for N={N} when parallel")
def _fk_serial():
for i in range(N):
tg = chain.forward_kinematics(th_batch[i])
- m = tg.get_matrix()
+ tg.get_matrix()
dt_serial = timeit(_fk_serial, number=number) / number
- print("elapsed {}s for N={} when serial".format(dt_serial, N))
+ print(f"elapsed {dt_serial}s for N={N} when serial")
# assert torch.allclose(tg.get_matrix().view(4, 4), m[i])
@@ -148,29 +165,31 @@ def test_fk_simple_arm():
chain = chain.to(dtype=torch.float64)
# print(chain)
# print(chain.get_joint_parameter_names())
- ret = chain.forward_kinematics({
- 'arm_shoulder_pan_joint': 0.,
- 'arm_elbow_pan_joint': math.pi / 2.0,
- 'arm_wrist_lift_joint': -0.5,
- 'arm_wrist_roll_joint': 0.,
- })
- tg = ret['arm_wrist_roll']
+ ret = chain.forward_kinematics(
+ {
+ "arm_shoulder_pan_joint": 0.0,
+ "arm_elbow_pan_joint": math.pi / 2.0,
+ "arm_wrist_lift_joint": -0.5,
+ "arm_wrist_roll_joint": 0.0,
+ }
+ )
+ tg = ret["arm_wrist_roll"]
pos, rot = quat_pos_from_transform3d(tg)
- assert quaternion_close(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=torch.float64))
+ assert quaternion_close(rot, torch.tensor([0.70710678, 0.0, 0.0, 0.70710678], dtype=torch.float64))
assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=torch.float64))
N = 100
ret = chain.forward_kinematics({k: torch.rand(N) for k in chain.get_joint_parameter_names()})
- tg = ret['arm_wrist_roll']
+ tg = ret["arm_wrist_roll"]
assert list(tg.get_matrix().shape) == [N, 4, 4]
def test_sdf_serial_chain():
- chain = pk.build_serial_chain_from_sdf(open(os.path.join(TEST_DIR, "simple_arm.sdf")).read(), 'arm_wrist_roll')
+ chain = pk.build_serial_chain_from_sdf(open(os.path.join(TEST_DIR, "simple_arm.sdf")).read(), "arm_wrist_roll")
chain = chain.to(dtype=torch.float64)
- tg = chain.forward_kinematics([0., math.pi / 2.0, -0.5, 0.])
+ tg = chain.forward_kinematics([0.0, math.pi / 2.0, -0.5, 0.0])
pos, rot = quat_pos_from_transform3d(tg)
- assert quaternion_close(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=torch.float64))
+ assert quaternion_close(rot, torch.tensor([0.70710678, 0.0, 0.0, 0.70710678], dtype=torch.float64))
assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=torch.float64))
@@ -187,33 +206,37 @@ def test_cuda():
chain = pk.build_chain_from_sdf(open(os.path.join(TEST_DIR, "simple_arm.sdf")).read())
chain = chain.to(dtype=dtype, device=d)
- ret = chain.forward_kinematics({
- 'arm_shoulder_pan_joint': 0,
- 'arm_elbow_pan_joint': math.pi / 2.0,
- 'arm_wrist_lift_joint': -0.5,
- 'arm_wrist_roll_joint': 0,
- })
- tg = ret['arm_wrist_roll']
+ ret = chain.forward_kinematics(
+ {
+ "arm_shoulder_pan_joint": 0,
+ "arm_elbow_pan_joint": math.pi / 2.0,
+ "arm_wrist_lift_joint": -0.5,
+ "arm_wrist_roll_joint": 0,
+ }
+ )
+ tg = ret["arm_wrist_roll"]
pos, rot = quat_pos_from_transform3d(tg)
- assert quaternion_close(rot, torch.tensor([0.70710678, 0., 0., 0.70710678], dtype=dtype, device=d))
+ assert quaternion_close(rot, torch.tensor([0.70710678, 0.0, 0.0, 0.70710678], dtype=dtype, device=d))
assert torch.allclose(pos, torch.tensor([1.05, 0.55, 0.5], dtype=dtype, device=d))
- data = '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- '' \
- ''
- chain = pk.build_serial_chain_from_urdf(data, 'link3')
+ data = (
+ ''
+ ''
+ ''
+ ''
+ ''
+ ''
+ ''
+ ''
+ ""
+ ''
+ ''
+ ''
+ ''
+ ""
+ ""
+ )
+ chain = pk.build_serial_chain_from_urdf(data, "link3")
chain = chain.to(dtype=dtype, device=d)
N = 20
th_batch = torch.rand(N, 2).to(device=d, dtype=dtype)
@@ -247,7 +270,7 @@ def test_fk_val():
chain = pk.build_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read())
chain = chain.to(dtype=torch.float64)
ret = chain.forward_kinematics(torch.zeros([1000, chain.n_joints], dtype=torch.float64))
- tg = ret['drive45']
+ tg = ret["drive45"]
pos, rot = quat_pos_from_transform3d(tg)
torch.set_printoptions(precision=6, sci_mode=False)
assert quaternion_close(rot, torch.tensor([0.5, 0.5, -0.5, 0.5], dtype=torch.float64))
@@ -256,45 +279,47 @@ def test_fk_val():
def test_fk_partial_batched_dict():
# Test that you can pass in dict of batched joint configs for a subset of the joints
- chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read(), 'left_tool')
+ chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read(), "left_tool")
th = {
- 'joint56': torch.zeros([1000], dtype=torch.float64),
- 'joint57': torch.zeros([1000], dtype=torch.float64),
- 'joint41': torch.zeros([1000], dtype=torch.float64),
- 'joint42': torch.zeros([1000], dtype=torch.float64),
- 'joint43': torch.zeros([1000], dtype=torch.float64),
- 'joint44': torch.zeros([1000], dtype=torch.float64),
- 'joint45': torch.zeros([1000], dtype=torch.float64),
- 'joint46': torch.zeros([1000], dtype=torch.float64),
- 'joint47': torch.zeros([1000], dtype=torch.float64),
+ "joint56": torch.zeros([1000], dtype=torch.float64),
+ "joint57": torch.zeros([1000], dtype=torch.float64),
+ "joint41": torch.zeros([1000], dtype=torch.float64),
+ "joint42": torch.zeros([1000], dtype=torch.float64),
+ "joint43": torch.zeros([1000], dtype=torch.float64),
+ "joint44": torch.zeros([1000], dtype=torch.float64),
+ "joint45": torch.zeros([1000], dtype=torch.float64),
+ "joint46": torch.zeros([1000], dtype=torch.float64),
+ "joint47": torch.zeros([1000], dtype=torch.float64),
}
chain = chain.to(dtype=torch.float64)
- tg = chain.forward_kinematics(th)
+ chain.forward_kinematics(th)
def test_fk_partial_batched():
# Test that you can pass in dict of batched joint configs for a subset of the joints
- chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read(), 'left_tool')
+ chain = pk.build_serial_chain_from_mjcf(open(os.path.join(TEST_DIR, "val.xml")).read(), "left_tool")
th = torch.zeros([1000, 9], dtype=torch.float64)
chain = chain.to(dtype=torch.float64)
- tg = chain.forward_kinematics(th)
+ chain.forward_kinematics(th)
def test_ur5_fk():
urdf = os.path.join(TEST_DIR, "ur5.urdf")
- pk_chain = pk.build_serial_chain_from_urdf(open(urdf).read(), 'ee_link', 'base_link')
+ pk_chain = pk.build_serial_chain_from_urdf(open(urdf).read(), "ee_link", "base_link")
th = [0.0, -math.pi / 4.0, 0.0, math.pi / 2.0, 0.0, math.pi / 4.0]
try:
import ikpy.chain
- ik_chain = ikpy.chain.Chain.from_urdf_file(urdf,
- active_links_mask=[False, True, True, True, True, True, True, False])
+
+ ik_chain = ikpy.chain.Chain.from_urdf_file(urdf, active_links_mask=[False, True, True, True, True, True, True, False])
ik_ret = ik_chain.forward_kinematics([0, *th, 0])
except ImportError:
- ik_ret = [[-6.44330720e-18, 3.58979314e-09, -1.00000000e+00, 5.10955359e-01],
- [1.00000000e+00, 1.79489651e-09, 0.00000000e+00, 1.91450000e-01],
- [1.79489651e-09, -1.00000000e+00, -3.58979312e-09, 6.00114361e-01],
- [0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00]]
+ ik_ret = [
+ [-6.44330720e-18, 3.58979314e-09, -1.00000000e00, 5.10955359e-01],
+ [1.00000000e00, 1.79489651e-09, 0.00000000e00, 1.91450000e-01],
+ [1.79489651e-09, -1.00000000e00, -3.58979312e-09, 6.00114361e-01],
+ [0.00000000e00, 0.00000000e00, 0.00000000e00, 1.00000000e00],
+ ]
ret = pk_chain.forward_kinematics(th, end_only=True)
print(ret.get_matrix())
diff --git a/tests/test_menagerie.py b/tests/test_menagerie.py
index 6d7ae4b..6e10825 100644
--- a/tests/test_menagerie.py
+++ b/tests/test_menagerie.py
@@ -5,28 +5,29 @@
import pytorch_kinematics as pk
+
# Find all files named "scene*.xml" in the "mujoco_menagerie" directory
-_MENAGERIE_ROOT = pathlib.Path(__file__).parent / 'mujoco_menagerie'
+_MENAGERIE_ROOT = pathlib.Path(__file__).parent / "mujoco_menagerie"
_XMLS_AND_BODIES = {
# 'agility_cassie/scene.xml': 'cassie-pelvis', # not supported because it has a ball joint
- 'anybotics_anymal_b/scene.xml': 'base',
- 'anybotics_anymal_c/scene.xml': 'base',
- 'franka_emika_panda/scene.xml': 'link0',
- 'google_barkour_v0/scene.xml': 'chassis',
- 'google_barkour_v0/scene_barkour.xml': 'chassis',
+ "anybotics_anymal_b/scene.xml": "base",
+ "anybotics_anymal_c/scene.xml": "base",
+ "franka_emika_panda/scene.xml": "link0",
+ "google_barkour_v0/scene.xml": "chassis",
+ "google_barkour_v0/scene_barkour.xml": "chassis",
# 'hello_robot_stretch/scene.xml': 'base_link', # not supported because it has composite joints
- 'kuka_iiwa_14/scene.xml': 'base',
- 'rethink_robotics_sawyer/scene.xml': 'base',
- 'robotiq_2f85/scene.xml': 'base_mount',
- 'robotis_op3/scene.xml': 'body_link',
- 'shadow_hand/scene_left.xml': 'lh_forearm',
- 'shadow_hand/scene_right.xml': 'rh_forearm',
- 'ufactory_xarm7/scene.xml': 'link_base',
- 'unitree_a1/scene.xml': 'trunk',
- 'unitree_go1/scene.xml': 'trunk',
- 'universal_robots_ur5e/scene.xml': 'base',
- 'wonik_allegro/scene_left.xml': 'palm',
- 'wonik_allegro/scene_right.xml': 'palm',
+ "kuka_iiwa_14/scene.xml": "base",
+ "rethink_robotics_sawyer/scene.xml": "base",
+ "robotiq_2f85/scene.xml": "base_mount",
+ "robotis_op3/scene.xml": "body_link",
+ "shadow_hand/scene_left.xml": "lh_forearm",
+ "shadow_hand/scene_right.xml": "rh_forearm",
+ "ufactory_xarm7/scene.xml": "link_base",
+ "unitree_a1/scene.xml": "trunk",
+ "unitree_go1/scene.xml": "trunk",
+ "universal_robots_ur5e/scene.xml": "base",
+ "wonik_allegro/scene_left.xml": "palm",
+ "wonik_allegro/scene_right.xml": "palm",
}
@@ -36,7 +37,7 @@ def test_menagerie():
xml_dir = xml_filename.parent
# Menagerie files assume the current working directory is the directory of the scene.xml
os.chdir(xml_dir)
- with xml_filename.open('r') as f:
+ with xml_filename.open("r") as f:
xml = f.read()
chain = pk.build_chain_from_mjcf(xml, body)
print(xml_filename)
@@ -44,8 +45,8 @@ def test_menagerie():
print(f"\t {chain.get_frame_names()}")
print(f"\t {chain.get_joint_parameter_names()}")
th = np.zeros(len(chain.get_joint_parameter_names()))
- fk_dict = chain.forward_kinematics(th)
+ chain.forward_kinematics(th)
-if __name__ == '__main__':
+if __name__ == "__main__":
test_menagerie()
diff --git a/tests/test_rotation_conversions.py b/tests/test_rotation_conversions.py
index 46cbb25..b76002e 100644
--- a/tests/test_rotation_conversions.py
+++ b/tests/test_rotation_conversions.py
@@ -4,36 +4,42 @@
import torch
from pytorch_kinematics.transforms.math import quaternion_close
-from pytorch_kinematics.transforms.rotation_conversions import axis_and_angle_to_matrix_33, axis_angle_to_matrix, \
- pos_rot_to_matrix, matrix_to_pos_rot, random_rotations, quaternion_from_euler
+from pytorch_kinematics.transforms.rotation_conversions import (
+ axis_and_angle_to_matrix_33,
+ axis_angle_to_matrix,
+ matrix_to_pos_rot,
+ pos_rot_to_matrix,
+ quaternion_from_euler,
+ random_rotations,
+)
def test_axis_angle_to_matrix_perf():
number = 100
N = 1_000
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
+ device = "cuda" if torch.cuda.is_available() else "cpu"
axis_angle = torch.randn([N, 3], device=device, dtype=torch.float64)
- axis_1d = torch.tensor([1., 0, 0], device=device, dtype=torch.float64) # in the FK code this is NOT batched!
+ axis_1d = torch.tensor([1.0, 0, 0], device=device, dtype=torch.float64) # in the FK code this is NOT batched!
theta = axis_angle.norm(dim=1, keepdim=True)
dt1 = timeit.timeit(lambda: axis_angle_to_matrix(axis_angle), number=number)
- print(f'Old method: {dt1:.5f}')
+ print(f"Old method: {dt1:.5f}")
dt2 = timeit.timeit(lambda: axis_and_angle_to_matrix_33(axis=axis_1d, theta=theta), number=number)
- print(f'New method: {dt2:.5f}')
+ print(f"New method: {dt2:.5f}")
def test_quaternion_not_close():
# ensure it returns false for quaternions that are far apart
- q1 = torch.tensor([1., 0, 0, 0])
- q2 = torch.tensor([0., 1, 0, 0])
+ q1 = torch.tensor([1.0, 0, 0, 0])
+ q2 = torch.tensor([0.0, 1, 0, 0])
assert not quaternion_close(q1, q2)
def test_quaternion_from_euler():
- q = quaternion_from_euler(torch.tensor([0., 0, 0]))
- assert quaternion_close(q, torch.tensor([1., 0, 0, 0]))
+ q = quaternion_from_euler(torch.tensor([0.0, 0, 0]))
+ assert quaternion_close(q, torch.tensor([1.0, 0, 0, 0]))
root2_over_2 = np.sqrt(2) / 2
q = quaternion_from_euler(torch.tensor([0, 0, np.pi / 2]))
@@ -67,6 +73,6 @@ def test_pos_rot_conversion():
assert torch.allclose(T, TT, atol=1e-6)
-if __name__ == '__main__':
+if __name__ == "__main__":
test_axis_angle_to_matrix_perf()
test_pos_rot_conversion()
diff --git a/tests/test_serial_chain_creation.py b/tests/test_serial_chain_creation.py
index 057d54c..c09e7a0 100644
--- a/tests/test_serial_chain_creation.py
+++ b/tests/test_serial_chain_creation.py
@@ -1,10 +1,8 @@
import os
-from timeit import default_timer as timer
-
-import torch
import pytorch_kinematics as pk
+
TEST_DIR = os.path.dirname(__file__)
diff --git a/tests/test_transform.py b/tests/test_transform.py
index ffdcc1b..06f4acb 100644
--- a/tests/test_transform.py
+++ b/tests/test_transform.py
@@ -1,7 +1,7 @@
import torch
-import pytorch_kinematics.transforms as tf
import pytorch_kinematics as pk
+import pytorch_kinematics.transforms as tf
def test_transform():
@@ -15,18 +15,17 @@ def test_transform():
assert torch.allclose(mats, mats_recovered)
quat_identity = tf.quaternion_multiply(quat, tf.quaternion_invert(quat))
- assert torch.allclose(tf.quaternion_to_matrix(quat_identity), torch.eye(3, dtype=torch.float64).repeat(N, 1, 1))
+ assert torch.allclose(
+ tf.quaternion_to_matrix(quat_identity),
+ torch.eye(3, dtype=torch.float64).repeat(N, 1, 1),
+ )
def test_translations():
t = tf.Translate(1, 2, 3)
- points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
- 1, 3, 3
- )
+ points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(1, 3, 3)
points_out = t.transform_points(points)
- points_out_expected = torch.tensor(
- [[2.0, 2.0, 3.0], [1.0, 3.0, 3.0], [1.5, 2.5, 3.0]]
- ).view(1, 3, 3)
+ points_out_expected = torch.tensor([[2.0, 2.0, 3.0], [1.0, 3.0, 3.0], [1.5, 2.5, 3.0]]).view(1, 3, 3)
assert torch.allclose(points_out, points_out_expected)
N = 20
@@ -41,20 +40,12 @@ def test_translations():
def test_rotate_axis_angle():
t = tf.Transform3d().rotate_axis_angle(90.0, axis="Z")
- points = torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]]).view(
- 1, 3, 3
- )
- normals = torch.tensor(
- [[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]
- ).view(1, 3, 3)
+ points = torch.tensor([[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 1.0]]).view(1, 3, 3)
+ normals = torch.tensor([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]).view(1, 3, 3)
points_out = t.transform_points(points)
normals_out = t.transform_normals(normals)
- points_out_expected = torch.tensor(
- [[0.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 1.0]]
- ).view(1, 3, 3)
- normals_out_expected = torch.tensor(
- [[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]
- ).view(1, 3, 3)
+ points_out_expected = torch.tensor([[0.0, 0.0, 0.0], [-1.0, 0.0, 0.0], [-1.0, 0.0, 1.0]]).view(1, 3, 3)
+ normals_out_expected = torch.tensor([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]).view(1, 3, 3)
assert torch.allclose(points_out, points_out_expected)
assert torch.allclose(normals_out, normals_out_expected)
@@ -62,12 +53,8 @@ def test_rotate_axis_angle():
def test_rotate():
R = tf.so3_exp_map(torch.randn((1, 3)))
t = tf.Transform3d().rotate(R)
- points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
- 1, 3, 3
- )
- normals = torch.tensor(
- [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
- ).view(1, 3, 3)
+ points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(1, 3, 3)
+ normals = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]).view(1, 3, 3)
points_out = t.transform_points(points)
normals_out = t.transform_normals(normals)
points_out_expected = torch.bmm(points, R.transpose(-1, -2))
@@ -96,18 +83,25 @@ def test_transform_combined():
def test_euler():
euler_angles = torch.tensor([1, 0, 0.5])
t = tf.Transform3d(rot=euler_angles)
- sxyz_matrix = torch.tensor([[0.87758256, -0.47942554, 0., 0., ],
- [0.25903472, 0.47415988, -0.84147098, 0.],
- [0.40342268, 0.73846026, 0.54030231, 0.],
- [0., 0., 0., 1.]])
- # from tf.transformations import euler_matrix
- # print(euler_matrix(*euler_angles, "rxyz"))
- # print(t.get_matrix())
+ sxyz_matrix = torch.tensor(
+ [
+ [
+ 0.87758256,
+ -0.47942554,
+ 0.0,
+ 0.0,
+ ],
+ [0.25903472, 0.47415988, -0.84147098, 0.0],
+ [0.40342268, 0.73846026, 0.54030231, 0.0],
+ [0.0, 0.0, 0.0, 1.0],
+ ]
+ )
assert torch.allclose(sxyz_matrix, t.get_matrix())
def test_quaternions():
import pytorch_seed
+
pytorch_seed.seed(0)
n = 10
@@ -143,6 +137,7 @@ def test_quaternions():
def test_compose():
import torch
+
theta = 1.5707
a2b = tf.Transform3d(pos=[0.1, 0, 0]) # joint.offset
b2j = tf.Transform3d(rot=tf.axis_angle_to_quaternion(theta * torch.tensor([0.0, 0, 1]))) # joint.axis
diff --git a/tests/ur5.urdf b/tests/ur5.urdf
index 59242bc..761d900 100644
--- a/tests/ur5.urdf
+++ b/tests/ur5.urdf
@@ -16,7 +16,7 @@
5
power_state
10.0
- 87.78
+ 87.78
-474
525
15.52
@@ -302,4 +302,3 @@
-
diff --git a/tests/widowx/README.md b/tests/widowx/README.md
index 0e6d916..94acd4f 100644
--- a/tests/widowx/README.md
+++ b/tests/widowx/README.md
@@ -1,5 +1,3 @@
# WidowX 250 S Robot Description (URDF)
The robot model here is based on the real2sim project: https://github.com/simpler-env/SimplerEnv
-
-