Skip to content

Commit

Permalink
Merge pull request #341 from StanfordVL/fix/obs-space-compatibility
Browse files Browse the repository at this point in the history
Fix obs space compatibility
  • Loading branch information
cremebrule authored Dec 21, 2023
2 parents 08aee08 + a2ccfac commit 473960d
Show file tree
Hide file tree
Showing 8 changed files with 84 additions and 33 deletions.
41 changes: 31 additions & 10 deletions omnigibson/envs/env_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from omnigibson.tasks import REGISTERED_TASKS
from omnigibson.scenes import REGISTERED_SCENES
from omnigibson.sensors import create_sensor
from omnigibson.utils.gym_utils import GymObservable, recursively_generate_flat_dict
from omnigibson.utils.gym_utils import GymObservable, recursively_generate_flat_dict, recursively_generate_compatible_dict
from omnigibson.utils.config_utils import parse_config
from omnigibson.utils.ui_utils import create_module_logger
from omnigibson.utils.python_utils import assert_valid_key, merge_nested_dicts, create_class_from_registry_and_config,\
Expand Down Expand Up @@ -548,15 +548,36 @@ def reset(self):
# Grab and return observations
obs = self.get_obs()

if self._loaded and not self.observation_space.contains(obs):
# Flatten obs, and print out all keys and values
log.error("OBSERVATION SPACE:")
for key, value in recursively_generate_flat_dict(dic=self.observation_space).items():
log.error(("obs_space", key, value.dtype, value.shape))
log.error("ACTUAL OBSERVATIONS:")
for key, value in recursively_generate_flat_dict(dic=obs).items():
log.error(("obs", key, value.dtype, value.shape))
raise ValueError("Observation space does not match returned observations!")
if self._loaded:
# Sanity check to make sure received observations match expected observation space
check_obs = recursively_generate_compatible_dict(dic=obs)
if not self.observation_space.contains(check_obs):
exp_obs = dict()
for key, value in recursively_generate_flat_dict(dic=self.observation_space).items():
exp_obs[key] = ("obs_space", key, value.dtype, value.shape)
real_obs = dict()
for key, value in recursively_generate_flat_dict(dic=check_obs).items():
if isinstance(value, np.ndarray):
real_obs[key] = ("obs", key, value.dtype, value.shape)
else:
real_obs[key] = ("obs", key, type(value), "()")

exp_keys = set(exp_obs.keys())
real_keys = set(real_obs.keys())
shared_keys = exp_keys.intersection(real_keys)
missing_keys = exp_keys - real_keys
extra_keys = real_keys - exp_keys

log.error("MISSING OBSERVATION KEYS:")
log.error(missing_keys)
log.error("EXTRA OBSERVATION KEYS:")
log.error(extra_keys)
log.error("SHARED OBSERVATION KEY DTYPES AND SHAPES:")
for k in shared_keys:
log.error(exp_obs[k])
log.error(real_obs[k])

raise ValueError("Observation space does not match returned observations!")

return obs

Expand Down
6 changes: 5 additions & 1 deletion omnigibson/examples/robots/robot_control_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def main(random_selection=False, headless=False, short_exec=False):
robot0_cfg["action_normalize"] = True

# Compile config
cfg = dict(scene=scene_cfg, robots=[robot0_cfg])
cfg = dict(env=env_cfg, scene=scene_cfg, robots=[robot0_cfg])

# Create the environment
env = og.Environment(configs=cfg)
Expand All @@ -110,6 +110,10 @@ def main(random_selection=False, headless=False, short_exec=False):
controller_config = {component: {"name": name} for component, name in controller_choices.items()}
robot.reload_controllers(controller_config=controller_config)

# Because the controllers have been updated, we need to update the initial state so the correct controller state
# is preserved
env.scene.update_initial_state()

# Update the simulator's viewer camera's pose so it points towards the robot
og.sim.viewer_camera.set_position_orientation(
position=np.array([1.46949, -3.97358, 2.21529]),
Expand Down
2 changes: 1 addition & 1 deletion omnigibson/object_states/open_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def _get_relevant_joints(obj):
# 1 means the open direction corresponds to positive joint angle change and -1 means the opposite
default_joint_directions = [1] * len(default_relevant_joints)

if not hasattr(obj, "metadata"):
if not hasattr(obj, "metadata") or obj.metadata is None:
log.debug("No openable joint metadata found for object %s" % obj.name)
return default_both_sides, default_relevant_joints, default_joint_directions

Expand Down
2 changes: 1 addition & 1 deletion omnigibson/prims/entity_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _post_load(self):

# Prepare the articulation view.
if self.n_joints > 0:
self._articulation_view_direct = ArticulationView(self._prim_path + "/base_link")
self._articulation_view_direct = ArticulationView(f"{self._prim_path}/{self.root_link_name}")

# Set visual only flag
# This automatically handles setting collisions / gravity appropriately per-link
Expand Down
5 changes: 3 additions & 2 deletions omnigibson/prims/rigid_prim.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,9 @@ def update_meshes(self):
if prim.GetPrimTypeInfo().GetTypeName() in GEOM_TYPES:
mesh_name, mesh_path = prim.GetName(), prim.GetPrimPath().__str__()
mesh_prim = get_prim_at_path(prim_path=mesh_path)
mesh_kwargs = {"prim_path": mesh_path, "name": f"{self._name}:{mesh_name}"}
if mesh_prim.HasAPI(UsdPhysics.CollisionAPI):
is_collision = mesh_prim.HasAPI(UsdPhysics.CollisionAPI)
mesh_kwargs = {"prim_path": mesh_path, "name": f"{self._name}:{'collision' if is_collision else 'visual'}_{mesh_name}"}
if is_collision:
mesh = CollisionGeomPrim(**mesh_kwargs)
# We also modify the collision mesh's contact and rest offsets, since omni's default values result
# in lightweight objects sometimes not triggering contacts correctly
Expand Down
34 changes: 17 additions & 17 deletions omnigibson/sensors/vision_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,33 +417,33 @@ def _obs_space_mapping(self):
# Generate the complex space types for special modalities:
# {"bbox_2d_tight", "bbox_2d_loose", "bbox_3d", "camera"}
bbox_3d_space = gym.spaces.Sequence(space=gym.spaces.Tuple((
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=int), # uniqueId
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.int32), # uniqueId
gym.spaces.Text(min_length=1, max_length=50, charset=VALID_OMNI_CHARS), # name
gym.spaces.Text(min_length=1, max_length=50, charset=VALID_OMNI_CHARS), # semanticLabel
gym.spaces.Text(min_length=0, max_length=50, charset=VALID_OMNI_CHARS), # metadata
gym.spaces.Sequence(space=gym.spaces.Box(low=0, high=MAX_INSTANCE_COUNT, shape=(), dtype=np.uint)), # instanceIds
gym.spaces.Box(low=0, high=MAX_CLASS_COUNT, shape=(), dtype=np.uint), # semanticId
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=float), # x_min
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=float), # y_min
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=float), # z_min
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=float), # x_max
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=float), # y_max
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=float), # z_max
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4, 4), dtype=float), # transform
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(8, 3), dtype=float), # corners
gym.spaces.Box(low=0, high=MAX_CLASS_COUNT, shape=(), dtype=np.uint32), # semanticId
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.int32), # x_min
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.int32), # y_min
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.int32), # z_min
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.int32), # x_max
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.int32), # y_max
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.int32), # z_max
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4, 4), dtype=np.float32), # transform
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(8, 3), dtype=np.float32), # corners
)))

bbox_2d_space = gym.spaces.Sequence(space=gym.spaces.Tuple((
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=int), # uniqueId
gym.spaces.Box(low=-np.inf, high=np.inf, shape=(), dtype=np.int32), # uniqueId
gym.spaces.Text(min_length=1, max_length=50, charset=VALID_OMNI_CHARS), # name
gym.spaces.Text(min_length=1, max_length=50, charset=VALID_OMNI_CHARS), # semanticLabel
gym.spaces.Text(min_length=0, max_length=50, charset=VALID_OMNI_CHARS), # metadata
gym.spaces.Sequence(space=gym.spaces.Box(low=0, high=MAX_INSTANCE_COUNT, shape=(), dtype=np.uint)), # instanceIds
gym.spaces.Box(low=0, high=MAX_CLASS_COUNT, shape=(), dtype=np.uint), # semanticId
gym.spaces.Box(low=0, high=MAX_VIEWER_SIZE, shape=(), dtype=int), # x_min
gym.spaces.Box(low=0, high=MAX_VIEWER_SIZE, shape=(), dtype=int), # y_min
gym.spaces.Box(low=0, high=MAX_VIEWER_SIZE, shape=(), dtype=int), # x_max
gym.spaces.Box(low=0, high=MAX_VIEWER_SIZE, shape=(), dtype=int), # y_max
gym.spaces.Box(low=0, high=MAX_CLASS_COUNT, shape=(), dtype=np.uint32), # semanticId
gym.spaces.Box(low=0, high=MAX_VIEWER_SIZE, shape=(), dtype=np.int32), # x_min
gym.spaces.Box(low=0, high=MAX_VIEWER_SIZE, shape=(), dtype=np.int32), # y_min
gym.spaces.Box(low=0, high=MAX_VIEWER_SIZE, shape=(), dtype=np.int32), # x_max
gym.spaces.Box(low=0, high=MAX_VIEWER_SIZE, shape=(), dtype=np.int32), # y_max
)))

camera_space = gym.spaces.Dict(dict(
Expand All @@ -466,7 +466,7 @@ def _obs_space_mapping(self):
normal=((self.image_height, self.image_width, 3), -1.0, 1.0, np.float32),
seg_semantic=((self.image_height, self.image_width), 0, MAX_CLASS_COUNT, np.uint32),
seg_instance=((self.image_height, self.image_width), 0, MAX_INSTANCE_COUNT, np.uint32),
flow=((self.image_height, self.image_width, 3), -np.inf, np.inf, np.float32),
flow=((self.image_height, self.image_width, 4), -np.inf, np.inf, np.float32),
bbox_2d_tight=bbox_2d_space,
bbox_2d_loose=bbox_2d_space,
bbox_3d=bbox_3d_space,
Expand Down
25 changes: 25 additions & 0 deletions omnigibson/utils/gym_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,31 @@ def recursively_generate_flat_dict(dic, prefix=None):
return out


def recursively_generate_compatible_dict(dic):
"""
Helper function to recursively iterate through dictionary and cast values to necessary types to be compatibel with
Gym spaces -- in particular, the Sequence and Tuple types for np.ndarray / np.void values in @dic
Args:
dic (dict or gym.spaces.Dict): (Potentially nested) dictionary to convert into a flattened dictionary
Returns:
dict: Gym-compatible version of @dic
"""
out = dict()
for k, v in dic.items():
if isinstance(v, dict):
out[k] = recursively_generate_compatible_dict(dic=v)
elif isinstance(v, np.ndarray) and len(v.dtype) > 0:
# Map to list of tuples
out[k] = list(map(tuple, v))
else:
# Preserve the key-value pair
out[k] = v

return out


class GymObservable(metaclass=ABCMeta):
"""
Simple class interface for observable objects. These objects should implement a way to grab observations,
Expand Down
2 changes: 1 addition & 1 deletion omnigibson/utils/usd_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def get_camera_params(viewport):
"horizontal_aperture": view_params["horizontal_aperture"],
"view_projection_matrix": view_proj_mat,
"resolution": {"width": view_params["width"], "height": view_params["height"]},
"clipping_range": view_params["clipping_range"],
"clipping_range": np.array(view_params["clipping_range"]),
}


Expand Down

0 comments on commit 473960d

Please sign in to comment.