Skip to content

Commit

Permalink
add in background too
Browse files Browse the repository at this point in the history
  • Loading branch information
beneisner committed May 22, 2024
1 parent 11722e4 commit 84f84f0
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 56 deletions.
113 changes: 58 additions & 55 deletions notebooks/explore_dset2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@
"outputs": [],
"source": [
"# Get a mapping from handle id to handle name.\n",
"task_name = \"put_money_in_safe\"\n",
"# task_name = \"put_money_in_safe\"\n",
"task_name = \"put_knife_on_chopping_board\"\n",
"handle_mapping = load_handle_mapping(\"/data/rlbench10_collisions/\", task_name, 0)\n",
"rev_handle_mapping = {v: k for k, v in handle_mapping.items()}\n",
"\n",
Expand All @@ -388,7 +389,7 @@
"metadata": {},
"outputs": [],
"source": [
"handle_mapping"
"set(handle_mapping.keys())"
]
},
{
Expand Down Expand Up @@ -563,63 +564,65 @@
"# task_name = \"take_money_out_safe\"\n",
"# task_name = \"take_umbrella_out_of_umbrella_stand\"\n",
"\n",
"n_phases = len(TASK_DICT[task_name][\"phase_order\"])\n",
"fig = make_subplots(rows=1, cols=n_phases, specs=[[{\"type\": \"scene\"}] * n_phases])\n",
"for i in range(4):\n",
"\n",
"for ix, phase in enumerate(TASK_DICT[task_name][\"phase_order\"]):\n",
" print(f\"Phase: {phase}\")\n",
" dset = RLBenchPlacementDataset(\n",
" dataset_root=\"/data/rlbench10_collisions/\",\n",
" task_name=task_name,\n",
" demos=[0],\n",
" phase=phase,\n",
" debugging=False,\n",
" use_first_as_init_keyframe=False,\n",
" anchor_mode=\"background_robot_removed\",\n",
" action_mode=\"gripper_and_object\",\n",
" include_wrist_cam=True,\n",
" gripper_in_first_phase=True,\n",
" )\n",
" n_phases = len(TASK_DICT[task_name][\"phase_order\"])\n",
" fig = make_subplots(rows=1, cols=n_phases, specs=[[{\"type\": \"scene\"}] * n_phases])\n",
"\n",
" data = dset[0]\n",
"\n",
" # Plot segmentation with segmentation_fig\n",
"\n",
" print(list(data.keys()))\n",
"\n",
" anchor_pc = data[\"init_anchor_pc\"]\n",
" # Randomly downsample the anchor point cloud.\n",
" n_pts = anchor_pc.shape[0]\n",
" if n_pts > 1000:\n",
" anchor_pc = anchor_pc[np.random.permutation(n_pts)[:1000]]\n",
"\n",
" points = torch.cat(\n",
" [\n",
" data[\"init_action_pc\"],\n",
" anchor_pc,\n",
" data[\"key_action_pc\"],\n",
" ]\n",
" )\n",
" print(points.shape)\n",
" seg = torch.cat(\n",
" [\n",
" torch.zeros(data[\"init_action_pc\"].shape[0]),\n",
" torch.ones(anchor_pc.shape[0]),\n",
" 2 * torch.ones(data[\"key_action_pc\"].shape[0]),\n",
" ]\n",
" )\n",
" for ix, phase in enumerate(TASK_DICT[task_name][\"phase_order\"]):\n",
" print(f\"Phase: {phase}\")\n",
" dset = RLBenchPlacementDataset(\n",
" dataset_root=\"/data/rlbench10_collisions/\",\n",
" task_name=task_name,\n",
" demos=range(100),\n",
" phase=phase,\n",
" debugging=False,\n",
" use_first_as_init_keyframe=False,\n",
" anchor_mode=\"background_robot_removed\",\n",
" action_mode=\"gripper_and_object\",\n",
" include_wrist_cam=True,\n",
" gripper_in_first_phase=True,\n",
" )\n",
"\n",
" fig = segmentation_fig_rc(\n",
" points,\n",
" seg.int(),\n",
" labelmap={0: \"init_action\", 1: \"init_anchor\", 2: \"key_action\"},\n",
" fig=fig,\n",
" row=1,\n",
" column=ix+1,\n",
" n_col=n_phases,\n",
" )\n",
" data = dset[i]\n",
"\n",
"fig.show()\n",
" # Plot segmentation with segmentation_fig\n",
"\n",
" print(list(data.keys()))\n",
"\n",
" anchor_pc = data[\"init_anchor_pc\"]\n",
" # Randomly downsample the anchor point cloud.\n",
" n_pts = anchor_pc.shape[0]\n",
" if n_pts > 1000:\n",
" anchor_pc = anchor_pc[np.random.permutation(n_pts)[:1000]]\n",
"\n",
" points = torch.cat(\n",
" [\n",
" data[\"init_action_pc\"],\n",
" anchor_pc,\n",
" data[\"key_action_pc\"],\n",
" ]\n",
" )\n",
" print(points.shape)\n",
" seg = torch.cat(\n",
" [\n",
" torch.zeros(data[\"init_action_pc\"].shape[0]),\n",
" torch.ones(anchor_pc.shape[0]),\n",
" 2 * torch.ones(data[\"key_action_pc\"].shape[0]),\n",
" ]\n",
" )\n",
"\n",
" fig = segmentation_fig_rc(\n",
" points,\n",
" seg.int(),\n",
" labelmap={0: \"init_action\", 1: \"init_anchor\", 2: \"key_action\"},\n",
" fig=fig,\n",
" row=1,\n",
" column=ix+1,\n",
" n_col=n_phases,\n",
" )\n",
"\n",
" fig.show()\n",
" "
]
},
Expand Down
41 changes: 40 additions & 1 deletion src/rpad/rlbench_utils/placement_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,51 @@ def load_state_pos_dict(


BACKGROUND_NAMES = [
"DefaultCamera",
"DefaultLightA",
"DefaultLightB",
"DefaultLightC",
"DefaultLightD",
"DefaultLights",
"DefaultNXViewCamera",
"DefaultNYViewCamera",
"DefaultNZViewCamera",
"DefaultXViewCamera",
"DefaultYViewCamera",
"DefaultZViewCamera",
"Dummy",
"Floor",
"FloorAnchor",
"ResizableFloor_5_25",
"ResizableFloor_5_25_element",
"ResizableFloor_5_25_visibleElement",
"Roof",
"Wall1",
"Wall2",
"Wall3",
"Wall4",
"Roof",
"XYZCameraProxy",
"boundary",
"cam_cinematic_base",
"cam_cinematic_placeholder",
"cam_front",
"cam_front_mask",
"cam_over_shoulder_left",
"cam_over_shoulder_left_mask",
"cam_over_shoulder_right",
"cam_over_shoulder_right_mask",
"cam_overhead",
"cam_overhead_mask",
"cam_wrist",
"cam_wrist_mask",
"diningTable",
"diningTable_visible",
"remoteApi",
"success",
"waypoint0",
"waypoint1",
"waypoint2",
"waypoint3",
"workspace",
]

Expand All @@ -165,6 +203,7 @@ def load_state_pos_dict(
def filter_out_names(rgb, point_cloud, mask, handlemapping, names=BACKGROUND_NAMES):
# Get the indices of the background.
background_handles = [handlemapping[name] for name in names]
background_handles.append(65535) # It's -1, cast as uint16.
background_indices = np.isin(mask, background_handles).reshape((-1))

# Get the indices of the foreground.
Expand Down

0 comments on commit 84f84f0

Please sign in to comment.