Skip to content

Commit

Permalink
factor stuff out
Browse files Browse the repository at this point in the history
  • Loading branch information
beneisner committed May 16, 2024
1 parent 657288a commit ed54f36
Showing 1 changed file with 95 additions and 60 deletions.
155 changes: 95 additions & 60 deletions src/rpad/rlbench_utils/placement_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,72 @@ class AnchorMode(str, Enum):
SINGLE_OBJECT = "single_object"


def get_anchor_points(
anchor_mode: AnchorMode,
rgb,
point_cloud,
mask,
task_name,
phase,
use_from_simulator=False,
handle_mapping=None,
names_to_handles=None,
):
if anchor_mode == AnchorMode.RAW:
return rgb, point_cloud
elif anchor_mode == AnchorMode.BACKGROUND_REMOVED:
return filter_out_names(
rgb, point_cloud, mask, handle_mapping, BACKGROUND_NAMES
)
elif anchor_mode == AnchorMode.BACKGROUND_ROBOT_REMOVED:
return filter_out_names(
rgb,
point_cloud,
mask,
handle_mapping,
BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES,
)
elif anchor_mode == AnchorMode.SINGLE_OBJECT:
if use_from_simulator:
return get_rgb_point_cloud_by_object_names(
rgb,
point_cloud,
mask,
TASK_DICT[task_name]["phase"][phase]["anchor_obj_names"],
)
else:
return get_rgb_point_cloud_by_object_handles(
rgb,
point_cloud,
mask,
names_to_handles[phase]["anchor_obj_names"],
)
else:
raise ValueError("Anchor mode must be one of the AnchorMode enum values.")


def get_action_points(
action_mode: ActionMode,
rgb,
point_cloud,
mask,
action_handles,
gripper_handles,
):
if action_mode == ActionMode.GRIPPER_AND_OBJECT:
action_handles = action_handles + gripper_handles
elif action_mode == ActionMode.OBJECT:
pass
else:
raise ValueError("Action mode must be one of the ActionMode enum values.")

action_rgb, action_point_cloud = get_rgb_point_cloud_by_object_handles(
rgb, point_cloud, mask, action_handles
)

return action_rgb, action_point_cloud


class RLBenchPlacementDataset(data.Dataset):
def __init__(
self,
Expand Down Expand Up @@ -299,7 +365,7 @@ def _load_keyframes(

keyframes = [demo[ix] for ix in keyframe_ixs]

return keyframes, demo[0]
return keyframes, demo[0] # type: ignore

# We also cache in memory, since all the transformations are the same.
# Saves a lot of time when loading the dataset, but don't have to worry
Expand Down Expand Up @@ -347,69 +413,38 @@ def __getitem__(self, index: int) -> Dict[str, torch.Tensor]:
# Find the first grasp instance
key_obs = keyframes[phase_ix]

if self.debugging:
raise ValueError("Debugging not implemented.")
return {
"keyframes": keyframe_ixs,
"demo": demo,
"initial_obs": initial_obs,
"key_obs": key_obs,
"init_front_rgb": torch.from_numpy(initial_obs.front_rgb),
"key_front_rgb": torch.from_numpy(key_obs.front_rgb),
"init_front_mask": torch.from_numpy(
initial_obs.front_mask.astype(np.int32)
),
"key_front_mask": torch.from_numpy(key_obs.front_mask.astype(np.int32)),
"phase": phase,
"phase_onehot": torch.from_numpy(phase_onehot),
}
action_handles = self.names_to_handles[phase]["action_obj_names"]

def _select_action_vals(rgb, point_cloud, mask):
return get_action_points(
self.action_mode,
rgb,
point_cloud,
mask,
action_handles,
self.gripper_handles,
)

def _select_anchor_vals(rgb, point_cloud, mask):
return get_anchor_points(
self.anchor_mode,
rgb,
point_cloud,
mask,
self.task_name,
phase,
use_from_simulator=False,
handle_mapping=self.handle_mapping,
names_to_handles=self.names_to_handles,
)

# Merge all the initial point clouds and masks into one.
init_rgb, init_point_cloud, init_mask = obs_to_rgb_point_cloud(initial_obs)

action_handles = self.names_to_handles[phase]["action_obj_names"]
if self.action_mode == ActionMode.GRIPPER_AND_OBJECT:
action_handles = action_handles + self.gripper_handles
elif self.action_mode == ActionMode.OBJECT:
pass
else:
raise ValueError("Action mode must be one of the ActionMode enum values.")

# Split the initial point cloud and rgb into action and anchor.
(
init_action_rgb,
init_action_point_cloud,
) = get_rgb_point_cloud_by_object_handles(
init_rgb, init_point_cloud, init_mask, action_handles
init_action_rgb, init_action_point_cloud = _select_action_vals(
init_rgb, init_point_cloud, init_mask
)

def _select_anchor_vals(rgb, point_cloud, mask):
if self.anchor_mode == AnchorMode.RAW:
return rgb, point_cloud
elif self.anchor_mode == AnchorMode.BACKGROUND_REMOVED:
return filter_out_names(
rgb, point_cloud, mask, self.handle_mapping, BACKGROUND_NAMES
)
elif self.anchor_mode == AnchorMode.BACKGROUND_ROBOT_REMOVED:
return filter_out_names(
rgb,
point_cloud,
mask,
self.handle_mapping,
BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES,
)
elif self.anchor_mode == AnchorMode.SINGLE_OBJECT:
return get_rgb_point_cloud_by_object_handles(
rgb,
point_cloud,
mask,
self.names_to_handles[phase]["anchor_obj_names"],
)
else:
raise ValueError(
"Anchor mode must be one of the AnchorMode enum values."
)

init_anchor_rgb, init_anchor_point_cloud = _select_anchor_vals(
init_rgb, init_point_cloud, init_mask
)
Expand All @@ -418,8 +453,8 @@ def _select_anchor_vals(rgb, point_cloud, mask):
key_rgb, key_point_cloud, key_mask = obs_to_rgb_point_cloud(key_obs)

# Split the key point cloud and rgb into action and anchor.
key_action_rgb, key_action_point_cloud = get_rgb_point_cloud_by_object_handles(
key_rgb, key_point_cloud, key_mask, action_handles
key_action_rgb, key_action_point_cloud = _select_action_vals(
key_rgb, key_point_cloud, key_mask
)
key_anchor_rgb, key_anchor_point_cloud = _select_anchor_vals(
key_rgb, key_point_cloud, key_mask
Expand Down

0 comments on commit ed54f36

Please sign in to comment.