Skip to content

Commit

Permalink
Merge pull request #729 from StanfordVL/rl-experiments-merge
Browse files Browse the repository at this point in the history
Merge experimental RL updates into Hang's PR
  • Loading branch information
cgokmen authored May 21, 2024
2 parents f4f153e + 4d9102f commit d767b34
Show file tree
Hide file tree
Showing 17 changed files with 418 additions and 197 deletions.
42 changes: 31 additions & 11 deletions omnigibson/action_primitives/starter_semantic_action_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,15 @@
m.LOW_PRECISION_ANGLE_THRESHOLD = 0.2

m.TIAGO_TORSO_FIXED = False
m.JOINT_POS_DIFF_THRESHOLD = 0.005
m.JOINT_POS_DIFF_THRESHOLD = 0.01
m.JOINT_CONTROL_MIN_ACTION = 0.0
m.MAX_ALLOWED_JOINT_ERROR_FOR_LINEAR_MOTION = np.deg2rad(45)

log = create_module_logger(module_name=__name__)


def indented_print(msg, *args, **kwargs):
log.debug(" " * len(inspect.stack()) + str(msg), *args, **kwargs)
print(" " * len(inspect.stack()) + str(msg), *args, **kwargs)


class RobotCopy:
Expand Down Expand Up @@ -691,9 +691,11 @@ def _grasp(self, obj):
)

# Open the hand first
indented_print("Opening hand before grasping")
yield from self._execute_release()

# Allow grasping from suboptimal extents if we've tried enough times.
indented_print("Sampling grasp pose")
grasp_poses = get_grasp_poses_for_object_sticky(obj)
grasp_pose, object_direction = random.choice(grasp_poses)

Expand All @@ -702,10 +704,14 @@ def _grasp(self, obj):
approach_pose = (approach_pos, grasp_pose[1])

# If the grasp pose is too far, navigate.
indented_print("Navigating to grasp pose if needed")
yield from self._navigate_if_needed(obj, pose_on_obj=grasp_pose)

indented_print("Moving hand to grasp pose")
yield from self._move_hand(grasp_pose)

# We can pre-grasp in sticky grasping mode.
indented_print("Pregrasp squeeze")
yield from self._execute_grasp()

# Since the grasp pose is slightly off the object, we want to move towards the object, around 5cm.
Expand All @@ -717,15 +723,19 @@ def _grasp(self, obj):
empty_action = self._empty_action()
yield self._postprocess_action(empty_action)

indented_print("Checking grasp")
if self._get_obj_in_hand() is None:
raise ActionPrimitiveError(
ActionPrimitiveError.Reason.POST_CONDITION_ERROR,
"Grasp completed, but no object detected in hand after executing grasp",
{"target object": obj.name},
)

indented_print("Moving hand back")
yield from self._reset_hand()

indented_print("Done with grasp")

if self._get_obj_in_hand() != obj:
raise ActionPrimitiveError(
ActionPrimitiveError.Reason.POST_CONDITION_ERROR,
Expand Down Expand Up @@ -1055,27 +1065,25 @@ def _move_hand_direct_joint(self, joint_pos, stop_on_contact=False, ignore_failu
controller_name = f"arm_{self.arm}"
use_delta = self.robot._controllers[controller_name].use_delta_commands

action = self._empty_action()
controller_name = "arm_{}".format(self.arm)

action[self.robot.controller_action_idx[controller_name]] = joint_pos
# Store the previous eef pose for checking if we got stuck
prev_eef_pos = np.zeros(3)

for _ in range(m.MAX_STEPS_FOR_HAND_MOVE_JOINT):
current_joint_pos = self.robot.get_joint_positions()[self._manipulation_control_idx]
diff_joint_pos = np.array(current_joint_pos) - np.array(joint_pos)
diff_joint_pos = np.array(joint_pos) - np.array(current_joint_pos)
if np.max(np.abs(diff_joint_pos)) < m.JOINT_POS_DIFF_THRESHOLD:
return
if stop_on_contact and detect_robot_collision_in_sim(self.robot, ignore_obj_in_hand=False):
return
if np.max(np.abs(self.robot.get_eef_position(self.arm) - prev_eef_pos)) < 0.0001:
raise ActionPrimitiveError(
ActionPrimitiveError.Reason.EXECUTION_ERROR, f"Hand got stuck during execution."
)
# We're stuck!
break

action = self._empty_action()
if use_delta:
# Convert actions to delta.
action[self.robot.controller_action_idx[controller_name]] = diff_joint_pos
else:
action[self.robot.controller_action_idx[controller_name]] = joint_pos

prev_eef_pos = self.robot.get_eef_position(self.arm)
yield self._postprocess_action(action)
Expand Down Expand Up @@ -1286,6 +1294,12 @@ def _execute_grasp(self):
np.array or None: Action array for one step for the robot to grasp or None if its done grasping
"""
for _ in range(m.MAX_STEPS_FOR_GRASP_OR_RELEASE):
joint_position = self.robot.get_joint_positions()[self.robot.gripper_control_idx[self.arm]]
joint_lower_limit = self.robot.joint_lower_limits[self.robot.gripper_control_idx[self.arm]]

if np.allclose(joint_position, joint_lower_limit, atol=0.01):
break

action = self._empty_action()
controller_name = "gripper_{}".format(self.arm)
action[self.robot.controller_action_idx[controller_name]] = -1.0
Expand All @@ -1299,6 +1313,12 @@ def _execute_release(self):
np.array or None: Action array for one step for the robot to release or None if its done releasing
"""
for _ in range(m.MAX_STEPS_FOR_GRASP_OR_RELEASE):
joint_position = self.robot.get_joint_positions()[self.robot.gripper_control_idx[self.arm]]
joint_upper_limit = self.robot.joint_upper_limits[self.robot.gripper_control_idx[self.arm]]

if np.allclose(joint_position, joint_upper_limit, atol=0.01):
break

action = self._empty_action()
controller_name = "gripper_{}".format(self.arm)
action[self.robot.controller_action_idx[controller_name]] = 1.0
Expand Down
136 changes: 72 additions & 64 deletions omnigibson/envs/env_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Environment(gym.Env, GymObservable, Recreatable):
Core environment class that handles loading scene, robot(s), and task, following OpenAI Gym interface.
"""

def __init__(self, configs):
def __init__(self, configs, in_vec_env=False):
"""
Args:
configs (str or dict or list of str or dict): config_file path(s) or raw config dictionaries.
Expand All @@ -50,6 +50,9 @@ def __init__(self, configs):
self.render_mode = "rgb_array"
self.metadata = {"render.modes": ["rgb_array"]}

# Store if we are part of a vec env
self.in_vec_env = in_vec_env

# Convert config file(s) into a single parsed dict
configs = configs if isinstance(configs, list) or isinstance(configs, tuple) else [configs]

Expand Down Expand Up @@ -102,6 +105,11 @@ def __init__(self, configs):
# Load this environment
self.load()

# If we are not in a vec env, we can play ourselves. Otherwise we wait for the vec env to play.
if not self.in_vec_env:
og.sim.play()
self.post_play_load()

def reload(self, configs, overwrite_old=True):
"""
Reload using another set of config file(s).
Expand Down Expand Up @@ -213,7 +221,6 @@ def _load_scene(self):
"""
Load the scene and robot specified in the config file.
"""
og.sim.stop()
assert og.sim.is_stopped(), "Simulator must be stopped before loading scene!"

# Create the scene from our scene config
Expand Down Expand Up @@ -259,13 +266,6 @@ def _load_robots(self):
if robot._dummy is not None:
robot._dummy.load(self.scene)

if len(self.robots_config) > 0:
# Auto-initialize all robots
og.sim.play()
self.scene.reset()
self.scene.update_initial_state()
og.sim.stop()

assert og.sim.is_stopped(), "Simulator must be stopped after loading robots!"

def _load_objects(self):
Expand All @@ -290,13 +290,6 @@ def _load_objects(self):
self.scene.add_object(obj)
obj.set_local_pose(position=position, orientation=orientation)

if len(self.objects_config) > 0:
# Auto-initialize all objects
og.sim.play()
self.scene.reset()
self.scene.update_initial_state()
og.sim.stop()

assert og.sim.is_stopped(), "Simulator must be stopped after loading objects!"

def _load_external_sensors(self):
Expand All @@ -314,7 +307,7 @@ def _load_external_sensors(self):
sensor_config["name"] = f"external_sensor{i}"
# Determine prim path if not specified
if "prim_path" not in sensor_config:
sensor_config["prim_path"] = f"/World/{sensor_config['name']}"
sensor_config["relative_prim_path"] = f"/{sensor_config['name']}"
# Pop the desired position and orientation
local_position, local_orientation = sensor_config.pop("local_position", None), sensor_config.pop(
"local_orientation", None
Expand Down Expand Up @@ -411,7 +404,9 @@ def load(self):
self._load_task()
self._load_external_sensors()

og.sim.play()
def post_play_load(self):
# Save the state
self.scene.update_initial_state()

# Load the obs / action spaces
self.load_observation_space()
Expand Down Expand Up @@ -554,6 +549,17 @@ def _post_step(self, action):
info["last_observation"] = obs
obs = self.reset()

# Hacky way to check for time limit info to split terminated and truncated
terminated = False
truncated = False
for tc, tc_data in info["done"]["termination_conditions"].items():
if tc_data["done"]:
if tc == "timeout":
truncated = True
else:
terminated = True
assert (terminated or truncated) == done, "Terminated and truncated must match done!"

# Increment step
self._current_step += 1
return obs, reward, done, info
Expand Down Expand Up @@ -604,7 +610,8 @@ def render(self):
og.sim.render()

# Grab the rendered image from each of the rgb sensors, concatenate along dim 1
rgb_images = [sensor.get_obs()["rgb"] for sensor in rgb_sensors]
# TODO: get_obs is a tuple, should it be?
rgb_images = [sensor.get_obs()[0]["rgb"] for sensor in rgb_sensors]
return np.concatenate(rgb_images, axis=1)[:, :, :3]

def _reset_variables(self):
Expand All @@ -614,7 +621,7 @@ def _reset_variables(self):
self._current_episode += 1
self._current_step = 0

def reset(self, **kwargs):
def reset(self, get_obs=True, **kwargs):
"""
Reset episode.
"""
Expand All @@ -624,50 +631,51 @@ def reset(self, **kwargs):
# Reset internal variables
self._reset_variables()

# Run a single simulator step to make sure we can grab updated observations
og.sim.step()

# Grab and return observations
obs, _ = self.get_obs()

if self._loaded:
# Sanity check to make sure received observations match expected observation space
check_obs = recursively_generate_compatible_dict(dic=obs)
if not self.observation_space.contains(check_obs):
exp_obs = dict()
for key, value in recursively_generate_flat_dict(dic=self.observation_space).items():
exp_obs[key] = ("obs_space", key, value.dtype, value.shape)
real_obs = dict()
for key, value in recursively_generate_flat_dict(dic=check_obs).items():
if isinstance(value, np.ndarray):
real_obs[key] = ("obs", key, value.dtype, value.shape)
else:
real_obs[key] = ("obs", key, type(value), "()")

exp_keys = set(exp_obs.keys())
real_keys = set(real_obs.keys())
shared_keys = exp_keys.intersection(real_keys)
missing_keys = exp_keys - real_keys
extra_keys = real_keys - exp_keys

if missing_keys:
log.error("MISSING OBSERVATION KEYS:")
log.error(missing_keys)
if extra_keys:
log.error("EXTRA OBSERVATION KEYS:")
log.error(extra_keys)

mismatched_keys = []
for k in shared_keys:
if exp_obs[k][2:] != real_obs[k][2:]: # Compare dtypes and shapes
mismatched_keys.append(k)
log.error(f"MISMATCHED OBSERVATION FOR KEY '{k}':")
log.error(f"Expected: {exp_obs[k]}")
log.error(f"Received: {real_obs[k]}")

raise ValueError("Observation space does not match returned observations!")

return obs
if get_obs:
# Run a single simulator step to make sure we can grab updated observations
og.sim.step()

# Grab and return observations
obs, _ = self.get_obs()

if self._loaded:
# Sanity check to make sure received observations match expected observation space
check_obs = recursively_generate_compatible_dict(dic=obs)
if not self.observation_space.contains(check_obs):
exp_obs = dict()
for key, value in recursively_generate_flat_dict(dic=self.observation_space).items():
exp_obs[key] = ("obs_space", key, value.dtype, value.shape)
real_obs = dict()
for key, value in recursively_generate_flat_dict(dic=check_obs).items():
if isinstance(value, np.ndarray):
real_obs[key] = ("obs", key, value.dtype, value.shape)
else:
real_obs[key] = ("obs", key, type(value), "()")

exp_keys = set(exp_obs.keys())
real_keys = set(real_obs.keys())
shared_keys = exp_keys.intersection(real_keys)
missing_keys = exp_keys - real_keys
extra_keys = real_keys - exp_keys

if missing_keys:
log.error("MISSING OBSERVATION KEYS:")
log.error(missing_keys)
if extra_keys:
log.error("EXTRA OBSERVATION KEYS:")
log.error(extra_keys)

mismatched_keys = []
for k in shared_keys:
if exp_obs[k][2:] != real_obs[k][2:]: # Compare dtypes and shapes
mismatched_keys.append(k)
log.error(f"MISMATCHED OBSERVATION FOR KEY '{k}':")
log.error(f"Expected: {exp_obs[k]}")
log.error(f"Received: {real_obs[k]}")

raise ValueError("Observation space does not match returned observations!")

return obs, {}

@property
def episode_steps(self):
Expand Down
Loading

0 comments on commit d767b34

Please sign in to comment.