Skip to content

Commit

Permalink
Merge pull request #1033 from StanfordVL/fix/transform_utils
Browse files Browse the repository at this point in the history
fix torch compiled transform utils
  • Loading branch information
ChengshuLi authored Nov 22, 2024
2 parents 4d38122 + 5eaaa0c commit 075267e
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 16 deletions.
2 changes: 1 addition & 1 deletion omnigibson/configs/tiago_primitives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ robots:
name: JointController
arm_left:
name: JointController
subsume_controllers: [trunk]
motor_type: position
command_input_limits: null
command_output_limits: null
use_delta_commands: false
arm_right:
name: JointController
subsume_controllers: [trunk]
motor_type: position
command_input_limits: null
command_output_limits: null
Expand Down
11 changes: 7 additions & 4 deletions omnigibson/objects/controllable_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def _load_controllers(self):
# Generate the controller config
self._controller_config = self._generate_controller_config(custom_config=self._controller_config)

# We copy the controller config here because we add/remove some keys in-place that shouldn't persist
_controller_config = deepcopy(self._controller_config)

# Store dof idx mapping to dof name
self.dof_names_ordered = list(self._joints.keys())

Expand All @@ -237,8 +240,8 @@ def _load_controllers(self):
subsume_names = set()
for name in self._raw_controller_order:
# Make sure we have the valid controller name specified
assert_valid_key(key=name, valid_keys=self._controller_config, name="controller name")
cfg = self._controller_config[name]
assert_valid_key(key=name, valid_keys=_controller_config, name="controller name")
cfg = _controller_config[name]
subsume_controllers = cfg.pop("subsume_controllers", [])
# If this controller subsumes other controllers, it cannot be subsumed by another controller
# (i.e.: we don't allow nested / cyclical subsuming)
Expand All @@ -262,11 +265,11 @@ def _load_controllers(self):
# If this controller is subsumed by another controller, simply skip it
if name in subsume_names:
continue
cfg = self._controller_config[name]
cfg = _controller_config[name]
# If we subsume other controllers, prepend the subsumed' dof idxs to this controller's idxs
if name in controller_subsumes:
for subsumed_name in controller_subsumes[name]:
subsumed_cfg = self._controller_config[subsumed_name]
subsumed_cfg = _controller_config[subsumed_name]
cfg["dof_idx"] = th.concatenate([subsumed_cfg["dof_idx"], cfg["dof_idx"]])

# If we're using normalized action space, override the inputs for all controllers
Expand Down
4 changes: 2 additions & 2 deletions omnigibson/utils/python_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ def create_object_from_init_info(init_info):

def safe_equal(a, b):
if isinstance(a, th.Tensor) and isinstance(b, th.Tensor):
return (a == b).all().item()
return a.shape == b.shape and (a == b).all().item()
elif isinstance(a, (list, tuple)) and isinstance(b, (list, tuple)):
return all(safe_equal(a_item, b_item) for a_item, b_item in zip(a, b))
return len(a) == len(b) and all(safe_equal(a_item, b_item) for a_item, b_item in zip(a, b))
else:
return a == b

Expand Down
16 changes: 8 additions & 8 deletions omnigibson/utils/transform_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def quat_slerp(quat0, quat1, frac, shortestpath=True, eps=1.0e-15):
# type: (Tensor, Tensor, Tensor, bool, float) -> Tensor
# reshape quaternion
quat_shape = quat0.shape
quat0 = unit_vector(quat0.reshape(-1, 4), dim=-1)
quat1 = unit_vector(quat1.reshape(-1, 4), dim=-1)
quat0 = unit_vector(quat0.reshape(-1, 4), dim=-1, out=None)
quat1 = unit_vector(quat1.reshape(-1, 4), dim=-1, out=None)

# Check for endpoint cases
where_start = frac <= 0.0
Expand Down Expand Up @@ -481,8 +481,8 @@ def vec2quat(vec: torch.Tensor, up: torch.Tensor = torch.tensor([0.0, 0.0, 1.0])
if up.dim() == 1:
up = up.unsqueeze(0)

vec_n = torch.nn.functional.normalize(vec, dim=-1)
up_n = torch.nn.functional.normalize(up, dim=-1)
vec_n = normalize(vec, dim=-1, eps=1e-10)
up_n = normalize(up, dim=-1, eps=1e-10)

s_n = torch.cross(up_n, vec_n, dim=-1)
u_n = torch.cross(vec_n, s_n, dim=-1)
Expand Down Expand Up @@ -1141,8 +1141,8 @@ def vecs2axisangle(vec0, vec1):
vec1 (torch.tensor): (..., 3) (x,y,z) 3D vector, possibly unnormalized
"""
# Normalize vectors
vec0 = normalize(vec0, dim=-1)
vec1 = normalize(vec1, dim=-1)
vec0 = normalize(vec0, dim=-1, eps=1e-10)
vec1 = normalize(vec1, dim=-1, eps=1e-10)

# Get cross product for direction of angle, and multiply by arcos of the dot product which is the angle
return torch.linalg.cross(vec0, vec1) * torch.arccos((vec0 * vec1).sum(-1, keepdim=True))
Expand All @@ -1162,8 +1162,8 @@ def vecs2quat(vec0: torch.Tensor, vec1: torch.Tensor, normalized: bool = False)
"""
# Normalize vectors if requested
if not normalized:
vec0 = normalize(vec0, dim=-1)
vec1 = normalize(vec1, dim=-1)
vec0 = normalize(vec0, dim=-1, eps=1e-10)
vec1 = normalize(vec1, dim=-1, eps=1e-10)

# Half-way Quaternion Solution -- see https://stackoverflow.com/a/11741520
cos_theta = torch.sum(vec0 * vec1, dim=-1, keepdim=True)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_robot_states_flatcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def camera_pose_test(flatcache):
relative_pose_transform(sensor_world_pos, sensor_world_ori, robot_world_pos, robot_world_ori)
)

sensor_world_pos_gt = th.tensor([150.1620, 149.9999, 101.2193])
sensor_world_pos_gt = th.tensor([150.1628, 149.9993, 101.3773])
sensor_world_ori_gt = th.tensor([-0.2952, 0.2959, 0.6427, -0.6421])

assert th.allclose(sensor_world_pos, sensor_world_pos_gt, atol=1e-3)
Expand Down

0 comments on commit 075267e

Please sign in to comment.