diff --git a/notebooks/explore_dset2.ipynb b/notebooks/explore_dset2.ipynb index b339013..b1ab0fb 100644 --- a/notebooks/explore_dset2.ipynb +++ b/notebooks/explore_dset2.ipynb @@ -47,33 +47,6 @@ ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data = dset[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data[\"ignore_collisions\"]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data.keys()" - ] - }, { "cell_type": "code", "execution_count": null, @@ -401,13 +374,23 @@ "outputs": [], "source": [ "# Get a mapping from handle id to handle name.\n", - "handle_mapping = load_handle_mapping(\"/data/rlbench10/\", task_name, 0)\n", + "task_name = \"put_money_in_safe\"\n", + "handle_mapping = load_handle_mapping(\"/data/rlbench10_collisions/\", task_name, 0)\n", "rev_handle_mapping = {v: k for k, v in handle_mapping.items()}\n", "\n", "q_id = 100\n", "rev_handle_mapping[q_id]" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "handle_mapping" + ] + }, { "cell_type": "code", "execution_count": null, @@ -510,7 +493,7 @@ "N_DEMOS = 10\n", "# Create a dataset for that phase.\n", "dset = RLBenchPlacementDataset(\n", - " dataset_root=\"/data/rlbench10/\",\n", + " dataset_root=\"/data/rlbench10_collisions/\",\n", " task_name=\"take_umbrella_out_of_umbrella_stand\",\n", " demos=range(N_DEMOS),\n", " phase=phase,\n", @@ -571,12 +554,12 @@ "\n", "# task_name = \"pick_and_lift\"\n", "# task_name = \"pick_up_cup\"\n", - "# task_name = \"put_knife_on_chopping_board\"\n", + "task_name = \"put_knife_on_chopping_board\"\n", "# task_name = \"put_money_in_safe\"\n", "# task_name = \"push_button\"\n", "# task_name = \"reach_target\"\n", "# task_name = \"slide_block_to_target\"\n", - "task_name = \"stack_wine\"\n", + "# task_name = \"stack_wine\"\n", "# task_name = \"take_money_out_safe\"\n", "# task_name = \"take_umbrella_out_of_umbrella_stand\"\n", "\n", @@ -586,7 +569,7 @@ "for ix, phase in enumerate(TASK_DICT[task_name][\"phase_order\"]):\n", " print(f\"Phase: {phase}\")\n", " dset = RLBenchPlacementDataset(\n", - " dataset_root=\"/data/rlbench10/\",\n", + " dataset_root=\"/data/rlbench10_collisions/\",\n", " task_name=task_name,\n", " demos=[0],\n", " phase=phase,\n", @@ -594,6 +577,8 @@ " use_first_as_init_keyframe=False,\n", " anchor_mode=\"background_robot_removed\",\n", " action_mode=\"gripper_and_object\",\n", + " include_wrist_cam=True,\n", + " gripper_in_first_phase=True,\n", " )\n", "\n", " data = dset[0]\n", diff --git a/src/rpad/rlbench_utils/placement_dataset.py b/src/rpad/rlbench_utils/placement_dataset.py index 3762215..4f79346 100644 --- a/src/rpad/rlbench_utils/placement_dataset.py +++ b/src/rpad/rlbench_utils/placement_dataset.py @@ -141,7 +141,11 @@ def load_state_pos_dict( BACKGROUND_NAMES = [ "ResizableFloor_5_25_visibleElement", + "Wall1", + "Wall2", "Wall3", + "Wall4", + "Roof", "diningTable_visible", "workspace", ] @@ -213,7 +217,10 @@ def get_anchor_points( names = BACKGROUND_NAMES + ROBOT_NONGRIPPER_NAMES # If it's the first phase, we also omit the gripper. - if phase == TASK_DICT[task_name]["phase_order"][0] and gripper_in_first_phase: + if ( + phase == TASK_DICT[task_name]["phase_order"][0] + and not gripper_in_first_phase + ): names += GRIPPER_OBJ_NAMES return filter_out_names(rgb, point_cloud, mask, handle_mapping, names)