Skip to content

Commit

Permalink
Merge pull request #932 from StanfordVL/mini-fix
Browse files Browse the repository at this point in the history
Instance Segmentation Refactor
  • Loading branch information
cgokmen authored Oct 4, 2024
2 parents 6a56272 + ef9d37a commit 70b4ce7
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 78 deletions.
98 changes: 73 additions & 25 deletions omnigibson/sensors/vision_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import omnigibson as og
import omnigibson.lazy as lazy
from omnigibson.sensors.sensor_base import BaseSensor
from omnigibson.systems.system_base import get_all_system_names
from omnigibson.utils.constants import (
MAX_CLASS_COUNT,
MAX_INSTANCE_COUNT,
Expand Down Expand Up @@ -353,9 +354,7 @@ def _preprocess_semantic_labels(self, id_to_labels):
if "," in replicator_mapping[key]:
# If there are multiple class names, grab the one that is a registered system
# This happens with MacroVisual particles, e.g. {"11": {"class": "breakfast_table,stain"}}
categories = [
cat for cat in replicator_mapping[key].split(",") if cat in self.scene.available_systems.keys()
]
categories = [cat for cat in replicator_mapping[key].split(",") if cat in get_all_system_names()]
assert (
len(categories) == 1
), "There should be exactly one category that belongs to scene.system_registry"
Expand Down Expand Up @@ -415,30 +414,79 @@ def _remap_instance_segmentation(self, img, id_to_labels, semantic_img, semantic
if value in ["BACKGROUND", "UNLABELLED"]:
value = value.lower()
elif "/" in value:
prim_name = value.split("/")[-1]
# Hacky way to get the particles of MacroVisual/PhysicalParticleSystem
# Remap instance segmentation and instance segmentation ID labels to system name
if "Particle" in prim_name:
category_name = prim_name.split("Particle")[0]
assert (
category_name in self.scene.available_systems.keys()
), f"System name {category_name} is not in the registered systems!"
value = category_name
else:
# Remap instance segmentation labels to object name
if not id:
# value is the prim path of the object
if og.sim.floor_plane is not None and value == og.sim.floor_plane.prim_path:
value = "groundPlane"
else:
# Instance Segmentation
if not id:
# Case 1: This is the ground plane
if og.sim.floor_plane is not None and value == og.sim.floor_plane.prim_path:
value = "groundPlane"
else:
# Case 2: Check if this is an object, e.g. '/World/scene_0/breakfast_table', '/World/scene_0/dishtowel'
obj = None
if self.scene is not None:
# If this is a camera within a scene, we check the object registry of the scene
obj = self.scene.object_registry("prim_path", value)
# Remap instance segmentation labels from prim path to object name
assert obj is not None, f"Object with prim path {value} cannot be found in objct registry!"
else:
# If this is the viewer camera, we check each object registry
for scene in og.sim.scenes:
obj = scene.object_registry("prim_path", value)
if obj:
break
if obj is not None:
# This is an object, so we remap the instance segmentation label to the object name
value = obj.name

# Keep the instance segmentation ID labels intact (prim paths of visual meshes)
else:
pass
# Case 3: Check if this is a particle system
else:
# This is a particle system
path_split = value.split("/")
prim_name = path_split[-1]
system_matched = False
# Case 3.1: Filter out macro particle systems
# e.g. '/World/scene_0/diced__apple/particles/diced__appleParticle0', '/World/scene_0/breakfast_table/base_link/stainParticle0'
if "Particle" in prim_name:
macro_system_name = prim_name.split("Particle")[0]
if macro_system_name in get_all_system_names():
system_matched = True
value = macro_system_name
# Case 3.2: Filter out micro particle systems
# e.g. '/World/scene_0/water/waterInstancer0/prototype0_1', '/World/scene_0/white_rice/white_riceInstancer0/prototype0'
else:
# If anything in path_split has "Instancer" in it, we know it's a micro particle system
for path in path_split:
if "Instancer" in path:
# This is a micro particle system
system_matched = True
value = path.split("Instancer")[0]
break
# Case 4: If nothing matched, we label it as unlabelled
if not system_matched:
value = "unlabelled"
# Instance ID Segmentation
else:
# The only thing we do here is for micro particle system, we clean its name
# e.g. a raw path looks like '/World/scene_0/water/waterInstancer0/prototype0.proto0_prototype0_id0'
# we clean it to '/World/scene_0/water/waterInstancer0/prototype0'
# Case 1: This is a micro particle system
# e.g. '/World/scene_0/water/waterInstancer0/prototype0.proto0_prototype0_id0', '/World/scene_0/white_rice/white_riceInstancer0/prototype0.proto0_prototype0_id0'
if "Instancer" in value and "." in value:
# This is a micro particle system
value = value[: value.rfind(".")]
# Case 2: For everything else, we keep the name as is
"""
e.g.
{
'54': '/World/scene_0/water/waterInstancer0/prototype0.proto0_prototype0_id0',
'60': '/World/scene_0/water/waterInstancer0/prototype0.proto0_prototype0_id0',
'30': '/World/scene_0/breakfast_table/base_link/stainParticle1',
'27': '/World/scene_0/diced__apple/particles/diced__appleParticle0',
'58': '/World/scene_0/white_rice/white_riceInstancer0/prototype0.proto0_prototype0_id0',
'64': '/World/scene_0/white_rice/white_riceInstancer0/prototype0.proto0_prototype0_id0',
'40': '/World/scene_0/diced__apple/particles/diced__appleParticle1',
'48': '/World/scene_0/breakfast_table/base_link/stainParticle0',
'1': '/World/ground_plane/geom',
'19': '/World/scene_0/dishtowel/base_link_cloth',
'6': '/World/scene_0/breakfast_table/base_link/visuals'
}
"""
else:
# TODO: This is a temporary fix unexpected labels e.g. INVALID introduced in new Isaac Sim versions
value = "unlabelled"
Expand Down
2 changes: 2 additions & 0 deletions omnigibson/systems/system_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
from functools import cache

import torch as th

Expand Down Expand Up @@ -957,6 +958,7 @@ def generate_particles_on_object(
return success


@cache
def get_all_system_names():
"""
Gets all available systems from the OmniGibson dataset
Expand Down
3 changes: 1 addition & 2 deletions tests/test_data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,9 @@ def test_data_collect_and_playback():
}

if og.sim is None:
# Make sure GPU dynamics are enabled (GPU dynamics needed for cloth) and no flatcache
# Make sure GPU dynamics are enabled (GPU dynamics needed for cloth)
gm.ENABLE_OBJECT_STATES = True
gm.USE_GPU_DYNAMICS = True
gm.ENABLE_FLATCACHE = True
gm.ENABLE_TRANSITION_RULES = False
else:
# Make sure sim is stopped
Expand Down
3 changes: 1 addition & 2 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@ def task_tester(task_type):
}

if og.sim is None:
# Make sure GPU dynamics are enabled (GPU dynamics needed for cloth) and no flatcache
# Make sure GPU dynamics are enabled (GPU dynamics needed for cloth)
gm.ENABLE_OBJECT_STATES = True
gm.USE_GPU_DYNAMICS = True
gm.ENABLE_FLATCACHE = True
gm.ENABLE_TRANSITION_RULES = False
else:
# Make sure sim is stopped
Expand Down
80 changes: 31 additions & 49 deletions tests/test_sensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@ def test_segmentation_modalities(env):
robot = env.scene.robots[0]
place_obj_on_floor_plane(breakfast_table)
dishtowel.set_position_orientation(position=[-0.4, 0.0, 0.55], orientation=[0, 0, 0, 1])
robot.set_position_orientation(
position=[0.0, 0.8, 0.0], orientation=T.euler2quat(th.tensor([0, 0, -math.pi / 2], dtype=th.float32))
)
robot.reset()

og.sim.viewer_camera.set_position_orientation(position=[-0.0017, -0.1072, 1.4969], orientation=[0.0, 0.0, 0.0, 1.0])

modalities_required = ["seg_semantic", "seg_instance", "seg_instance_id"]
for modality in modalities_required:
robot.add_obs_modality(modality)
og.sim.viewer_camera.add_modality(modality)

systems = [env.scene.get_system(system_name) for system_name in SYSTEM_EXAMPLES.keys()]
for i, system in enumerate(systems):
Expand All @@ -37,17 +35,14 @@ def test_segmentation_modalities(env):
system.generate_group_particles(
group=system.get_group_name(breakfast_table),
positions=[pos, pos + th.tensor([0.1, 0.0, 0.0])],
link_prim_paths=[breakfast_table.root_link.prim_path],
link_prim_paths=[breakfast_table.root_link.prim_path] * 2,
)

og.sim.step()
for _ in range(3):
og.sim.render()

sensors = [s for s in robot.sensors.values() if isinstance(s, VisionSensor)]
assert len(sensors) > 0
vision_sensor = sensors[0]
all_observation, all_info = vision_sensor.get_obs()
all_observation, all_info = og.sim.viewer_camera.get_obs()

seg_semantic = all_observation["seg_semantic"]
seg_semantic_info = all_info["seg_semantic"]
Expand All @@ -57,7 +52,6 @@ def test_segmentation_modalities(env):
825831922: "floors",
884110082: "stain",
1949122937: "breakfast_table",
2814990211: "agent",
3051938632: "white_rice",
3330677804: "water",
4207839377: "dishtowel",
Expand All @@ -68,38 +62,31 @@ def test_segmentation_modalities(env):
seg_instance_info = all_info["seg_instance"]
assert set(int(x.item()) for x in th.unique(seg_instance)) == set(seg_instance_info.keys())
expected_dict = {
1: "unlabelled",
2: env.robots[0].name,
3: "groundPlane",
4: "dishtowel",
5: "breakfast_table",
6: "stain",
# 7: "water",
# 8: "white_rice",
9: "diced__apple",
2: "groundPlane",
3: "water",
4: "diced__apple",
5: "stain",
6: "white_rice",
7: "breakfast_table",
8: "dishtowel",
}
assert set(seg_instance_info.values()) == set(expected_dict.values())

seg_instance_id = all_observation["seg_instance_id"]
seg_instance_id_info = all_info["seg_instance_id"]
assert set(int(x.item()) for x in th.unique(seg_instance_id)) == set(seg_instance_id_info.keys())
expected_dict = {
3: f"/World/{env.robots[0].name}/gripper_link/visuals",
4: f"/World/{env.robots[0].name}/wrist_roll_link/visuals",
5: f"/World/{env.robots[0].name}/forearm_roll_link/visuals",
6: f"/World/{env.robots[0].name}/wrist_flex_link/visuals",
8: "/World/groundPlane/geom",
9: "/World/dishtowel/base_link_cloth",
10: f"/World/{env.robots[0].name}/r_gripper_finger_link/visuals",
11: f"/World/{env.robots[0].name}/l_gripper_finger_link/visuals",
12: "/World/breakfast_table/base_link/visuals",
13: "stain",
14: "white_rice",
15: "diced__apple",
16: "water",
1: "/World/ground_plane/geom",
2: "/World/scene_0/breakfast_table/base_link/visuals",
3: "/World/scene_0/dishtowel/base_link_cloth",
4: "/World/scene_0/water/waterInstancer0/prototype0",
5: "/World/scene_0/white_rice/white_riceInstancer0/prototype0",
6: "/World/scene_0/diced__apple/particles/diced__appleParticle1",
7: "/World/scene_0/breakfast_table/base_link/stainParticle1",
8: "/World/scene_0/breakfast_table/base_link/stainParticle0",
9: "/World/scene_0/diced__apple/particles/diced__appleParticle0",
}
# Temporarily disable this test because og_assets are outdated on CI machines
# assert set(seg_instance_id_info.values()) == set(expected_dict.values())
assert set(seg_instance_id_info.values()) == set(expected_dict.values())

for system in systems:
env.scene.clear_system(system.name)
Expand All @@ -112,34 +99,29 @@ def test_bbox_modalities(env):
robot = env.scene.robots[0]
place_obj_on_floor_plane(breakfast_table)
dishtowel.set_position_orientation(position=[-0.4, 0.0, 0.55], orientation=[0, 0, 0, 1])
robot.set_position_orientation(
position=[0, 0.8, 0.0], orientation=T.euler2quat(th.tensor([0, 0, -math.pi / 2], dtype=th.float32))
)
robot.reset()

og.sim.viewer_camera.set_position_orientation(position=[-0.0017, -0.1072, 1.4969], orientation=[0.0, 0.0, 0.0, 1.0])

modalities_required = ["bbox_2d_tight", "bbox_2d_loose", "bbox_3d"]
for modality in modalities_required:
robot.add_obs_modality(modality)
og.sim.viewer_camera.add_modality(modality)

og.sim.step()
for _ in range(3):
og.sim.render()

sensors = [s for s in robot.sensors.values() if isinstance(s, VisionSensor)]
assert len(sensors) > 0
vision_sensor = sensors[0]
all_observation, all_info = vision_sensor.get_obs()
all_observation, all_info = og.sim.viewer_camera.get_obs()

bbox_2d_tight = all_observation["bbox_2d_tight"]
bbox_2d_loose = all_observation["bbox_2d_loose"]
bbox_3d = all_observation["bbox_3d"]

assert len(bbox_2d_tight) == 4
assert len(bbox_2d_loose) == 4
assert len(bbox_3d) == 3
assert len(bbox_2d_tight) == 3
assert len(bbox_2d_loose) == 3
assert len(bbox_3d) == 2

bbox_2d_expected_objs = set(["floors", "agent", "breakfast_table", "dishtowel"])
bbox_3d_expected_objs = set(["agent", "breakfast_table", "dishtowel"])
bbox_2d_expected_objs = set(["floors", "breakfast_table", "dishtowel"])
bbox_3d_expected_objs = set(["breakfast_table", "dishtowel"])

bbox_2d_objs = set([semantic_class_id_to_name()[bbox[0]] for bbox in bbox_2d_tight])
bbox_3d_objs = set([semantic_class_id_to_name()[bbox[0]] for bbox in bbox_3d])
Expand Down

0 comments on commit 70b4ce7

Please sign in to comment.