diff --git a/docs/tutorials/custom_robot_import.md b/docs/tutorials/custom_robot_import.md index d8606e815..87800952f 100644 --- a/docs/tutorials/custom_robot_import.md +++ b/docs/tutorials/custom_robot_import.md @@ -134,8 +134,8 @@ Now that we have the USD file for the robot, let's write our own robot class. Fo raise ValueError("Stretch does not support discrete actions!") @property - def controller_order(self): - # Controller ordering. Usually determined by general robot kinematics chain + def _raw_controller_order(self): + # Raw controller ordering. Usually determined by general robot kinematics chain # You can usually simply take a subset of these based on the type of robot interfaces inherited for your robot class return ["base", "camera", f"arm_{self.default_arm}", f"gripper_{self.default_arm}"] diff --git a/omnigibson/configs/fetch_behavior.yaml b/omnigibson/configs/fetch_behavior.yaml index 5932b0d7e..cfde1f7d0 100644 --- a/omnigibson/configs/fetch_behavior.yaml +++ b/omnigibson/configs/fetch_behavior.yaml @@ -45,8 +45,6 @@ robots: action_normalize: true action_type: continuous grasping_mode: physical - rigid_trunk: false - default_trunk_offset: 0.365 default_arm_pose: diagonal30 default_reset_mode: tuck sensor_config: @@ -61,8 +59,11 @@ robots: controller_config: base: name: DifferentialDriveController + trunk: + name: JointController arm_0: name: InverseKinematicsController + subsume_controllers: [trunk] gripper_0: name: MultiFingerGripperController mode: binary diff --git a/omnigibson/configs/fetch_primitives.yaml b/omnigibson/configs/fetch_primitives.yaml index e4b2adab2..09978bd0d 100644 --- a/omnigibson/configs/fetch_primitives.yaml +++ b/omnigibson/configs/fetch_primitives.yaml @@ -38,14 +38,15 @@ robots: action_normalize: false action_type: continuous grasping_mode: sticky - rigid_trunk: false - default_trunk_offset: 0.365 default_arm_pose: diagonal30 controller_config: base: name: DifferentialDriveController + trunk: + name: JointController arm_0: name: JointController + subsume_controllers: [trunk] motor_type: position command_input_limits: null use_delta_commands: false diff --git a/omnigibson/configs/robots/fetch.yaml b/omnigibson/configs/robots/fetch.yaml index 9d8054e17..bb1a53b1d 100644 --- a/omnigibson/configs/robots/fetch.yaml +++ b/omnigibson/configs/robots/fetch.yaml @@ -16,8 +16,6 @@ robot: scale: 1.0 self_collision: true grasping_mode: physical - rigid_trunk: false - default_trunk_offset: 0.365 default_arm_pose: vertical sensor_config: VisionSensor: @@ -31,8 +29,11 @@ robot: controller_config: base: name: DifferentialDriveController + trunk: + name: JointController arm_0: name: InverseKinematicsController + subsume_controllers: [trunk] gripper_0: name: MultiFingerGripperController camera: diff --git a/omnigibson/configs/tiago_primitives.yaml b/omnigibson/configs/tiago_primitives.yaml index 44af0a663..9b03932a0 100644 --- a/omnigibson/configs/tiago_primitives.yaml +++ b/omnigibson/configs/tiago_primitives.yaml @@ -38,8 +38,6 @@ robots: action_normalize: false action_type: continuous grasping_mode: sticky - rigid_trunk: false - default_trunk_offset: 0.15 default_arm_pose: vertical sensor_config: VisionSensor: @@ -49,6 +47,8 @@ robots: controller_config: base: name: JointController + trunk: + name: JointController arm_left: name: JointController motor_type: position @@ -57,6 +57,7 @@ robots: use_delta_commands: false arm_right: name: JointController + subsume_controllers: [trunk] motor_type: position command_input_limits: null command_output_limits: null diff --git a/omnigibson/controllers/ik_controller.py b/omnigibson/controllers/ik_controller.py index 07fd4d160..841e5c3e9 100644 --- a/omnigibson/controllers/ik_controller.py +++ b/omnigibson/controllers/ik_controller.py @@ -38,9 +38,6 @@ class InverseKinematicsController(JointController, ManipulationController): def __init__( self, task_name, - robot_description_path, - robot_urdf_path, - eef_name, control_freq, reset_joint_pos, control_limits, @@ -63,9 +60,6 @@ def __init__( task_name (str): name assigned to this task frame for computing IK control. During control calculations, the inputted control_dict should include entries named <@task_name>_pos_relative and <@task_name>_quat_relative. See self._command_to_control() for what these values should entail. - robot_description_path (str): path to robot descriptor yaml file - robot_urdf_path (str): path to robot urdf file - eef_name (str): end effector frame name control_freq (int): controller loop frequency reset_joint_pos (Array[float]): reset joint positions, used as part of nullspace controller in IK. Note that this should correspond to ALL the joints; the exact indices will be extracted via @dof_idx diff --git a/omnigibson/objects/controllable_object.py b/omnigibson/objects/controllable_object.py index 691a7e0ae..dc996f99b 100644 --- a/omnigibson/objects/controllable_object.py +++ b/omnigibson/objects/controllable_object.py @@ -231,10 +231,44 @@ def _load_controllers(self): # Initialize controllers to create self._controllers = dict() - # Loop over all controllers, in the order corresponding to @action dim - for name in self.controller_order: + # Keep track of any controllers that are subsumed by other controllers + # We will not instantiate subsumed controllers + controller_subsumes = dict() # Maps independent controller name to list of subsumed controllers + subsume_names = set() + for name in self._raw_controller_order: + # Make sure we have the valid controller name specified assert_valid_key(key=name, valid_keys=self._controller_config, name="controller name") cfg = self._controller_config[name] + subsume_controllers = cfg.pop("subsume_controllers", []) + # If this controller subsumes other controllers, it cannot be subsumed by another controller + # (i.e.: we don't allow nested / cyclical subsuming) + if len(subsume_controllers) > 0: + assert ( + name not in subsume_names + ), f"Controller {name} subsumes other controllers, and therefore cannot be subsumed by another controller!" + controller_subsumes[name] = subsume_controllers + for subsume_name in subsume_controllers: + # Make sure it doesn't already exist -- a controller should only be subsumed by up to one other + assert ( + subsume_name not in subsume_names + ), f"Controller {subsume_name} cannot be subsumed by more than one other controller!" + assert ( + subsume_name not in controller_subsumes + ), f"Controller {name} subsumes other controllers, and therefore cannot be subsumed by another controller!" + subsume_names.add(subsume_name) + + # Loop over all controllers, in the order corresponding to @action dim + for name in self._raw_controller_order: + # If this controller is subsumed by another controller, simply skip it + if name in subsume_names: + continue + cfg = self._controller_config[name] + # If we subsume other controllers, prepend the subsumed' dof idxs to this controller's idxs + if name in controller_subsumes: + for subsumed_name in controller_subsumes[name]: + subsumed_cfg = self._controller_config[subsumed_name] + cfg["dof_idx"] = th.concatenate([subsumed_cfg["dof_idx"], cfg["dof_idx"]]) + # If we're using normalized action space, override the inputs for all controllers if self._action_normalize: cfg["command_input_limits"] = "default" # default is normalized (-1, 1) @@ -254,8 +288,12 @@ def update_controller_mode(self): Helper function to force the joints to use the internal specified control mode and gains """ # Update the control modes of each joint based on the outputted control from the controllers + unused_dofs = {i for i in range(self.n_dof)} for name in self._controllers: for dof in self._controllers[name].dof_idx: + # Make sure the DOF has not already been set yet, and remove it afterwards + assert dof.item() in unused_dofs + unused_dofs.remove(dof.item()) control_type = self._controllers[name].control_type self._joints[self.dof_names_ordered[dof]].set_control_type( control_type=control_type, @@ -267,6 +305,21 @@ def update_controller_mode(self): ), ) + # For all remaining DOFs not controlled, we assume these are free DOFs (e.g.: virtual joints representing free + # motion wrt a specific axis), so explicitly set kp / kd to 0 to avoid silent bugs when + # joint positions / velocities are set + for unused_dof in unused_dofs: + unused_joint = self._joints[self.dof_names_ordered[unused_dof]] + assert not unused_joint.driven, ( + f"All unused joints not mapped to any controller should not have DriveAPI attached to it! " + f"However, joint {unused_joint.name} is driven!" + ) + unused_joint.set_control_type( + control_type=ControlType.NONE, + kp=None, + kd=None, + ) + def _generate_controller_config(self, custom_config=None): """ Generates a fully-populated controller config, overriding any default values with the corresponding values @@ -283,7 +336,7 @@ def _generate_controller_config(self, custom_config=None): controller_config = {} if custom_config is None else deepcopy(custom_config) # Update the configs - for group in self.controller_order: + for group in self._raw_controller_order: group_controller_name = ( controller_config[group]["name"] if group in controller_config and "name" in controller_config[group] @@ -623,6 +676,41 @@ def get_control_dict(self): return fcns + def _add_task_frame_control_dict(self, fcns, task_name, link_name): + """ + Internally helper function to generate per-link control dictionary entries. Useful for generating relevant + control values needed for IK / OSC for a given @task_name. Should be called within @get_control_dict() + + Args: + fcns (CachedFunctions): Keyword-mapped control values for this object, mapping names to n-arrays. + task_name (str): name to assign for this task_frame. It will be prepended to all fcns generated + link_name (str): the corresponding link name from this controllable object that @task_name is referencing + """ + fcns[f"_{task_name}_pos_quat_relative"] = ( + lambda: ControllableObjectViewAPI.get_link_relative_position_orientation( + self.articulation_root_path, link_name + ) + ) + fcns[f"{task_name}_pos_relative"] = lambda: fcns[f"_{task_name}_pos_quat_relative"][0] + fcns[f"{task_name}_quat_relative"] = lambda: fcns[f"_{task_name}_pos_quat_relative"][1] + fcns[f"{task_name}_lin_vel_relative"] = lambda: ControllableObjectViewAPI.get_link_relative_linear_velocity( + self.articulation_root_path, link_name + ) + fcns[f"{task_name}_ang_vel_relative"] = lambda: ControllableObjectViewAPI.get_link_relative_angular_velocity( + self.articulation_root_path, link_name + ) + # -n_joints because there may be an additional 6 entries at the beginning of the array, if this robot does + # not have a fixed base (i.e.: the 6DOF --> "floating" joint) + # see self.get_relative_jacobian() for more info + # We also count backwards for the link frame because if the robot is fixed base, the jacobian returned has one + # less index than the number of links. This is presumably because the 1st link of a fixed base robot will + # always have a zero jacobian since it can't move. Counting backwards resolves this issue. + start_idx = 0 if self.fixed_base else 6 + link_idx = self._articulation_view.get_body_index(link_name) + fcns[f"{task_name}_jacobian_relative"] = lambda: ControllableObjectViewAPI.get_relative_jacobian( + self.articulation_root_path + )[-(self.n_links - link_idx), :, start_idx : start_idx + self.n_joints] + def dump_action(self): """ Dump the last action applied to this object. For use in demo collection. @@ -755,13 +843,27 @@ def controllers(self): return self._controllers @property - @abstractmethod def controller_order(self): """ Returns: list: Ordering of the actions, corresponding to the controllers. e.g., ["base", "arm", "gripper"], to denote that the action vector should be interpreted as first the base action, then arm command, then - gripper command + gripper command. Note that this may be a subset of all possible controllers due to some controllers + subsuming others (e.g.: arm controller subsuming the trunk controller if using IK) + """ + assert self._controllers is not None, "Can only view controller_order after controllers are loaded!" + return list(self._controllers.keys()) + + @property + @abstractmethod + def _raw_controller_order(self): + """ + Returns: + list: Raw ordering of the actions, corresponding to the controllers. e.g., ["base", "arm", "gripper"], + to denote that the action vector should be interpreted as first the base action, then arm command, then + gripper command. Note that external users should query @controller_order, which is the post-processed + ordering of actions, which may be a subset of the controllers due to some controllers subsuming others + (e.g.: arm controller subsuming the trunk controller if using IK) """ raise NotImplementedError diff --git a/omnigibson/prims/entity_prim.py b/omnigibson/prims/entity_prim.py index 18e5d4d9c..6676097fc 100644 --- a/omnigibson/prims/entity_prim.py +++ b/omnigibson/prims/entity_prim.py @@ -1612,7 +1612,6 @@ def _dump_state(self): if self.n_joints > 0: state["joint_pos"] = self.get_joint_positions() state["joint_vel"] = self.get_joint_velocities() - state["joint_eff"] = self.get_joint_efforts() # We do NOT save joint pos / vel targets because this is only relevant for motorized joints (e.g.: robots). # Such control (a) only relies on the joint state, and not joint targets, when computing control, and @@ -1633,7 +1632,6 @@ def _load_state(self, state): elif self.n_joints > 0: self.set_joint_positions(state["joint_pos"]) self.set_joint_velocities(state["joint_vel"]) - self.set_joint_efforts(state["joint_eff"]) # Make sure this object is awake self.wake() @@ -1646,7 +1644,6 @@ def serialize(self, state): state_flat += [ state["joint_pos"], state["joint_vel"], - state["joint_eff"], ] return th.cat(state_flat) @@ -1657,7 +1654,7 @@ def deserialize(self, state): root_link_state, idx = self.root_link.deserialize(state=state) state_dict = dict(root_link=root_link_state) if self.n_joints > 0: - for jnt_state in ("pos", "vel", "eff"): + for jnt_state in ("pos", "vel"): state_dict[f"joint_{jnt_state}"] = state[idx : idx + self.n_joints] idx += self.n_joints diff --git a/omnigibson/prims/geom_prim.py b/omnigibson/prims/geom_prim.py index 0c12beada..7b03dad9e 100644 --- a/omnigibson/prims/geom_prim.py +++ b/omnigibson/prims/geom_prim.py @@ -117,7 +117,7 @@ def opacity(self, opacity): if self.has_material(): self.material.opacity_constant = opacity else: - self.set_attribute("primvars:displayOpacity", th.tensor([opacity])) + self.set_attribute("primvars:displayOpacity", np.array([opacity])) @property def points(self): diff --git a/omnigibson/prims/joint_prim.py b/omnigibson/prims/joint_prim.py index 92831a50b..901e978be 100644 --- a/omnigibson/prims/joint_prim.py +++ b/omnigibson/prims/joint_prim.py @@ -173,11 +173,12 @@ def update_handles(self): def set_control_type(self, control_type, kp=None, kd=None): """ - Sets the control type for this joint. + Sets the control type for this joint. Note that ControlType.NONE is equivalent to + ControlType.EFFORT with 0 kp / kd Args: control_type (ControlType): What type of control to use for this joint. - Valid options are: {ControlType.POSITION, ControlType.VELOCITY, ControlType.EFFORT} + Valid options are: {ControlType.POSITION, ControlType.VELOCITY, ControlType.EFFORT, ControlType.NONE} kp (None or float): If specified, sets the kp gain value for this joint. Should only be set if setting ControlType.POSITION kd (None or float): If specified, sets the kd gain value for this joint. Should only be set if @@ -194,7 +195,7 @@ def set_control_type(self, control_type, kp=None, kd=None): assert kp is None, "kp gain must not be specified for setting VELOCITY control!" assert kd is not None, "kd gain must be specified for setting VELOCITY control!" kp = 0.0 - else: # Efforts + else: # Efforts (or NONE -- equivalent) assert kp is None, "kp gain must not be specified for setting EFFORT control!" assert kd is None, "kd gain must not be specified for setting EFFORT control!" kp, kd = 0.0, 0.0 @@ -867,12 +868,11 @@ def keep_still(self): self.set_effort(th.zeros(self.n_dof)) def _dump_state(self): - pos, vel, effort = self.get_state() if self.articulated else (th.empty(0), th.empty(0), th.empty(0)) + pos, vel, _ = self.get_state() if self.articulated else (th.empty(0), th.empty(0), th.empty(0)) target_pos, target_vel = self.get_target() if self.articulated else (th.empty(0), th.empty(0)) return dict( pos=pos, vel=vel, - effort=effort, target_pos=target_pos, target_vel=target_vel, ) @@ -881,8 +881,6 @@ def _load_state(self, state): if self.articulated: self.set_pos(state["pos"], drive=False) self.set_vel(state["vel"], drive=False) - if self.driven: - self.set_effort(state["effort"]) if self._control_type == ControlType.POSITION: self.set_pos(state["target_pos"], drive=True) elif self._control_type == ControlType.VELOCITY: @@ -893,7 +891,6 @@ def serialize(self, state): [ state["pos"], state["vel"], - state["effort"], state["target_pos"], state["target_vel"], ] @@ -905,9 +902,8 @@ def deserialize(self, state): dict( pos=state[0 : self.n_dof], vel=state[self.n_dof : 2 * self.n_dof], - effort=state[2 * self.n_dof : 3 * self.n_dof], - target_pos=state[3 * self.n_dof : 4 * self.n_dof], - target_vel=state[4 * self.n_dof : 5 * self.n_dof], + target_pos=state[2 * self.n_dof : 3 * self.n_dof], + target_vel=state[3 * self.n_dof : 4 * self.n_dof], ), - 5 * self.n_dof, + 4 * self.n_dof, ) diff --git a/omnigibson/robots/a1.py b/omnigibson/robots/a1.py index b226bd3e5..1b3e233c2 100644 --- a/omnigibson/robots/a1.py +++ b/omnigibson/robots/a1.py @@ -149,14 +149,14 @@ def _create_discrete_action_space(self): raise ValueError("A1 does not support discrete actions!") @property - def controller_order(self): - return ["arm_{}".format(self.default_arm), "gripper_{}".format(self.default_arm)] + def _raw_controller_order(self): + return [f"arm_{self.default_arm}", f"gripper_{self.default_arm}"] @property def _default_controllers(self): controllers = super()._default_controllers - controllers["arm_{}".format(self.default_arm)] = "InverseKinematicsController" - controllers["gripper_{}".format(self.default_arm)] = "MultiFingerGripperController" + controllers[f"arm_{self.default_arm}"] = "InverseKinematicsController" + controllers[f"gripper_{self.default_arm}"] = "MultiFingerGripperController" return controllers @property diff --git a/omnigibson/robots/active_camera_robot.py b/omnigibson/robots/active_camera_robot.py index 66b92f191..cb56e4c77 100644 --- a/omnigibson/robots/active_camera_robot.py +++ b/omnigibson/robots/active_camera_robot.py @@ -52,7 +52,7 @@ def default_proprio_obs(self): return obs_keys + ["camera_qpos_sin", "camera_qpos_cos"] @property - def controller_order(self): + def _raw_controller_order(self): # By default, only camera is supported return ["camera"] diff --git a/omnigibson/robots/articulated_trunk_robot.py b/omnigibson/robots/articulated_trunk_robot.py index 628cd6cbd..bac1cc950 100644 --- a/omnigibson/robots/articulated_trunk_robot.py +++ b/omnigibson/robots/articulated_trunk_robot.py @@ -1,5 +1,3 @@ -from abc import abstractmethod - import torch as th from omnigibson.robots.manipulation_robot import ManipulationRobot @@ -26,111 +24,25 @@ class ArticulatedTrunkRobot(ManipulationRobot): values specified, but setting these individual kwargs will override them """ - def __init__( - self, - # Shared kwargs in hierarchy - name, - relative_prim_path=None, - scale=None, - visible=True, - visual_only=False, - self_collisions=True, - load_config=None, - fixed_base=False, - # Unique to USDObject hierarchy - abilities=None, - # Unique to ControllableObject hierarchy - control_freq=None, - controller_config=None, - action_type="continuous", - action_normalize=True, - reset_joint_pos=None, - # Unique to BaseRobot - obs_modalities=("rgb", "proprio"), - proprio_obs="default", - sensor_config=None, - # Unique to ManipulationRobot - grasping_mode="physical", - disable_grasp_handling=False, - # Unique to ArticulatedTrunkRobot - rigid_trunk=False, - default_trunk_offset=0.2, - **kwargs, - ): - """ - Args: - name (str): Name for the object. Names need to be unique per scene - relative_prim_path (str): Scene-local prim path of the Prim to encapsulate or create. - scale (None or float or 3-array): if specified, sets either the uniform (float) or x,y,z (3-array) scale - for this object. A single number corresponds to uniform scaling along the x,y,z axes, whereas a - 3-array specifies per-axis scaling. - visible (bool): whether to render this object or not in the stage - visual_only (bool): Whether this object should be visual only (and not collide with any other objects) - self_collisions (bool): Whether to enable self collisions for this object - load_config (None or dict): If specified, should contain keyword-mapped values that are relevant for - loading this prim at runtime. - abilities (None or dict): If specified, manually adds specific object states to this object. It should be - a dict in the form of {ability: {param: value}} containing object abilities and parameters to pass to - the object state instance constructor. - control_freq (float): control frequency (in Hz) at which to control the object. If set to be None, - we will automatically set the control frequency to be at the render frequency by default. - controller_config (None or dict): nested dictionary mapping controller name(s) to specific controller - configurations for this object. This will override any default values specified by this class. - action_type (str): one of {discrete, continuous} - what type of action space to use - action_normalize (bool): whether to normalize inputted actions. This will override any default values - specified by this class. - reset_joint_pos (None or n-array): if specified, should be the joint positions that the object should - be set to during a reset. If None (default), self._default_joint_pos will be used instead. - Note that _default_joint_pos are hardcoded & precomputed, and thus should not be modified by the user. - Set this value instead if you want to initialize the robot with a different rese joint position. - obs_modalities (str or list of str): Observation modalities to use for this robot. Default is ["rgb", "proprio"]. - Valid options are "all", or a list containing any subset of omnigibson.sensors.ALL_SENSOR_MODALITIES. - Note: If @sensor_config explicitly specifies `modalities` for a given sensor class, it will - override any values specified from @obs_modalities! - proprio_obs (str or list of str): proprioception observation key(s) to use for generating proprioceptive - observations. If str, should be exactly "default" -- this results in the default proprioception - observations being used, as defined by self.default_proprio_obs. See self._get_proprioception_dict - for valid key choices - sensor_config (None or dict): nested dictionary mapping sensor class name(s) to specific sensor - configurations for this object. This will override any default values specified by this class. - grasping_mode (str): One of {"physical", "assisted", "sticky"}. - If "physical", no assistive grasping will be applied (relies on contact friction + finger force). - If "assisted", will magnetize any object touching and within the gripper's fingers. - If "sticky", will magnetize any object touching the gripper's fingers. - disable_grasp_handling (bool): If True, will disable all grasp handling for this object. This means that - sticky and assisted grasp modes will not work unless the connection/release methodsare manually called. - rigid_trunk (bool): If True, will prevent the trunk from moving during execution. - default_trunk_offset (float): The default height of the robot's trunk - kwargs (dict): Additional keyword arguments that are used for other super() calls from subclasses, allowing - for flexible compositions of various object subclasses (e.g.: Robot is USDObject + ControllableObject). - """ - self.rigid_trunk = rigid_trunk - self.default_trunk_offset = default_trunk_offset - - # Run super init - super().__init__( - relative_prim_path=relative_prim_path, - name=name, - scale=scale, - visible=visible, - fixed_base=fixed_base, - visual_only=visual_only, - self_collisions=self_collisions, - load_config=load_config, - abilities=abilities, - control_freq=control_freq, - controller_config=controller_config, - action_type=action_type, - action_normalize=action_normalize, - reset_joint_pos=reset_joint_pos, - obs_modalities=obs_modalities, - proprio_obs=proprio_obs, - sensor_config=sensor_config, - grasping_mode=grasping_mode, - disable_grasp_handling=disable_grasp_handling, - **kwargs, + def get_control_dict(self): + # In addition to super method, add in trunk endpoint state + fcns = super().get_control_dict() + + # Add relevant trunk values + self._add_task_frame_control_dict( + fcns=fcns, task_name="trunk", link_name=self.joints[self.trunk_joint_names[-1]].body1.split("/")[-1] ) + return fcns + + @property + def trunk_links(self): + return [self.links[name] for name in self.trunk_link_names] + + @property + def trunk_link_names(self): + raise NotImplementedError + @property def trunk_joint_names(self): raise NotImplementedError("trunk_joint_names must be implemented in subclass") @@ -143,21 +55,110 @@ def trunk_control_idx(self): """ return th.tensor([list(self.joints.keys()).index(name) for name in self.trunk_joint_names]) + @property + def trunk_action_idx(self): + controller_idx = self.controller_order.index("trunk") + action_start_idx = sum([self.controllers[self.controller_order[i]].command_dim for i in range(controller_idx)]) + return th.arange(action_start_idx, action_start_idx + self.controllers["trunk"].command_dim) + + @property + def _default_controllers(self): + # Always call super first + controllers = super()._default_controllers + + # For best generalizability use, joint controller as default + controllers["trunk"] = "JointController" + + return controllers + + @property + def _default_trunk_ik_controller_config(self): + """ + Returns: + dict: Default controller config for an Inverse kinematics controller to control this robot's trunk + """ + return { + "name": "InverseKinematicsController", + "task_name": "trunk", + "control_freq": self._control_freq, + "reset_joint_pos": self.reset_joint_pos, + "control_limits": self.control_limits, + "dof_idx": self.trunk_control_idx, + "command_output_limits": ( + th.tensor([-0.2, -0.2, -0.2, -0.5, -0.5, -0.5]), + th.tensor([0.2, 0.2, 0.2, 0.5, 0.5, 0.5]), + ), + "mode": "pose_delta_ori", + "smoothing_filter_size": 2, + "workspace_pose_limiter": None, + } + + @property + def _default_trunk_osc_controller_config(self): + """ + Returns: + dict: Default controller config for an Operational Space controller to control this robot's trunk + """ + return { + "name": "OperationalSpaceController", + "task_name": "trunk", + "control_freq": self._control_freq, + "reset_joint_pos": self.reset_joint_pos, + "control_limits": self.control_limits, + "dof_idx": self.trunk_control_idx, + "command_output_limits": ( + th.tensor([-0.2, -0.2, -0.2, -0.5, -0.5, -0.5]), + th.tensor([0.2, 0.2, 0.2, 0.5, 0.5, 0.5]), + ), + "mode": "pose_delta_ori", + "workspace_pose_limiter": None, + } + + @property + def _default_trunk_joint_controller_config(self): + """ + Returns: + dict: Default base joint controller config to control this robot's base. Uses position + control by default. + """ + return { + "name": "JointController", + "control_freq": self._control_freq, + "motor_type": "position", + "control_limits": self.control_limits, + "dof_idx": self.trunk_control_idx, + "command_output_limits": None, + "use_delta_commands": True, + } + + @property + def _default_trunk_null_joint_controller_config(self): + """ + Returns: + dict: Default null joint controller config to control this robot's base i.e. dummy controller + """ + return { + "name": "NullJointController", + "control_freq": self._control_freq, + "motor_type": "position", + "control_limits": self.control_limits, + "dof_idx": self.trunk_control_idx, + "default_command": self.reset_joint_pos[self.trunk_control_idx], + "use_impedances": False, + } + @property def _default_controller_config(self): - # Grab defaults from super method first + # Always run super method first cfg = super()._default_controller_config - if self.rigid_trunk: - return cfg - - # Need to override joint idx being controlled to include trunk in default arm controller configs - for arm_cfg in cfg[f"arm_{self.default_arm}"].values(): - arm_control_idx = th.cat([self.trunk_control_idx, self.arm_control_idx[self.default_arm]]) - arm_cfg["dof_idx"] = arm_control_idx - # Need to modify the default joint positions also if this is a null joint controller - if arm_cfg["name"] == "NullJointController": - arm_cfg["default_command"] = self.reset_joint_pos[arm_control_idx] + # Add supported base controllers + cfg["trunk"] = { + self._default_trunk_joint_controller_config["name"]: self._default_trunk_joint_controller_config, + self._default_trunk_null_joint_controller_config["name"]: self._default_trunk_null_joint_controller_config, + self._default_trunk_ik_controller_config["name"]: self._default_trunk_ik_controller_config, + self._default_trunk_osc_controller_config["name"]: self._default_trunk_osc_controller_config, + } return cfg diff --git a/omnigibson/robots/behavior_robot.py b/omnigibson/robots/behavior_robot.py index 8c5369e16..b41aa0cf3 100644 --- a/omnigibson/robots/behavior_robot.py +++ b/omnigibson/robots/behavior_robot.py @@ -209,7 +209,7 @@ def _default_joint_pos(self): return th.zeros(self.n_joints) @property - def controller_order(self): + def _raw_controller_order(self): controllers = ["base", "camera"] for arm_name in self.arm_names: controllers += [f"arm_{arm_name}", f"gripper_{arm_name}"] diff --git a/omnigibson/robots/fetch.py b/omnigibson/robots/fetch.py index ec706bc36..0fd6496b2 100644 --- a/omnigibson/robots/fetch.py +++ b/omnigibson/robots/fetch.py @@ -50,9 +50,6 @@ def __init__( # Unique to ManipulationRobot grasping_mode="physical", disable_grasp_handling=False, - # Unique to ArticulatedTrunkRobot - rigid_trunk=False, - default_trunk_offset=0.2, # Unique to MobileManipulationRobot default_reset_mode="untuck", # Unique to UntuckedArmPoseRobot @@ -102,8 +99,6 @@ def __init__( If "sticky", will magnetize any object touching the gripper's fingers. disable_grasp_handling (bool): If True, will disable all grasp handling for this object. This means that sticky and assisted grasp modes will not work unless the connection/release methodsare manually called. - rigid_trunk (bool): If True, will prevent the trunk from moving during execution. - default_trunk_offset (float): The default height of the robot's trunk default_reset_mode (str): Default reset mode for the robot. Should be one of: {"tuck", "untuck"} If reset_joint_pos is not None, this will be ignored (since _default_joint_pos won't be used during initialization). default_arm_pose (str): Default pose for the robot arm. Should be one of: @@ -134,8 +129,6 @@ def __init__( sensor_config=sensor_config, grasping_mode=grasping_mode, disable_grasp_handling=disable_grasp_handling, - rigid_trunk=rigid_trunk, - default_trunk_offset=default_trunk_offset, default_reset_mode=default_reset_mode, default_arm_pose=default_arm_pose, **kwargs, @@ -166,7 +159,7 @@ def tucked_default_joint_pos(self): def untucked_default_joint_pos(self): pos = super().untucked_default_joint_pos pos[self.base_control_idx] = 0.0 - pos[self.trunk_control_idx] = 0.02 + self.default_trunk_offset + pos[self.trunk_control_idx] = 0.385 pos[self.camera_control_idx] = th.tensor([0.0, 0.45]) pos[self.gripper_control_idx[self.default_arm]] = th.tensor([0.05, 0.05]) # open gripper return pos @@ -207,9 +200,9 @@ def _create_discrete_action_space(self): raise ValueError("Fetch does not support discrete actions!") @property - def controller_order(self): + def _raw_controller_order(self): # Ordered by general robot kinematics chain - return ["base", "camera", "arm_{}".format(self.default_arm), "gripper_{}".format(self.default_arm)] + return ["base", "trunk", "camera", f"arm_{self.default_arm}", f"gripper_{self.default_arm}"] @property def _default_controllers(self): @@ -218,9 +211,10 @@ def _default_controllers(self): # We use multi finger gripper, differential drive, and IK controllers as default controllers["base"] = "DifferentialDriveController" + controllers["trunk"] = "JointController" controllers["camera"] = "JointController" - controllers["arm_{}".format(self.default_arm)] = "InverseKinematicsController" - controllers["gripper_{}".format(self.default_arm)] = "MultiFingerGripperController" + controllers[f"arm_{self.default_arm}"] = "InverseKinematicsController" + controllers[f"gripper_{self.default_arm}"] = "MultiFingerGripperController" return controllers diff --git a/omnigibson/robots/franka.py b/omnigibson/robots/franka.py index 18f55d49a..76211f691 100644 --- a/omnigibson/robots/franka.py +++ b/omnigibson/robots/franka.py @@ -213,14 +213,14 @@ def _create_discrete_action_space(self): raise ValueError("Franka does not support discrete actions!") @property - def controller_order(self): - return ["arm_{}".format(self.default_arm), "gripper_{}".format(self.default_arm)] + def _raw_controller_order(self): + return [f"arm_{self.default_arm}", f"gripper_{self.default_arm}"] @property def _default_controllers(self): controllers = super()._default_controllers - controllers["arm_{}".format(self.default_arm)] = "InverseKinematicsController" - controllers["gripper_{}".format(self.default_arm)] = "MultiFingerGripperController" + controllers[f"arm_{self.default_arm}"] = "InverseKinematicsController" + controllers[f"gripper_{self.default_arm}"] = "MultiFingerGripperController" return controllers @property diff --git a/omnigibson/robots/franka_mounted.py b/omnigibson/robots/franka_mounted.py index 89206f96b..566720833 100644 --- a/omnigibson/robots/franka_mounted.py +++ b/omnigibson/robots/franka_mounted.py @@ -11,14 +11,14 @@ class FrankaMounted(FrankaPanda): """ @property - def controller_order(self): - return ["arm_{}".format(self.default_arm), "gripper_{}".format(self.default_arm)] + def _raw_controller_order(self): + return [f"arm_{self.default_arm}", f"gripper_{self.default_arm}"] @property def _default_controllers(self): controllers = super()._default_controllers - controllers["arm_{}".format(self.default_arm)] = "InverseKinematicsController" - controllers["gripper_{}".format(self.default_arm)] = "MultiFingerGripperController" + controllers[f"arm_{self.default_arm}"] = "InverseKinematicsController" + controllers[f"gripper_{self.default_arm}"] = "MultiFingerGripperController" return controllers @property diff --git a/omnigibson/robots/holonomic_base_robot.py b/omnigibson/robots/holonomic_base_robot.py index da8d9ac2a..0796bb8af 100644 --- a/omnigibson/robots/holonomic_base_robot.py +++ b/omnigibson/robots/holonomic_base_robot.py @@ -282,11 +282,18 @@ def _postprocess_control(self, control, control_type): base_orn = self.base_footprint_link.get_position_orientation()[1] root_link_orn = self.root_link.get_position_orientation()[1] - cur_orn = T.mat2quat(T.quat2mat(root_link_orn).T @ T.quat2mat(base_orn)) + cur_orn_mat = T.quat2mat(root_link_orn).T @ T.quat2mat(base_orn) + cur_pose = th.zeros((2, 4, 4)) + cur_pose[:, :3, :3] = cur_orn_mat + cur_pose[:, 3, 3] = 1.0 + + local_pose = th.zeros((2, 4, 4)) + local_pose[:] = th.eye(4) + local_pose[:, :3, 3] = u_vec[self.base_idx].view(2, 3) # Rotate the linear and angular velocity to the desired frame - lin_vel_global, _ = T.pose_transform(th.zeros(3), cur_orn, u_vec[self.base_idx[:3]], th.tensor([0, 0, 0, 1])) - ang_vel_global, _ = T.pose_transform(th.zeros(3), cur_orn, u_vec[self.base_idx[3:]], th.tensor([0, 0, 0, 1])) + global_pose = cur_pose @ local_pose + lin_vel_global, ang_vel_global = global_pose[0, :3, 3], global_pose[1, :3, 3] u_vec[self.base_control_idx] = th.tensor([lin_vel_global[0], lin_vel_global[1], ang_vel_global[2]]) diff --git a/omnigibson/robots/locomotion_robot.py b/omnigibson/robots/locomotion_robot.py index b613d7ef4..3860d5564 100644 --- a/omnigibson/robots/locomotion_robot.py +++ b/omnigibson/robots/locomotion_robot.py @@ -56,7 +56,7 @@ def default_proprio_obs(self): return obs_keys + ["base_qpos_sin", "base_qpos_cos", "robot_lin_vel", "robot_ang_vel"] @property - def controller_order(self): + def _raw_controller_order(self): # By default, only base is supported return ["base"] @@ -184,6 +184,14 @@ def turn_right(self, delta=0.03): quat = quat_multiply((euler2quat(delta, 0, 0)), quat) self.set_position_orientation(orientation=quat) + @property + def base_links(self): + return [self.links[name] for name in self.base_link_names] + + @property + def base_link_names(self): + raise NotImplementedError + @property def base_action_idx(self): controller_idx = self.controller_order.index("base") diff --git a/omnigibson/robots/manipulation_robot.py b/omnigibson/robots/manipulation_robot.py index 527334b12..098049064 100644 --- a/omnigibson/robots/manipulation_robot.py +++ b/omnigibson/robots/manipulation_robot.py @@ -404,44 +404,10 @@ def get_control_dict(self): fcns = super().get_control_dict() for arm in self.arm_names: - self._add_arm_control_dict(fcns=fcns, arm=arm) + self._add_task_frame_control_dict(fcns=fcns, task_name=f"eef_{arm}", link_name=self.eef_link_names[arm]) return fcns - def _add_arm_control_dict(self, fcns, arm): - """ - Internally helper function to generate per-arm control dictionary entries. Needed because otherwise generated - functions inadvertently point to the same arm, if directly iterated in a for loop! - - Args: - fcns (CachedFunctions): Keyword-mapped control values for this object, mapping names to n-arrays. - arm (str): specific arm to generate necessary control dict entries for - """ - fcns[f"_eef_{arm}_pos_quat_relative"] = ( - lambda: ControllableObjectViewAPI.get_link_relative_position_orientation( - self.articulation_root_path, self.eef_link_names[arm] - ) - ) - fcns[f"eef_{arm}_pos_relative"] = lambda: fcns[f"_eef_{arm}_pos_quat_relative"][0] - fcns[f"eef_{arm}_quat_relative"] = lambda: fcns[f"_eef_{arm}_pos_quat_relative"][1] - fcns[f"eef_{arm}_lin_vel_relative"] = lambda: ControllableObjectViewAPI.get_link_relative_linear_velocity( - self.articulation_root_path, self.eef_link_names[arm] - ) - fcns[f"eef_{arm}_ang_vel_relative"] = lambda: ControllableObjectViewAPI.get_link_relative_angular_velocity( - self.articulation_root_path, self.eef_link_names[arm] - ) - # -n_joints because there may be an additional 6 entries at the beginning of the array, if this robot does - # not have a fixed base (i.e.: the 6DOF --> "floating" joint) - # see self.get_relative_jacobian() for more info - # We also count backwards for the link frame because if the robot is fixed base, the jacobian returned has one - # less index than the number of links. This is presumably because the 1st link of a fixed base robot will - # always have a zero jacobian since it can't move. Counting backwards resolves this issue. - start_idx = 0 if self.fixed_base else 6 - eef_link_idx = self._articulation_view.get_body_index(self.eef_links[arm].body_name) - fcns[f"eef_{arm}_jacobian_relative"] = lambda: ControllableObjectViewAPI.get_relative_jacobian( - self.articulation_root_path - )[-(self.n_links - eef_link_idx), :, start_idx : start_idx + self.n_joints] - def _get_proprioception_dict(self): dic = super()._get_proprioception_dict() @@ -492,7 +458,7 @@ def grasping_mode(self): return self._grasping_mode @property - def controller_order(self): + def _raw_controller_order(self): # Assumes we have arm(s) and corresponding gripper(s) controllers = [] for arm in self.arm_names: @@ -1037,9 +1003,6 @@ def _default_arm_ik_controller_configs(self): dic[arm] = { "name": "InverseKinematicsController", "task_name": f"eef_{arm}", - "robot_description_path": self.robot_arm_descriptor_yamls[arm], - "robot_urdf_path": self.urdf_path, - "eef_name": self.eef_link_names[arm], "control_freq": self._control_freq, "reset_joint_pos": self.reset_joint_pos, "control_limits": self.control_limits, diff --git a/omnigibson/robots/r1.py b/omnigibson/robots/r1.py index b687fd7d4..c5d6980ea 100644 --- a/omnigibson/robots/r1.py +++ b/omnigibson/robots/r1.py @@ -44,8 +44,6 @@ def __init__( # Unique to ManipulationRobot grasping_mode="physical", disable_grasp_handling=False, - # Unique to ArticulatedTrunkRobot - rigid_trunk=True, # Unique to MobileManipulationRobot default_reset_mode="untuck", **kwargs, @@ -92,7 +90,6 @@ def __init__( If "sticky", will magnetize any object touching the gripper's fingers. disable_grasp_handling (bool): If True, will disable all grasp handling for this object. This means that sticky and assisted grasp modes will not work unless the connection/release methodsare manually called. - rigid_trunk (bool): If True, will prevent the trunk from moving during execution. default_reset_mode (str): Default reset mode for the robot. Should be one of: {"tuck", "untuck"} If reset_joint_pos is not None, this will be ignored (since _default_joint_pos won't be used during initialization). kwargs (dict): Additional keyword arguments that are used for other super() calls from subclasses, allowing @@ -119,7 +116,6 @@ def __init__( sensor_config=sensor_config, grasping_mode=grasping_mode, disable_grasp_handling=disable_grasp_handling, - rigid_trunk=rigid_trunk, default_trunk_offset=0.0, # not applicable for R1 default_reset_mode=default_reset_mode, **kwargs, @@ -139,8 +135,8 @@ def _create_discrete_action_space(self): raise ValueError("R1 does not support discrete actions!") @property - def controller_order(self): - controllers = ["base"] + def _raw_controller_order(self): + controllers = ["base", "trunk"] for arm in self.arm_names: controllers += [f"arm_{arm}", f"gripper_{arm}"] return controllers @@ -150,6 +146,7 @@ def _default_controllers(self): controllers = super()._default_controllers # We use joint controllers for base as default controllers["base"] = "JointController" + controllers["trunk"] = "JointController" # We use IK and multi finger gripper controllers as default for arm in self.arm_names: controllers["arm_{}".format(arm)] = "InverseKinematicsController" @@ -196,6 +193,14 @@ def assisted_grasp_end_points(self): for arm in self.arm_names } + @property + def base_link_names(self): + return ["base_link"] # , "wheel_link1", "wheel_link2", "wheel_link3"] + + @property + def trunk_link_names(self): + return ["torso_link1", "torso_link2", "torso_link3", "torso_link4"] + @property def trunk_joint_names(self): return [f"torso_joint{i}" for i in range(1, 5)] @@ -210,7 +215,7 @@ def arm_names(cls): @property def arm_link_names(self): - return {arm: [f"{arm}_arm_link{i}" for i in range(1, 3)] for arm in self.arm_names} + return {arm: [f"{arm}_arm_link{i}" for i in range(1, 7)] for arm in self.arm_names} @property def arm_joint_names(self): diff --git a/omnigibson/robots/stretch.py b/omnigibson/robots/stretch.py index bb48eabc0..b2c5c38fc 100644 --- a/omnigibson/robots/stretch.py +++ b/omnigibson/robots/stretch.py @@ -39,7 +39,7 @@ def _create_discrete_action_space(self): raise ValueError("Stretch does not support discrete actions!") @property - def controller_order(self): + def _raw_controller_order(self): # Ordered by general robot kinematics chain return ["base", "camera", f"arm_{self.default_arm}", f"gripper_{self.default_arm}"] diff --git a/omnigibson/robots/tiago.py b/omnigibson/robots/tiago.py index 970d6c2fe..2bde8a748 100644 --- a/omnigibson/robots/tiago.py +++ b/omnigibson/robots/tiago.py @@ -48,9 +48,6 @@ def __init__( # Unique to ManipulationRobot grasping_mode="physical", disable_grasp_handling=False, - # Unique to ArticulatedTrunkRobot - rigid_trunk=False, - default_trunk_offset=0.2, # Unique to MobileManipulationRobot default_reset_mode="untuck", # Unique to UntuckedArmPoseRobot @@ -101,8 +98,6 @@ def __init__( If "sticky", will magnetize any object touching the gripper's fingers. disable_grasp_handling (bool): If True, will disable all grasp handling for this object. This means that sticky and assisted grasp modes will not work unless the connection/release methodsare manually called. - rigid_trunk (bool): If True, will prevent the trunk from moving during execution. - default_trunk_offset (float): The default height of the robot's trunk default_reset_mode (str): Default reset mode for the robot. Should be one of: {"tuck", "untuck"} If reset_joint_pos is not None, this will be ignored (since _default_joint_pos won't be used during initialization). default_arm_pose (str): Default pose for the robot arm. Should be one of: @@ -138,8 +133,6 @@ def __init__( sensor_config=sensor_config, grasping_mode=grasping_mode, disable_grasp_handling=disable_grasp_handling, - rigid_trunk=rigid_trunk, - default_trunk_offset=default_trunk_offset, default_reset_mode=default_reset_mode, default_arm_pose=default_arm_pose, **kwargs, @@ -177,7 +170,7 @@ def untucked_default_joint_pos(self): pos = super().untucked_default_joint_pos # Keep the current joint positions for the base joints pos[self.base_idx] = self.get_joint_positions()[self.base_idx] - pos[self.trunk_control_idx] = 0.02 + self.default_trunk_offset + pos[self.trunk_control_idx] = 0.17 pos[self.camera_control_idx] = th.tensor([0.0, -0.45]) for arm in self.arm_names: pos[self.gripper_control_idx[arm]] = th.tensor([0.045, 0.045]) # open gripper @@ -215,8 +208,8 @@ def base_footprint_link_name(self): return "base_footprint" @property - def controller_order(self): - controllers = ["base", "camera"] + def _raw_controller_order(self): + controllers = ["base", "trunk", "camera"] for arm in self.arm_names: controllers += ["arm_{}".format(arm), "gripper_{}".format(arm)] @@ -228,6 +221,7 @@ def _default_controllers(self): controllers = super()._default_controllers # We use joint controllers for base and camera as default controllers["base"] = "JointController" + controllers["trunk"] = "JointController" controllers["camera"] = "JointController" # We use multi finger gripper, and IK controllers for eefs as default for arm in self.arm_names: diff --git a/omnigibson/robots/vx300s.py b/omnigibson/robots/vx300s.py index a27ea44f2..4dabc1ba4 100644 --- a/omnigibson/robots/vx300s.py +++ b/omnigibson/robots/vx300s.py @@ -115,7 +115,7 @@ def _create_discrete_action_space(self): raise ValueError("VX300S does not support discrete actions!") @property - def controller_order(self): + def _raw_controller_order(self): return [f"arm_{self.default_arm}", f"gripper_{self.default_arm}"] @property diff --git a/omnigibson/utils/python_utils.py b/omnigibson/utils/python_utils.py index 6d654d374..301d68191 100644 --- a/omnigibson/utils/python_utils.py +++ b/omnigibson/utils/python_utils.py @@ -581,6 +581,7 @@ class CachedFunctions: def __init__(self, **kwargs): # Create internal dict to store functions self._fcns = dict() + self._cache = dict() for kwarg in kwargs: self._fcns[kwarg] = kwargs[kwarg] @@ -590,20 +591,20 @@ def __getitem__(self, item): def __setitem__(self, key, value): self.add_fcn(name=key, fcn=value) - def get(self, name, *args, **kwargs): + def get(self, name): """ Computes the function referenced by @name with the corresponding @args and @kwargs. Note that for a unique set of arguments, this value will be internally cached Args: name (str): The name of the function to call - *args (tuple): Positional arguments to pass into the function call - **kwargs (tuple): Keyword arguments to pass into the function call Returns: any: Output of the function referenced by @name """ - return self._fcns[name](*args, **kwargs) + if name not in self._cache: + self._cache[name] = self._fcns[name]() + return self._cache[name] def get_fcn(self, name): """ diff --git a/omnigibson/utils/transform_utils.py b/omnigibson/utils/transform_utils.py index 3503554ee..2bbdfc598 100644 --- a/omnigibson/utils/transform_utils.py +++ b/omnigibson/utils/transform_utils.py @@ -7,10 +7,10 @@ import math from typing import List, Optional, Tuple -import torch as th +import torch PI = math.pi -EPS = th.finfo(th.float32).eps * 4.0 +EPS = torch.finfo(torch.float32).eps * 4.0 # map axes strings to/from tuples of inner axis, parity, repetition, frame _AXES2TUPLE = { @@ -41,27 +41,27 @@ } -@th.jit.script +@torch.compile def copysign(a, b): - # type: (float, th.Tensor) -> th.Tensor - a = th.tensor(a, device=b.device, dtype=th.float).repeat(b.shape[0]) - return th.abs(a) * th.sign(b) + # type: (float, torch.Tensor) -> torch.Tensor + a = torch.tensor(a, device=b.device, dtype=torch.float).repeat(b.shape[0]) + return torch.abs(a) * torch.sign(b) -@th.jit.script -def anorm(x: th.Tensor, dim: Optional[int] = None, keepdim: bool = False) -> th.Tensor: +@torch.compile +def anorm(x: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False) -> torch.Tensor: """Compute L2 norms along specified axes.""" - return th.norm(x, dim=dim, keepdim=keepdim) + return torch.norm(x, dim=dim, keepdim=keepdim) -@th.jit.script -def normalize(v: th.Tensor, dim: Optional[int] = None, eps: float = 1e-10) -> th.Tensor: +@torch.compile +def normalize(v: torch.Tensor, dim: Optional[int] = None, eps: float = 1e-10) -> torch.Tensor: """L2 Normalize along specified axes.""" norm = anorm(v, dim=dim, keepdim=True) - return v / th.where(norm < eps, th.full_like(norm, eps), norm) + return v / torch.where(norm < eps, torch.full_like(norm, eps), norm) -@th.jit.script +@torch.compile def dot(v1, v2, dim=-1, keepdim=False): """ Computes dot product between two vectors along the provided dim, optionally keeping the dimension @@ -76,30 +76,30 @@ def dot(v1, v2, dim=-1, keepdim=False): tensor: (..., [1,] ...) dot product of vectors, with optional dimension kept if @keepdim is True """ # type: (Tensor, Tensor, int, bool) -> Tensor - return th.sum(v1 * v2, dim=dim, keepdim=keepdim) + return torch.sum(v1 * v2, dim=dim, keepdim=keepdim) -@th.jit.script -def unit_vector(data: th.Tensor, dim: Optional[int] = None, out: Optional[th.Tensor] = None) -> th.Tensor: +@torch.compile +def unit_vector(data: torch.Tensor, dim: Optional[int] = None, out: Optional[torch.Tensor] = None) -> torch.Tensor: """ Returns tensor normalized by length, i.e. Euclidean norm, along axis. Args: - data (th.Tensor): data to normalize + data (torch.Tensor): data to normalize dim (Optional[int]): If specified, determines specific dimension along data to normalize - out (Optional[th.Tensor]): If specified, will store computation in this variable + out (Optional[torch.Tensor]): If specified, will store computation in this variable Returns: - th.Tensor: Normalized vector + torch.Tensor: Normalized vector """ if out is None: - if not isinstance(data, th.Tensor): - data = th.tensor(data, dtype=th.float32) + if not isinstance(data, torch.Tensor): + data = torch.tensor(data, dtype=torch.float32) else: - data = data.clone().to(th.float32) + data = data.clone().to(torch.float32) if data.ndim == 1: - return data / th.sqrt(th.dot(data, data)) + return data / torch.sqrt(torch.dot(data, data)) else: if out is not data: out.copy_(data) @@ -108,21 +108,21 @@ def unit_vector(data: th.Tensor, dim: Optional[int] = None, out: Optional[th.Ten if dim is None: dim = -1 - length = th.sum(data * data, dim=dim, keepdim=True).sqrt() + length = torch.sum(data * data, dim=dim, keepdim=True).sqrt() data = data / (length + 1e-8) # Add small epsilon to avoid division by zero return data -@th.jit.script -def quat_apply(quat: th.Tensor, vec: th.Tensor) -> th.Tensor: +@torch.compile +def quat_apply(quat: torch.Tensor, vec: torch.Tensor) -> torch.Tensor: """ Apply a quaternion rotation to a vector (equivalent to R.from_quat(x).apply(y)) Args: - quat (th.Tensor): (4,) or (N, 4) or (N, 1, 4) quaternion in (x, y, z, w) format - vec (th.Tensor): (3,) or (M, 3) or (1, M, 3) vector to rotate + quat (torch.Tensor): (4,) or (N, 4) or (N, 1, 4) quaternion in (x, y, z, w) format + vec (torch.Tensor): (3,) or (M, 3) or (1, M, 3) vector to rotate Returns: - th.Tensor: (M, 3) or (N, M, 3) rotated vector + torch.Tensor: (M, 3) or (N, M, 3) rotated vector """ assert quat.shape[-1] == 4, "Quaternion must have 4 components in last dimension" assert vec.shape[-1] == 3, "Vector must have 3 components in last dimension" @@ -143,7 +143,7 @@ def quat_apply(quat: th.Tensor, vec: th.Tensor) -> th.Tensor: qx, qy, qz, qw = quat.unbind(-1) # Compute the quaternion multiplication - t = th.stack( + t = torch.stack( [ 2 * (qy * vec[..., 2] - qz * vec[..., 1]), 2 * (qz * vec[..., 0] - qx * vec[..., 2]), @@ -153,50 +153,50 @@ def quat_apply(quat: th.Tensor, vec: th.Tensor) -> th.Tensor: ) # Compute the final rotated vector - result = vec + qw.unsqueeze(-1) * t + th.cross(quat[..., :3], t, dim=-1) + result = vec + qw.unsqueeze(-1) * t + torch.cross(quat[..., :3], t, dim=-1) # Remove any extra dimensions return result.squeeze() -@th.jit.script -def convert_quat(q: th.Tensor, to: str = "xyzw") -> th.Tensor: +@torch.compile +def convert_quat(q: torch.Tensor, to: str = "xyzw") -> torch.Tensor: """ Converts quaternion from one convention to another. The convention to convert TO is specified as an optional argument. If to == 'xyzw', then the input is in 'wxyz' format, and vice-versa. Args: - q (th.Tensor): a 4-dim array corresponding to a quaternion + q (torch.Tensor): a 4-dim array corresponding to a quaternion to (str): either 'xyzw' or 'wxyz', determining which convention to convert to. Returns: - th.Tensor: The converted quaternion + torch.Tensor: The converted quaternion """ if to == "xyzw": - return th.stack([q[1], q[2], q[3], q[0]], dim=0) + return torch.stack([q[1], q[2], q[3], q[0]], dim=0) elif to == "wxyz": - return th.stack([q[3], q[0], q[1], q[2]], dim=0) + return torch.stack([q[3], q[0], q[1], q[2]], dim=0) else: raise ValueError("convert_quat: choose a valid `to` argument (xyzw or wxyz)") -@th.jit.script -def quat_multiply(quaternion1: th.Tensor, quaternion0: th.Tensor) -> th.Tensor: +@torch.compile +def quat_multiply(quaternion1: torch.Tensor, quaternion0: torch.Tensor) -> torch.Tensor: """ Return multiplication of two quaternions (q1 * q0). Args: - quaternion1 (th.Tensor): (x,y,z,w) quaternion - quaternion0 (th.Tensor): (x,y,z,w) quaternion + quaternion1 (torch.Tensor): (x,y,z,w) quaternion + quaternion0 (torch.Tensor): (x,y,z,w) quaternion Returns: - th.Tensor: (x,y,z,w) multiplied quaternion + torch.Tensor: (x,y,z,w) multiplied quaternion """ x0, y0, z0, w0 = quaternion0[0], quaternion0[1], quaternion0[2], quaternion0[3] x1, y1, z1, w1 = quaternion1[0], quaternion1[1], quaternion1[2], quaternion1[3] - return th.stack( + return torch.stack( [ x1 * w0 + y1 * z0 - z1 * y0 + w1 * x0, -x1 * z0 + y1 * w0 + z1 * x0 + w1 * y0, @@ -207,56 +207,56 @@ def quat_multiply(quaternion1: th.Tensor, quaternion0: th.Tensor) -> th.Tensor: ) -@th.jit.script -def quat_conjugate(quaternion: th.Tensor) -> th.Tensor: +@torch.compile +def quat_conjugate(quaternion: torch.Tensor) -> torch.Tensor: """ Return conjugate of quaternion. Args: - quaternion (th.Tensor): (x,y,z,w) quaternion + quaternion (torch.Tensor): (x,y,z,w) quaternion Returns: - th.Tensor: (x,y,z,w) quaternion conjugate + torch.Tensor: (x,y,z,w) quaternion conjugate """ - return th.cat([-quaternion[:3], quaternion[3:]]) + return torch.cat([-quaternion[:3], quaternion[3:]]) -@th.jit.script -def quat_inverse(quaternion: th.Tensor) -> th.Tensor: +@torch.compile +def quat_inverse(quaternion: torch.Tensor) -> torch.Tensor: """ Return inverse of quaternion. E.g.: >>> q0 = random_quaternion() >>> q1 = quat_inverse(q0) - >>> th.allclose(quat_multiply(q0, q1), [0, 0, 0, 1]) + >>> torch.allclose(quat_multiply(q0, q1), [0, 0, 0, 1]) True Args: - quaternion (th.tensor): (x,y,z,w) quaternion + quaternion (torch.tensor): (x,y,z,w) quaternion Returns: - th.tensor: (x,y,z,w) quaternion inverse + torch.tensor: (x,y,z,w) quaternion inverse """ - return quat_conjugate(quaternion) / th.dot(quaternion, quaternion) + return quat_conjugate(quaternion) / torch.dot(quaternion, quaternion) -@th.jit.script +@torch.compile def quat_distance(quaternion1, quaternion0): """ Returns distance between two quaternions, such that distance * quaternion0 = quaternion1 Args: - quaternion1 (th.tensor): (x,y,z,w) quaternion - quaternion0 (th.tensor): (x,y,z,w) quaternion + quaternion1 (torch.tensor): (x,y,z,w) quaternion + quaternion0 (torch.tensor): (x,y,z,w) quaternion Returns: - th.tensor: (x,y,z,w) quaternion distance + torch.tensor: (x,y,z,w) quaternion distance """ return quat_multiply(quaternion1, quat_inverse(quaternion0)) -@th.jit.script +@torch.compile def quat_slerp(quat0, quat1, frac, shortestpath=True, eps=1.0e-15): """ Return spherical linear interpolation between two quaternions. @@ -282,22 +282,22 @@ def quat_slerp(quat0, quat1, frac, shortestpath=True, eps=1.0e-15): d = dot(quat0, quat1, dim=-1, keepdim=True) if shortestpath: - quat1 = th.where(d < 0.0, -quat1, quat1) - d = th.abs(d) - angle = th.acos(th.clip(d, -1.0, 1.0)) + quat1 = torch.where(d < 0.0, -quat1, quat1) + d = torch.abs(d) + angle = torch.acos(torch.clip(d, -1.0, 1.0)) # Check for small quantities (i.e.: q0 = q1) - where_small_diff = th.abs(th.abs(d) - 1.0) < eps + where_small_diff = torch.abs(torch.abs(d) - 1.0) < eps where_small_angle = abs(angle) < eps - isin = 1.0 / th.sin(angle) - val = quat0 * th.sin((1.0 - frac) * angle) * isin + quat1 * th.sin(frac * angle) * isin + isin = 1.0 / torch.sin(angle) + val = quat0 * torch.sin((1.0 - frac) * angle) * isin + quat1 * torch.sin(frac * angle) * isin # Filter edge cases - val = th.where( + val = torch.where( where_small_diff | where_small_angle | where_start, quat0, - th.where( + torch.where( where_end, quat1, val, @@ -308,7 +308,7 @@ def quat_slerp(quat0, quat1, frac, shortestpath=True, eps=1.0e-15): return val.reshape(list(quat_shape)) -@th.jit.script +@torch.compile def random_axis_angle(angle_limit: float = 2.0 * math.pi): """ Samples an axis-angle rotation by first sampling a random axis @@ -324,13 +324,13 @@ def random_axis_angle(angle_limit: float = 2.0 * math.pi): # sample random axis using a normalized sample from spherical Gaussian. # see (http://extremelearning.com.au/how-to-generate-uniformly-random-points-on-n-spheres-and-n-balls/) # for why it works. - random_axis = th.randn(3) - random_axis /= th.norm(random_axis) - random_angle = th.rand(1) * angle_limit + random_axis = torch.randn(3) + random_axis /= torch.norm(random_axis) + random_angle = torch.rand(1) * angle_limit return random_axis, random_angle.item() -@th.jit.script +@torch.compile def quat2mat(quaternion): """ Convert quaternions into rotation matrices. @@ -341,18 +341,22 @@ def quat2mat(quaternion): Returns: torch.Tensor: A tensor of shape (..., 3, 3) representing batches of rotation matrices. """ - quaternion = quaternion / th.norm(quaternion, dim=-1, keepdim=True) + quaternion = quaternion / torch.norm(quaternion, dim=-1, keepdim=True) - x = quaternion[..., 0] - y = quaternion[..., 1] - z = quaternion[..., 2] - w = quaternion[..., 3] + outer = quaternion.unsqueeze(-1) * quaternion.unsqueeze(-2) - xx, yy, zz = x * x, y * y, z * z - xy, xz, yz = x * y, x * z, y * z - xw, yw, zw = x * w, y * w, z * w + # Extract the necessary components + xx = outer[..., 0, 0] + yy = outer[..., 1, 1] + zz = outer[..., 2, 2] + xy = outer[..., 0, 1] + xz = outer[..., 0, 2] + yz = outer[..., 1, 2] + xw = outer[..., 0, 3] + yw = outer[..., 1, 3] + zw = outer[..., 2, 3] - rotation_matrix = th.empty(quaternion.shape[:-1] + (3, 3), dtype=quaternion.dtype, device=quaternion.device) + rotation_matrix = torch.empty(quaternion.shape[:-1] + (3, 3), dtype=quaternion.dtype, device=quaternion.device) rotation_matrix[..., 0, 0] = 1 - 2 * (yy + zz) rotation_matrix[..., 0, 1] = 2 * (xy - zw) @@ -369,14 +373,14 @@ def quat2mat(quaternion): return rotation_matrix -@th.jit.script -def mat2quat(rmat: th.Tensor) -> th.Tensor: +@torch.compile +def mat2quat(rmat: torch.Tensor) -> torch.Tensor: """ Converts given rotation matrix to quaternion. Args: - rmat (th.Tensor): (3, 3) or (..., 3, 3) rotation matrix + rmat (torch.Tensor): (3, 3) or (..., 3, 3) rotation matrix Returns: - th.Tensor: (4,) or (..., 4) (x,y,z,w) float quaternion angles + torch.Tensor: (4,) or (..., 4) (x,y,z,w) float quaternion angles """ # Check if input is a single matrix or a batch is_single = rmat.dim() == 2 @@ -398,37 +402,37 @@ def mat2quat(rmat: th.Tensor) -> th.Tensor: cond3 = ~(trace_positive | cond1 | cond2) # Trace positive condition - sq = th.where(trace_positive, th.sqrt(trace + 1.0) * 2.0, th.zeros_like(trace)) - qw = th.where(trace_positive, 0.25 * sq, th.zeros_like(trace)) - qx = th.where(trace_positive, (m21 - m12) / sq, th.zeros_like(trace)) - qy = th.where(trace_positive, (m02 - m20) / sq, th.zeros_like(trace)) - qz = th.where(trace_positive, (m10 - m01) / sq, th.zeros_like(trace)) + sq = torch.where(trace_positive, torch.sqrt(trace + 1.0) * 2.0, torch.zeros_like(trace)) + qw = torch.where(trace_positive, 0.25 * sq, torch.zeros_like(trace)) + qx = torch.where(trace_positive, (m21 - m12) / sq, torch.zeros_like(trace)) + qy = torch.where(trace_positive, (m02 - m20) / sq, torch.zeros_like(trace)) + qz = torch.where(trace_positive, (m10 - m01) / sq, torch.zeros_like(trace)) # Condition 1 - sq = th.where(cond1, th.sqrt(1.0 + m00 - m11 - m22) * 2.0, sq) - qw = th.where(cond1, (m21 - m12) / sq, qw) - qx = th.where(cond1, 0.25 * sq, qx) - qy = th.where(cond1, (m01 + m10) / sq, qy) - qz = th.where(cond1, (m02 + m20) / sq, qz) + sq = torch.where(cond1, torch.sqrt(1.0 + m00 - m11 - m22) * 2.0, sq) + qw = torch.where(cond1, (m21 - m12) / sq, qw) + qx = torch.where(cond1, 0.25 * sq, qx) + qy = torch.where(cond1, (m01 + m10) / sq, qy) + qz = torch.where(cond1, (m02 + m20) / sq, qz) # Condition 2 - sq = th.where(cond2, th.sqrt(1.0 + m11 - m00 - m22) * 2.0, sq) - qw = th.where(cond2, (m02 - m20) / sq, qw) - qx = th.where(cond2, (m01 + m10) / sq, qx) - qy = th.where(cond2, 0.25 * sq, qy) - qz = th.where(cond2, (m12 + m21) / sq, qz) + sq = torch.where(cond2, torch.sqrt(1.0 + m11 - m00 - m22) * 2.0, sq) + qw = torch.where(cond2, (m02 - m20) / sq, qw) + qx = torch.where(cond2, (m01 + m10) / sq, qx) + qy = torch.where(cond2, 0.25 * sq, qy) + qz = torch.where(cond2, (m12 + m21) / sq, qz) # Condition 3 - sq = th.where(cond3, th.sqrt(1.0 + m22 - m00 - m11) * 2.0, sq) - qw = th.where(cond3, (m10 - m01) / sq, qw) - qx = th.where(cond3, (m02 + m20) / sq, qx) - qy = th.where(cond3, (m12 + m21) / sq, qy) - qz = th.where(cond3, 0.25 * sq, qz) + sq = torch.where(cond3, torch.sqrt(1.0 + m22 - m00 - m11) * 2.0, sq) + qw = torch.where(cond3, (m10 - m01) / sq, qw) + qx = torch.where(cond3, (m02 + m20) / sq, qx) + qy = torch.where(cond3, (m12 + m21) / sq, qy) + qz = torch.where(cond3, 0.25 * sq, qz) - quat = th.stack([qx, qy, qz, qw], dim=-1) + quat = torch.stack([qx, qy, qz, qw], dim=-1) # Normalize the quaternion - quat = quat / th.norm(quat, dim=-1, keepdim=True) + quat = quat / torch.norm(quat, dim=-1, keepdim=True) # Reshape to match input batch shape quat = quat.reshape(batch_shape + (4,)) @@ -440,36 +444,36 @@ def mat2quat(rmat: th.Tensor) -> th.Tensor: return quat -@th.jit.script +@torch.compile def mat2pose(hmat): """ Converts a homogeneous 4x4 matrix into pose. Args: - hmat (th.tensor): a 4x4 homogeneous matrix + hmat (torch.tensor): a 4x4 homogeneous matrix Returns: 2-tuple: - - (th.tensor) (x,y,z) position array in cartesian coordinates - - (th.tensor) (x,y,z,w) orientation array in quaternion form + - (torch.tensor) (x,y,z) position array in cartesian coordinates + - (torch.tensor) (x,y,z,w) orientation array in quaternion form """ - assert th.allclose(hmat[:3, :3].det(), th.tensor(1.0)), "Rotation matrix must not be scaled" + assert torch.allclose(hmat[:3, :3].det(), torch.tensor(1.0)), "Rotation matrix must not be scaled" pos = hmat[:3, 3] orn = mat2quat(hmat[:3, :3]) return pos, orn -@th.jit.script -def vec2quat(vec: th.Tensor, up: th.Tensor = th.tensor([0.0, 0.0, 1.0])) -> th.Tensor: +@torch.compile +def vec2quat(vec: torch.Tensor, up: torch.Tensor = torch.tensor([0.0, 0.0, 1.0])) -> torch.Tensor: """ Converts given 3d-direction vector @vec to quaternion orientation with respect to another direction vector @up Args: - vec (th.Tensor): (x,y,z) direction vector (possibly non-normalized) - up (th.Tensor): (x,y,z) direction vector representing the canonical up direction (possibly non-normalized) + vec (torch.Tensor): (x,y,z) direction vector (possibly non-normalized) + up (torch.Tensor): (x,y,z) direction vector representing the canonical up direction (possibly non-normalized) Returns: - th.Tensor: (x,y,z,w) quaternion + torch.Tensor: (x,y,z,w) quaternion """ # Ensure inputs are 2D if vec.dim() == 1: @@ -477,27 +481,27 @@ def vec2quat(vec: th.Tensor, up: th.Tensor = th.tensor([0.0, 0.0, 1.0])) -> th.T if up.dim() == 1: up = up.unsqueeze(0) - vec_n = th.nn.functional.normalize(vec, dim=-1) - up_n = th.nn.functional.normalize(up, dim=-1) + vec_n = torch.nn.functional.normalize(vec, dim=-1) + up_n = torch.nn.functional.normalize(up, dim=-1) - s_n = th.cross(up_n, vec_n, dim=-1) - u_n = th.cross(vec_n, s_n, dim=-1) + s_n = torch.cross(up_n, vec_n, dim=-1) + u_n = torch.cross(vec_n, s_n, dim=-1) - rotation_matrix = th.stack([vec_n, s_n, u_n], dim=-1) + rotation_matrix = torch.stack([vec_n, s_n, u_n], dim=-1) return mat2quat(rotation_matrix) -@th.jit.script -def euler2quat(euler: th.Tensor) -> th.Tensor: +@torch.compile +def euler2quat(euler: torch.Tensor) -> torch.Tensor: """ Converts euler angles into quaternion form Args: - euler (th.Tensor): (..., 3) (r,p,y) angles + euler (torch.Tensor): (..., 3) (r,p,y) angles Returns: - th.Tensor: (..., 4) (x,y,z,w) float quaternion angles + torch.Tensor: (..., 4) (x,y,z,w) float quaternion angles Raises: AssertionError: [Invalid input shape] @@ -508,12 +512,12 @@ def euler2quat(euler: th.Tensor) -> th.Tensor: roll, pitch, yaw = euler.unbind(-1) # Compute sines and cosines of half angles - cy = th.cos(yaw * 0.5) - sy = th.sin(yaw * 0.5) - cr = th.cos(roll * 0.5) - sr = th.sin(roll * 0.5) - cp = th.cos(pitch * 0.5) - sp = th.sin(pitch * 0.5) + cy = torch.cos(yaw * 0.5) + sy = torch.sin(yaw * 0.5) + cr = torch.cos(roll * 0.5) + sr = torch.sin(roll * 0.5) + cp = torch.cos(pitch * 0.5) + sp = torch.sin(pitch * 0.5) # Compute quaternion components qw = cy * cr * cp + sy * sr * sp @@ -522,10 +526,10 @@ def euler2quat(euler: th.Tensor) -> th.Tensor: qz = sy * cr * cp - cy * sr * sp # Stack and return - return th.stack([qx, qy, qz, qw], dim=-1) + return torch.stack([qx, qy, qz, qw], dim=-1) -@th.jit.script +@torch.compile def quat2euler(q): single_dim = q.dim() == 1 @@ -537,16 +541,16 @@ def quat2euler(q): # roll (x-axis rotation) sinr_cosp = 2.0 * (q[:, qw] * q[:, qx] + q[:, qy] * q[:, qz]) cosr_cosp = q[:, qw] * q[:, qw] - q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] + q[:, qz] * q[:, qz] - roll = th.atan2(sinr_cosp, cosr_cosp) + roll = torch.atan2(sinr_cosp, cosr_cosp) # pitch (y-axis rotation) sinp = 2.0 * (q[:, qw] * q[:, qy] - q[:, qz] * q[:, qx]) - pitch = th.where(th.abs(sinp) >= 1, copysign(math.pi / 2.0, sinp), th.asin(sinp)) + pitch = torch.where(torch.abs(sinp) >= 1, copysign(math.pi / 2.0, sinp), torch.asin(sinp)) # yaw (z-axis rotation) siny_cosp = 2.0 * (q[:, qw] * q[:, qz] + q[:, qx] * q[:, qy]) cosy_cosp = q[:, qw] * q[:, qw] + q[:, qx] * q[:, qx] - q[:, qy] * q[:, qy] - q[:, qz] * q[:, qz] - yaw = th.atan2(siny_cosp, cosy_cosp) + yaw = torch.atan2(siny_cosp, cosy_cosp) - euler = th.stack([roll, pitch, yaw], dim=-1) % (2 * math.pi) + euler = torch.stack([roll, pitch, yaw], dim=-1) % (2 * math.pi) euler[euler > math.pi] -= 2 * math.pi if single_dim: @@ -555,21 +559,21 @@ def quat2euler(q): return euler -@th.jit.script +@torch.compile def euler2mat(euler): """ Converts euler angles into rotation matrix form Args: - euler (th.tensor): (r,p,y) angles + euler (torch.tensor): (r,p,y) angles Returns: - th.tensor: 3x3 rotation matrix + torch.tensor: 3x3 rotation matrix Raises: AssertionError: [Invalid input shape] """ - euler = th.as_tensor(euler, dtype=th.float32) + euler = torch.as_tensor(euler, dtype=torch.float32) assert euler.shape[-1] == 3, f"Invalid shaped euler {euler}" # Convert Euler angles to quaternion @@ -579,18 +583,18 @@ def euler2mat(euler): return quat2mat(quat) -@th.jit.script +@torch.compile def mat2euler(rmat): """ Converts given rotation matrix to euler angles in radian. Args: - rmat (th.tensor): 3x3 rotation matrix + rmat (torch.tensor): 3x3 rotation matrix Returns: - th.tensor: (r,p,y) converted euler angles in radian vec3 float + torch.tensor: (r,p,y) converted euler angles in radian vec3 float """ - M = th.as_tensor(rmat, dtype=th.float32)[:3, :3] + M = torch.as_tensor(rmat, dtype=torch.float32)[:3, :3] # Convert rotation matrix to quaternion # Note: You'll need to implement mat2quat function @@ -602,25 +606,25 @@ def mat2euler(rmat): pitch = euler[..., 1] yaw = euler[..., 2] - return th.stack([roll, pitch, yaw], dim=-1) + return torch.stack([roll, pitch, yaw], dim=-1) -@th.jit.script -def pose2mat(pose: Tuple[th.Tensor, th.Tensor]) -> th.Tensor: +@torch.compile +def pose2mat(pose: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: pos, orn = pose # Ensure pos and orn are the expected shape and dtype - pos = pos.to(dtype=th.float32).reshape(3) - orn = orn.to(dtype=th.float32).reshape(4) + pos = pos.to(dtype=torch.float32).reshape(3) + orn = orn.to(dtype=torch.float32).reshape(4) - homo_pose_mat = th.eye(4, dtype=th.float32) + homo_pose_mat = torch.eye(4, dtype=torch.float32) homo_pose_mat[:3, :3] = quat2mat(orn) homo_pose_mat[:3, 3] = pos return homo_pose_mat -@th.jit.script +@torch.compile def quat2axisangle(quat): """ Converts quaternion to axis-angle format. @@ -634,15 +638,15 @@ def quat2axisangle(quat): quat_shape = quat.shape[:-1] # ignore last dim quat = quat.reshape(-1, 4) # clip quaternion - quat[:, 3] = th.clip(quat[:, 3], -1.0, 1.0) + quat[:, 3] = torch.clip(quat[:, 3], -1.0, 1.0) # Calculate denominator - den = th.sqrt(1.0 - quat[:, 3] * quat[:, 3]) + den = torch.sqrt(1.0 - quat[:, 3] * quat[:, 3]) # Map this into a mask # Create return array - ret = th.zeros_like(quat)[:, :3] - idx = th.nonzero(den).reshape(-1) - ret[idx, :] = (quat[idx, :3] * 2.0 * th.acos(quat[idx, 3]).unsqueeze(-1)) / den[idx].unsqueeze(-1) + ret = torch.zeros_like(quat)[:, :3] + idx = torch.nonzero(den).reshape(-1) + ret[idx, :] = (quat[idx, :3] * 2.0 * torch.acos(quat[idx, 3]).unsqueeze(-1)) / den[idx].unsqueeze(-1) # Reshape and return output ret = ret.reshape( @@ -654,7 +658,7 @@ def quat2axisangle(quat): return ret -@th.jit.script +@torch.compile def axisangle2quat(vec, eps=1e-6): """ Converts scaled axis-angle to quat. @@ -671,16 +675,16 @@ def axisangle2quat(vec, eps=1e-6): vec = vec.reshape(-1, 3) # Grab angle - angle = th.norm(vec, dim=-1, keepdim=True) + angle = torch.norm(vec, dim=-1, keepdim=True) # Create return array - quat = th.zeros(th.prod(th.tensor(input_shape, dtype=th.int)), 4, device=vec.device) + quat = torch.zeros(torch.prod(torch.tensor(input_shape, dtype=torch.int)), 4, device=vec.device) quat[:, 3] = 1.0 # Grab indexes where angle is not zero an convert the input to its quaternion form - idx = angle.reshape(-1) > eps # th.nonzero(angle).reshape(-1) - quat[idx, :] = th.cat( - [vec[idx, :] * th.sin(angle[idx, :] / 2.0) / angle[idx, :], th.cos(angle[idx, :] / 2.0)], dim=-1 + idx = angle.reshape(-1) > eps # torch.nonzero(angle).reshape(-1) + quat[idx, :] = torch.cat( + [vec[idx, :] * torch.sin(angle[idx, :] / 2.0) / angle[idx, :], torch.cos(angle[idx, :] / 2.0)], dim=-1 ) # Reshape and return output @@ -693,18 +697,18 @@ def axisangle2quat(vec, eps=1e-6): return quat -@th.jit.script +@torch.compile def pose_in_A_to_pose_in_B(pose_A, pose_A_in_B): """ Converts a homogenous matrix corresponding to a point C in frame A to a homogenous matrix corresponding to the same point C in frame B. Args: - pose_A (th.tensor): 4x4 matrix corresponding to the pose of C in frame A - pose_A_in_B (th.tensor): 4x4 matrix corresponding to the pose of A in frame B + pose_A (torch.tensor): 4x4 matrix corresponding to the pose of C in frame A + pose_A_in_B (torch.tensor): 4x4 matrix corresponding to the pose of A in frame B Returns: - th.tensor: 4x4 matrix corresponding to the pose of C in frame B + torch.tensor: 4x4 matrix corresponding to the pose of C in frame B """ # pose of A in B takes a point in A and transforms it to a point in C. @@ -712,20 +716,20 @@ def pose_in_A_to_pose_in_B(pose_A, pose_A_in_B): # pose of C in B = pose of A in B * pose of C in A # take a point in C, transform it to A, then to B # T_B^C = T_A^C * T_B^A - return th.matmul(pose_A_in_B, pose_A) + return torch.matmul(pose_A_in_B, pose_A) -@th.jit.script +@torch.compile def pose_inv(pose_mat): """ Computes the inverse of a homogeneous matrix corresponding to the pose of some frame B in frame A. The inverse is the pose of frame A in frame B. Args: - pose_mat (th.tensor): 4x4 matrix for the pose to inverse + pose_mat (torch.tensor): 4x4 matrix for the pose to inverse Returns: - th.tensor: 4x4 matrix for the inverse pose + torch.tensor: 4x4 matrix for the inverse pose """ # Note, the inverse of a pose matrix is the following @@ -738,14 +742,14 @@ def pose_inv(pose_mat): # -t in the original frame, which is -R-1*t in the new frame, and then rotate back by # R-1 to align the axis again. - pose_inv = th.zeros((4, 4)) + pose_inv = torch.zeros((4, 4)) pose_inv[:3, :3] = pose_mat[:3, :3].T pose_inv[:3, 3] = -pose_inv[:3, :3] @ pose_mat[:3, 3] pose_inv[3, 3] = 1.0 return pose_inv -@th.jit.script +@torch.compile def pose_transform(pos1, quat1, pos0, quat0): """ Conducts forward transform from pose (pos0, quat0) to pose (pos1, quat1): @@ -760,8 +764,8 @@ def pose_transform(pos1, quat1, pos0, quat0): Returns: 2-tuple: - - (th.tensor) (x,y,z) position array in cartesian coordinates - - (th.tensor) (x,y,z,w) orientation array in quaternion form + - (torch.tensor) (x,y,z) position array in cartesian coordinates + - (torch.tensor) (x,y,z,w) orientation array in quaternion form """ # Get poses mat0 = pose2mat((pos0, quat0)) @@ -771,7 +775,7 @@ def pose_transform(pos1, quat1, pos0, quat0): return mat2pose(mat1 @ mat0) -@th.jit.script +@torch.compile def invert_pose_transform(pos, quat): """ Inverts a pose transform @@ -782,8 +786,8 @@ def invert_pose_transform(pos, quat): Returns: 2-tuple: - - (th.tensor) (x,y,z) position array in cartesian coordinates - - (th.tensor) (x,y,z,w) orientation array in quaternion form + - (torch.tensor) (x,y,z) position array in cartesian coordinates + - (torch.tensor) (x,y,z,w) orientation array in quaternion form """ # Get pose mat = pose2mat((pos, quat)) @@ -792,7 +796,7 @@ def invert_pose_transform(pos, quat): return mat2pose(pose_inv(mat)) -@th.jit.script +@torch.compile def relative_pose_transform(pos1, quat1, pos0, quat0): """ Computes relative forward transform from pose (pos0, quat0) to pose (pos1, quat1), i.e.: solves: @@ -807,8 +811,8 @@ def relative_pose_transform(pos1, quat1, pos0, quat0): Returns: 2-tuple: - - (th.tensor) (x,y,z) position array in cartesian coordinates - - (th.tensor) (x,y,z,w) orientation array in quaternion form + - (torch.tensor) (x,y,z) position array in cartesian coordinates + - (torch.tensor) (x,y,z,w) orientation array in quaternion form """ # Get poses mat0 = pose2mat((pos0, quat0)) @@ -818,44 +822,44 @@ def relative_pose_transform(pos1, quat1, pos0, quat0): return mat2pose(pose_inv(mat0) @ mat1) -@th.jit.script +@torch.compile def _skew_symmetric_translation(pos_A_in_B): """ Helper function to get a skew symmetric translation matrix for converting quantities between frames. Args: - pos_A_in_B (th.tensor): (x,y,z) position of A in frame B + pos_A_in_B (torch.tensor): (x,y,z) position of A in frame B Returns: - th.tensor: 3x3 skew symmetric translation matrix + torch.tensor: 3x3 skew symmetric translation matrix """ - return th.tensor( + return torch.tensor( [ [0.0, -pos_A_in_B[2].item(), pos_A_in_B[1].item()], [pos_A_in_B[2].item(), 0.0, -pos_A_in_B[0].item()], [-pos_A_in_B[1].item(), pos_A_in_B[0].item(), 0.0], ], - dtype=th.float32, + dtype=torch.float32, device=pos_A_in_B.device, ) -@th.jit.script +@torch.compile def vel_in_A_to_vel_in_B(vel_A, ang_vel_A, pose_A_in_B): """ Converts linear and angular velocity of a point in frame A to the equivalent in frame B. Args: - vel_A (th.tensor): (vx,vy,vz) linear velocity in A - ang_vel_A (th.tensor): (wx,wy,wz) angular velocity in A - pose_A_in_B (th.tensor): 4x4 matrix corresponding to the pose of A in frame B + vel_A (torch.tensor): (vx,vy,vz) linear velocity in A + ang_vel_A (torch.tensor): (wx,wy,wz) angular velocity in A + pose_A_in_B (torch.tensor): 4x4 matrix corresponding to the pose of A in frame B Returns: 2-tuple: - - (th.tensor) (vx,vy,vz) linear velocities in frame B - - (th.tensor) (wx,wy,wz) angular velocities in frame B + - (torch.tensor) (vx,vy,vz) linear velocities in frame B + - (torch.tensor) (wx,wy,wz) angular velocities in frame B """ pos_A_in_B = pose_A_in_B[:3, 3] rot_A_in_B = pose_A_in_B[:3, :3] @@ -865,21 +869,21 @@ def vel_in_A_to_vel_in_B(vel_A, ang_vel_A, pose_A_in_B): return vel_B, ang_vel_B -@th.jit.script +@torch.compile def force_in_A_to_force_in_B(force_A, torque_A, pose_A_in_B): """ Converts linear and rotational force at a point in frame A to the equivalent in frame B. Args: - force_A (th.tensor): (fx,fy,fz) linear force in A - torque_A (th.tensor): (tx,ty,tz) rotational force (moment) in A - pose_A_in_B (th.tensor): 4x4 matrix corresponding to the pose of A in frame B + force_A (torch.tensor): (fx,fy,fz) linear force in A + torque_A (torch.tensor): (tx,ty,tz) rotational force (moment) in A + pose_A_in_B (torch.tensor): 4x4 matrix corresponding to the pose of A in frame B Returns: 2-tuple: - - (th.tensor) (fx,fy,fz) linear forces in frame B - - (th.tensor) (tx,ty,tz) moments in frame B + - (torch.tensor) (fx,fy,fz) linear forces in frame B + - (torch.tensor) (tx,ty,tz) moments in frame B """ pos_A_in_B = pose_A_in_B[:3, 3] rot_A_in_B = pose_A_in_B[:3, :3] @@ -889,31 +893,31 @@ def force_in_A_to_force_in_B(force_A, torque_A, pose_A_in_B): return force_B, torque_B -@th.jit.script -def rotation_matrix(angle: float, direction: th.Tensor) -> th.Tensor: +@torch.compile +def rotation_matrix(angle: float, direction: torch.Tensor) -> torch.Tensor: """ Returns a 3x3 rotation matrix to rotate about the given axis. Args: angle (float): Magnitude of rotation in radians - direction (th.Tensor): (ax,ay,az) axis about which to rotate + direction (torch.Tensor): (ax,ay,az) axis about which to rotate Returns: - th.Tensor: 3x3 rotation matrix + torch.Tensor: 3x3 rotation matrix """ - sina = th.sin(th.tensor(angle, dtype=th.float32)) - cosa = th.cos(th.tensor(angle, dtype=th.float32)) + sina = torch.sin(torch.tensor(angle, dtype=torch.float32)) + cosa = torch.cos(torch.tensor(angle, dtype=torch.float32)) - direction = direction / th.norm(direction) # Normalize direction vector + direction = direction / torch.norm(direction) # Normalize direction vector # Create rotation matrix - R = th.eye(3, dtype=th.float32, device=direction.device) + R = torch.eye(3, dtype=torch.float32, device=direction.device) R *= cosa - R += th.outer(direction, direction) * (1.0 - cosa) + R += torch.outer(direction, direction) * (1.0 - cosa) direction *= sina # Create the skew-symmetric matrix - skew_matrix = th.zeros(3, 3, dtype=th.float32, device=direction.device) + skew_matrix = torch.zeros(3, 3, dtype=torch.float32, device=direction.device) skew_matrix[0, 1] = -direction[2] skew_matrix[0, 2] = direction[1] skew_matrix[1, 0] = direction[2] @@ -926,32 +930,32 @@ def rotation_matrix(angle: float, direction: th.Tensor) -> th.Tensor: return R -@th.jit.script -def transformation_matrix(angle: float, direction: th.Tensor, point: Optional[th.Tensor] = None) -> th.Tensor: +@torch.compile +def transformation_matrix(angle: float, direction: torch.Tensor, point: Optional[torch.Tensor] = None) -> torch.Tensor: """ Returns a 4x4 homogeneous transformation matrix to rotate about axis defined by point and direction. Args: angle (float): Magnitude of rotation in radians - direction (th.Tensor): (ax,ay,az) axis about which to rotate - point (Optional[th.Tensor]): If specified, is the (x,y,z) point about which the rotation will occur + direction (torch.Tensor): (ax,ay,az) axis about which to rotate + point (Optional[torch.Tensor]): If specified, is the (x,y,z) point about which the rotation will occur Returns: - th.Tensor: 4x4 homogeneous transformation matrix + torch.Tensor: 4x4 homogeneous transformation matrix """ R = rotation_matrix(angle, direction) - M = th.eye(4, dtype=th.float32, device=direction.device) + M = torch.eye(4, dtype=torch.float32, device=direction.device) M[:3, :3] = R if point is not None: # Rotation not about origin - point = point.to(dtype=th.float32) + point = point.to(dtype=torch.float32) M[:3, 3] = point - R @ point return M -@th.jit.script +@torch.compile def clip_translation(dpos, limit): """ Limits a translation (delta position) to a specified limit @@ -965,14 +969,14 @@ def clip_translation(dpos, limit): Returns: 2-tuple: - - (th.tensor) Clipped translation (same dimension as inputs) + - (torch.tensor) Clipped translation (same dimension as inputs) - (bool) whether the value was clipped or not """ - input_norm = th.norm(dpos) + input_norm = torch.norm(dpos) return (dpos * limit / input_norm, True) if input_norm > limit else (dpos, False) -@th.jit.script +@torch.compile def clip_rotation(quat, limit): """ Limits a (delta) rotation to a specified limit @@ -980,19 +984,19 @@ def clip_rotation(quat, limit): Converts rotation to axis-angle, clips, then re-converts back into quaternion Args: - quat (th.tensor): (x,y,z,w) rotation being clipped + quat (torch.tensor): (x,y,z,w) rotation being clipped limit (float): Value to limit rotation by -- magnitude (scalar, in radians) Returns: 2-tuple: - - (th.tensor) Clipped rotation quaternion (x, y, z, w) + - (torch.tensor) Clipped rotation quaternion (x, y, z, w) - (bool) whether the value was clipped or not """ clipped = False # First, normalize the quaternion - quat = quat / th.norm(quat) + quat = quat / torch.norm(quat) den = math.sqrt(max(1 - quat[3] * quat[3], 0)) if den == 0: @@ -1007,35 +1011,35 @@ def clip_rotation(quat, limit): # Clip rotation if necessary and return clipped quat if abs(a) > limit: - a = limit * th.sign(a) / 2 + a = limit * torch.sign(a) / 2 sa = math.sin(a) ca = math.cos(a) - quat = th.tensor([x * sa, y * sa, z * sa, ca]) + quat = torch.tensor([x * sa, y * sa, z * sa, ca]) clipped = True return quat, clipped -@th.jit.script +@torch.compile def make_pose(translation, rotation): """ Makes a homogeneous pose matrix from a translation vector and a rotation matrix. Args: - translation (th.tensor): (x,y,z) translation value - rotation (th.tensor): a 3x3 matrix representing rotation + translation (torch.tensor): (x,y,z) translation value + rotation (torch.tensor): a 3x3 matrix representing rotation Returns: - pose (th.tensor): a 4x4 homogeneous matrix + pose (torch.tensor): a 4x4 homogeneous matrix """ - pose = th.zeros((4, 4)) + pose = torch.zeros((4, 4)) pose[:3, :3] = rotation pose[:3, 3] = translation pose[3, 3] = 1.0 return pose -@th.jit.script +@torch.compile def get_orientation_error(desired, current): """ This function calculates a 3-dimensional orientation error vector, where inputs are quaternions @@ -1053,20 +1057,20 @@ def get_orientation_error(desired, current): cc = quat_conjugate(current) q_r = quat_multiply(desired, cc) - return (q_r[:, 0:3] * th.sign(q_r[:, 3]).unsqueeze(-1)).reshape(list(input_shape) + [3]) + return (q_r[:, 0:3] * torch.sign(q_r[:, 3]).unsqueeze(-1)).reshape(list(input_shape) + [3]) -@th.jit.script -def get_orientation_diff_in_radian(orn0: th.Tensor, orn1: th.Tensor) -> th.Tensor: +@torch.compile +def get_orientation_diff_in_radian(orn0: torch.Tensor, orn1: torch.Tensor) -> torch.Tensor: """ Returns the difference between two quaternion orientations in radians. Args: - orn0 (th.Tensor): (x, y, z, w) quaternion - orn1 (th.Tensor): (x, y, z, w) quaternion + orn0 (torch.Tensor): (x, y, z, w) quaternion + orn1 (torch.Tensor): (x, y, z, w) quaternion Returns: - orn_diff (th.Tensor): orientation difference in radians + orn_diff (torch.Tensor): orientation difference in radians """ # Compute the difference quaternion diff_quat = quat_distance(orn0, orn1) @@ -1075,10 +1079,10 @@ def get_orientation_diff_in_radian(orn0: th.Tensor, orn1: th.Tensor) -> th.Tenso axis_angle = quat2axisangle(diff_quat) # The magnitude of the axis-angle vector is the rotation angle - return th.norm(axis_angle) + return torch.norm(axis_angle) -@th.jit.script +@torch.compile def get_pose_error(target_pose, current_pose): """ Computes the error corresponding to target pose - current pose as a 6-dim vector. @@ -1086,13 +1090,13 @@ def get_pose_error(target_pose, current_pose): correspond to the rotational error. Args: - target_pose (th.tensor): a 4x4 homogenous matrix for the target pose - current_pose (th.tensor): a 4x4 homogenous matrix for the current pose + target_pose (torch.tensor): a 4x4 homogenous matrix for the target pose + current_pose (torch.tensor): a 4x4 homogenous matrix for the current pose Returns: - th.tensor: 6-dim pose error. + torch.tensor: 6-dim pose error. """ - error = th.zeros(6) + error = torch.zeros(6) # compute translational error target_pos = target_pose[:3, 3] @@ -1106,55 +1110,55 @@ def get_pose_error(target_pose, current_pose): r1d = target_pose[:3, 0] r2d = target_pose[:3, 1] r3d = target_pose[:3, 2] - rot_err = 0.5 * (th.linalg.cross(r1, r1d) + th.linalg.cross(r2, r2d) + th.linalg.cross(r3, r3d)) + rot_err = 0.5 * (torch.linalg.cross(r1, r1d) + torch.linalg.cross(r2, r2d) + torch.linalg.cross(r3, r3d)) error[:3] = pos_err error[3:] = rot_err return error -@th.jit.script +@torch.compile def matrix_inverse(matrix): """ Helper function to have an efficient matrix inversion function. Args: - matrix (th.tensor): 2d-array representing a matrix + matrix (torch.tensor): 2d-array representing a matrix Returns: - th.tensor: 2d-array representing the matrix inverse + torch.tensor: 2d-array representing the matrix inverse """ - return th.linalg.inv_ex(matrix).inverse + return torch.linalg.inv_ex(matrix).inverse -@th.jit.script +@torch.compile def vecs2axisangle(vec0, vec1): """ Converts the angle from unnormalized 3D vectors @vec0 to @vec1 into an axis-angle representation of the angle Args: - vec0 (th.tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized - vec1 (th.tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized + vec0 (torch.tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized + vec1 (torch.tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized """ # Normalize vectors vec0 = normalize(vec0, dim=-1) vec1 = normalize(vec1, dim=-1) # Get cross product for direction of angle, and multiply by arcos of the dot product which is the angle - return th.linalg.cross(vec0, vec1) * th.arccos((vec0 * vec1).sum(-1, keepdim=True)) + return torch.linalg.cross(vec0, vec1) * torch.arccos((vec0 * vec1).sum(-1, keepdim=True)) -@th.jit.script -def vecs2quat(vec0: th.Tensor, vec1: th.Tensor, normalized: bool = False) -> th.Tensor: +@torch.compile +def vecs2quat(vec0: torch.Tensor, vec1: torch.Tensor, normalized: bool = False) -> torch.Tensor: """ Converts the angle from unnormalized 3D vectors @vec0 to @vec1 into a quaternion representation of the angle Args: - vec0 (th.Tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized - vec1 (th.Tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized + vec0 (torch.Tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized + vec1 (torch.Tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized normalized (bool): If True, @vec0 and @vec1 are assumed to already be normalized and we will skip the normalization step (more efficient) Returns: - th.Tensor: (..., 4) Normalized quaternion representing the rotation from vec0 to vec1 + torch.Tensor: (..., 4) Normalized quaternion representing the rotation from vec0 to vec1 """ # Normalize vectors if requested if not normalized: @@ -1162,34 +1166,34 @@ def vecs2quat(vec0: th.Tensor, vec1: th.Tensor, normalized: bool = False) -> th. vec1 = normalize(vec1, dim=-1) # Half-way Quaternion Solution -- see https://stackoverflow.com/a/11741520 - cos_theta = th.sum(vec0 * vec1, dim=-1, keepdim=True) + cos_theta = torch.sum(vec0 * vec1, dim=-1, keepdim=True) # Create a tensor for the case where cos_theta == -1 batch_shape = vec0.shape[:-1] - fallback = th.zeros(batch_shape + (4,), device=vec0.device, dtype=vec0.dtype) + fallback = torch.zeros(batch_shape + (4,), device=vec0.device, dtype=vec0.dtype) fallback[..., 0] = 1.0 # Compute the quaternion - quat_unnormalized = th.where( + quat_unnormalized = torch.where( cos_theta == -1, fallback, - th.cat([th.linalg.cross(vec0, vec1), 1 + cos_theta], dim=-1), + torch.cat([torch.linalg.cross(vec0, vec1), 1 + cos_theta], dim=-1), ) - return quat_unnormalized / th.norm(quat_unnormalized, dim=-1, keepdim=True) + return quat_unnormalized / torch.norm(quat_unnormalized, dim=-1, keepdim=True) -@th.jit.script -def align_vector_sets(vec_set1: th.Tensor, vec_set2: th.Tensor) -> th.Tensor: +@torch.compile +def align_vector_sets(vec_set1: torch.Tensor, vec_set2: torch.Tensor) -> torch.Tensor: """ Computes a single quaternion representing the rotation that best aligns vec_set1 to vec_set2. Args: - vec_set1 (th.Tensor): (N, 3) tensor of N 3D vectors - vec_set2 (th.Tensor): (N, 3) tensor of N 3D vectors + vec_set1 (torch.Tensor): (N, 3) tensor of N 3D vectors + vec_set2 (torch.Tensor): (N, 3) tensor of N 3D vectors Returns: - th.Tensor: (4,) Normalized quaternion representing the overall rotation + torch.Tensor: (4,) Normalized quaternion representing the overall rotation """ # Compute the cross-covariance matrix H = vec_set2.T @ vec_set1 @@ -1202,31 +1206,31 @@ def align_vector_sets(vec_set1: th.Tensor, vec_set2: th.Tensor) -> th.Tensor: z = H[0, 1] - H[1, 0] # Construct the quaternion - quat = th.stack([x, y, z, w]) + quat = torch.stack([x, y, z, w]) # Handle the case where w is close to zero if quat[3] < 1e-4: quat[3] = 0 - max_idx = th.argmax(quat[:3].abs()) + 1 + max_idx = torch.argmax(quat[:3].abs()) + 1 quat[max_idx] = 1 # Normalize the quaternion - quat = quat / (th.norm(quat) + 1e-8) # Add epsilon to avoid division by zero + quat = quat / (torch.norm(quat) + 1e-8) # Add epsilon to avoid division by zero return quat -@th.jit.script -def l2_distance(v1: th.Tensor, v2: th.Tensor) -> th.Tensor: +@torch.compile +def l2_distance(v1: torch.Tensor, v2: torch.Tensor) -> torch.Tensor: """Returns the L2 distance between vector v1 and v2.""" - return th.norm(v1 - v2) + return torch.norm(v1 - v2) -@th.jit.script +@torch.compile def cartesian_to_polar(x, y): """Convert cartesian coordinate to polar coordinate""" - rho = th.sqrt(x**2 + y**2) - phi = th.arctan2(y, x) + rho = torch.sqrt(x**2 + y**2) + phi = torch.arctan2(y, x) return rho, phi @@ -1238,8 +1242,8 @@ def rad2deg(rad): return rad * 180.0 / math.pi -@th.jit.script -def check_quat_right_angle(quat: th.Tensor, atol: float = 5e-2) -> th.Tensor: +@torch.compile +def check_quat_right_angle(quat: torch.Tensor, atol: float = 5e-2) -> torch.Tensor: """ Check by making sure the quaternion is some permutation of +/- (1, 0, 0, 0), +/- (0.707, 0.707, 0, 0), or +/- (0.5, 0.5, 0.5, 0.5) @@ -1247,25 +1251,25 @@ def check_quat_right_angle(quat: th.Tensor, atol: float = 5e-2) -> th.Tensor: So we check the L1-norm of the absolute value of the orientation as a proxy for verifying these values Args: - quat (th.Tensor): (x,y,z,w) quaternion orientation to check + quat (torch.Tensor): (x,y,z,w) quaternion orientation to check atol (float): Absolute tolerance permitted Returns: - th.Tensor: Boolean tensor indicating whether the quaternion is a right angle or not + torch.Tensor: Boolean tensor indicating whether the quaternion is a right angle or not """ - l1_norm = th.abs(quat).sum(dim=-1) - reference_norms = th.tensor([1.0, 1.414, 2.0], device=quat.device, dtype=quat.dtype) - return th.any(th.abs(l1_norm.unsqueeze(-1) - reference_norms) < atol, dim=-1) + l1_norm = torch.abs(quat).sum(dim=-1) + reference_norms = torch.tensor([1.0, 1.414, 2.0], device=quat.device, dtype=quat.dtype) + return torch.any(torch.abs(l1_norm.unsqueeze(-1) - reference_norms) < atol, dim=-1) -@th.jit.script +@torch.compile def z_angle_from_quat(quat): """Get the angle around the Z axis produced by the quaternion.""" - rotated_X_axis = quat_apply(quat, th.tensor([1, 0, 0], dtype=th.float32)) - return th.arctan2(rotated_X_axis[1], rotated_X_axis[0]) + rotated_X_axis = quat_apply(quat, torch.tensor([1, 0, 0], dtype=torch.float32)) + return torch.arctan2(rotated_X_axis[1], rotated_X_axis[0]) -@th.jit.script +@torch.compile def integer_spiral_coordinates(n: int) -> Tuple[int, int]: """A function to map integers to 2D coordinates in a spiral pattern around the origin.""" # Map integers from Z to Z^2 in a spiral pattern around the origin. @@ -1278,8 +1282,8 @@ def integer_spiral_coordinates(n: int) -> Tuple[int, int]: return int(x), int(y) -@th.jit.script -def random_quaternion(num_quaternions: int = 1) -> th.Tensor: +@torch.compile +def random_quaternion(num_quaternions: int = 1) -> torch.Tensor: """ Generate random rotation quaternions, uniformly distributed over SO(3). @@ -1287,24 +1291,24 @@ def random_quaternion(num_quaternions: int = 1) -> th.Tensor: num_quaternions: int, number of quaternions to generate (default: 1) Returns: - th.Tensor: A tensor of shape (num_quaternions, 4) containing random unit quaternions. + torch.Tensor: A tensor of shape (num_quaternions, 4) containing random unit quaternions. """ # Generate four random numbers between 0 and 1 - rand = th.rand(num_quaternions, 4) + rand = torch.rand(num_quaternions, 4) # Use the formula from Ken Shoemake's "Uniform Random Rotations" - r1 = th.sqrt(1.0 - rand[:, 0]) - r2 = th.sqrt(rand[:, 0]) - t1 = 2 * th.pi * rand[:, 1] - t2 = 2 * th.pi * rand[:, 2] + r1 = torch.sqrt(1.0 - rand[:, 0]) + r2 = torch.sqrt(rand[:, 0]) + t1 = 2 * torch.pi * rand[:, 1] + t2 = 2 * torch.pi * rand[:, 2] - quaternions = th.stack([r1 * th.sin(t1), r1 * th.cos(t1), r2 * th.sin(t2), r2 * th.cos(t2)], dim=1) + quaternions = torch.stack([r1 * torch.sin(t1), r1 * torch.cos(t1), r2 * torch.sin(t2), r2 * torch.cos(t2)], dim=1) return quaternions -@th.jit.script -def transform_points(points: th.Tensor, matrix: th.Tensor, translate: bool = True) -> th.Tensor: +@torch.compile +def transform_points(points: torch.Tensor, matrix: torch.Tensor, translate: bool = True) -> torch.Tensor: """ Returns points rotated by a homogeneous transformation matrix. @@ -1328,27 +1332,27 @@ def transform_points(points: th.Tensor, matrix: th.Tensor, translate: bool = Tru count, dim = points.shape # Check if the matrix is close to an identity matrix - identity = th.eye(dim + 1, device=points.device) - if th.abs(matrix - identity[: dim + 1, : dim + 1]).max() < 1e-8: + identity = torch.eye(dim + 1, device=points.device) + if torch.abs(matrix - identity[: dim + 1, : dim + 1]).max() < 1e-8: return points.clone().contiguous() if translate: - stack = th.cat((points, th.ones(count, 1, device=points.device)), dim=1) - return th.mm(matrix, stack.t()).t()[:, :dim] + stack = torch.cat((points, torch.ones(count, 1, device=points.device)), dim=1) + return torch.mm(matrix, stack.t()).t()[:, :dim] else: - return th.mm(matrix[:dim, :dim], points.t()).t() + return torch.mm(matrix[:dim, :dim], points.t()).t() -@th.jit.script -def quaternions_close(q1: th.Tensor, q2: th.Tensor, atol: float = 1e-3) -> bool: +@torch.compile +def quaternions_close(q1: torch.Tensor, q2: torch.Tensor, atol: float = 1e-3) -> bool: """ Whether two quaternions represent the same rotation, allowing for the possibility that one is the negative of the other. Arguments: - q1: th.Tensor + q1: torch.Tensor First quaternion - q2: th.Tensor + q2: torch.Tensor Second quaternion atol: float Absolute tolerance for comparison @@ -1357,4 +1361,4 @@ def quaternions_close(q1: th.Tensor, q2: th.Tensor, atol: float = 1e-3) -> bool: bool Whether the quaternions are close """ - return th.allclose(q1, q2, atol=atol) or th.allclose(q1, -q2, atol=atol) + return torch.allclose(q1, q2, atol=atol) or torch.allclose(q1, -q2, atol=atol)