Skip to content

Commit

Permalink
fix conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
codekansas committed Feb 7, 2025
1 parent f0b8578 commit 65182a4
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 52 deletions.
139 changes: 87 additions & 52 deletions urdf2mjcf/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
import colorlogging
from pydantic import BaseModel

from urdf2mjcf.postprocess.merge_fixed import remove_fixed_joints
from urdf2mjcf.utils import save_xml

logger = logging.getLogger(__name__)

ROOT_BODY_NAME = "root"
ROOT_SITE_NAME = f"{ROOT_BODY_NAME}_site"


class JointParam(BaseModel):
kp: float | None = None
Expand Down Expand Up @@ -432,6 +430,7 @@ def rpy_to_quat(rpy_str: str) -> str:

def add_sensors(
mjcf_root: ET.Element,
root_link_name: str,
imus: Sequence[ImuSensor] | None = None,
) -> None:
"""Add sensors to the MJCF model.
Expand All @@ -444,43 +443,43 @@ def add_sensors(
if sensor_elem is None:
sensor_elem = ET.SubElement(mjcf_root, "sensor")

# Adds sensors for global reference frame values.
ET.SubElement(
sensor_elem,
"framepos",
attrib={
"name": "base_link_pos",
"objtype": "site",
"objname": ROOT_SITE_NAME,
},
)
ET.SubElement(
sensor_elem,
"framequat",
attrib={
"name": "base_link_quat",
"objtype": "site",
"objname": ROOT_SITE_NAME,
},
)
ET.SubElement(
sensor_elem,
"framelinvel",
attrib={
"name": "base_link_vel",
"objtype": "site",
"objname": ROOT_SITE_NAME,
},
)
ET.SubElement(
sensor_elem,
"frameangvel",
attrib={
"name": "base_link_ang_vel",
"objtype": "site",
"objname": ROOT_SITE_NAME,
},
)
def add_base_sensors(link_name: str) -> None:
ET.SubElement(
sensor_elem,
"framepos",
attrib={
"name": "base_link_pos",
"objtype": "site",
"objname": link_name,
},
)
ET.SubElement(
sensor_elem,
"framequat",
attrib={
"name": "base_link_quat",
"objtype": "site",
"objname": link_name,
},
)
ET.SubElement(
sensor_elem,
"framelinvel",
attrib={
"name": "base_link_vel",
"objtype": "site",
"objname": link_name,
},
)
ET.SubElement(
sensor_elem,
"frameangvel",
attrib={
"name": "base_link_ang_vel",
"objtype": "site",
"objname": link_name,
},
)

if imus:
for imu in imus:
Expand Down Expand Up @@ -542,6 +541,12 @@ def add_sensors(
mag_attrib["noise"] = str(imu.mag_noise)
ET.SubElement(sensor_elem, "magnetometer", attrib=mag_attrib)

# Adds other sensors.
add_base_sensors(imu.link_name)

else:
add_base_sensors(root_link_name)


def convert_urdf_to_mjcf(
urdf_path: str | Path,
Expand Down Expand Up @@ -716,7 +721,7 @@ def build_body(
link_name: str,
joint: ET.Element | None = None,
actuator_joints: list[ParsedJointParams] = actuator_joints,
) -> ET.Element:
) -> ET.Element | None:
"""Recursively build a MJCF body element from a URDF link."""
link: ET.Element = link_map[link_name]

Expand All @@ -738,15 +743,18 @@ def build_body(
# Add joint element if this is not the root and the joint type is not fixed.
if joint is not None:
jtype: str = joint.attrib.get("type", "fixed")
if jtype != "fixed":

if jtype in ("revolute", "continuous", "prismatic"):
j_name: str = joint.attrib.get("name", link_name + "_joint")
j_attrib: dict[str, str] = {"name": j_name}

if jtype in ["revolute", "continuous"]:
j_attrib["type"] = "hinge"
elif jtype == "prismatic":
j_attrib["type"] = "slide"
else:
j_attrib["type"] = jtype
raise ValueError(f"Unsupported joint type: {jtype}")

limit = joint.find("limit")
if limit is not None:
lower_val = limit.attrib.get("lower")
Expand Down Expand Up @@ -891,11 +899,29 @@ def build_body(
if link_name in parent_map:
for child_name, child_joint in parent_map[link_name]:
child_body = build_body(child_name, child_joint, actuator_joints)
body.append(child_body)
if child_body is not None:
body.append(child_body)
return body

# Build the robot body hierarchy starting from the root link.
robot_body: ET.Element = build_body(root_link_name, None, actuator_joints)
robot_body = build_body(root_link_name, None, actuator_joints)
if robot_body is None:
raise ValueError("Failed to build robot body")

# Adds free joint to the root link.
ET.SubElement(
robot_body,
"joint",
attrib={"name": "floating_base", "type": "free"},
)

# Adds a site to the root link.
root_site_name = f"{root_link_name}_site"
ET.SubElement(
robot_body,
"site",
attrib={"name": root_site_name, "pos": "0 0 0", "quat": "1 0 0 0"},
)

# Automatically compute the base offset using the model's minimum z coordinate.
identity: list[list[float]] = [
Expand All @@ -908,12 +934,13 @@ def build_body(
computed_offset: float = -min_z
logger.info("Auto-detected base offset: %s (min z = %s)", computed_offset, min_z)

# Create a root body with a freejoint and an IMU site; the z position uses the computed offset.
root_body = ET.Element("body", attrib={"name": ROOT_BODY_NAME, "pos": f"0 0 {computed_offset}", "quat": "1 0 0 0"})
ET.SubElement(root_body, "freejoint", attrib={"name": ROOT_BODY_NAME})
ET.SubElement(root_body, "site", attrib={"name": ROOT_SITE_NAME, "pos": "0 0 0", "quat": "1 0 0 0"})
root_body.append(robot_body)
worldbody.append(root_body)
# Moves the robot body to the computed offset.
body_pos = robot_body.attrib.get("pos", "0 0 0")
body_pos = [float(x) for x in body_pos.split()]
body_pos[2] += computed_offset
robot_body.attrib["pos"] = " ".join(f"{x:.8f}" for x in body_pos)

worldbody.append(robot_body)

# Replace the actuator block with one that uses positional control.
actuator_elem = ET.SubElement(mjcf_root, "actuator")
Expand Down Expand Up @@ -947,7 +974,7 @@ def build_body(
add_worldbody_elements(mjcf_root)

# Add sensors after adding worldbody elements
add_sensors(mjcf_root, imus=metadata.imus)
add_sensors(mjcf_root, root_site_name, imus=metadata.imus)

# Add mesh assets to the asset section.
asset_elem: ET.Element | None = mjcf_root.find("asset")
Expand Down Expand Up @@ -1000,6 +1027,11 @@ def main() -> None:
type=str,
help="A JSON file containing conversion metadata (joint params and sensors).",
)
parser.add_argument(
"--merge-fixed",
action="store_true",
help="Merge fixed joints into their parent body.",
)

args = parser.parse_args()

Expand Down Expand Up @@ -1030,6 +1062,9 @@ def main() -> None:
metadata=metadata,
)

if args.merge_fixed:
remove_fixed_joints(args.output)


if __name__ == "__main__":
main()
Empty file.
76 changes: 76 additions & 0 deletions urdf2mjcf/postprocess/merge_fixed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Defines a post-processing function that merges MJCF fixed joints into their parent body."""

import argparse
import logging
import xml.etree.ElementTree as ET
from pathlib import Path

logger = logging.getLogger(__name__)


def remove_fixed_joints(mjcf_path: str | Path) -> None:
"""Merges fixed joints into their parent body.
This function works by finding all of the body links which do not have a
joint element, and converting them from a body element to a site element.
Args:
mjcf_path: The path to the MJCF file to process.
"""
tree = ET.parse(mjcf_path)
root = tree.getroot()

# Find all body elements
worldbody = root.find(".//worldbody")
if worldbody is None:
return

bodies_to_merge = []

# Find all bodies that don't have joints
for parent_body in worldbody.findall(".//body"):
if parent_body.find("freejoint") is not None:
continue

for child_body in parent_body.findall("body"):
if child_body.find("joint") is not None:
continue

bodies_to_merge.append((parent_body, child_body))

for parent_body, child_body in bodies_to_merge:
parent_name = parent_body.attrib["name"]
child_name = child_body.attrib["name"]
logger.info("Merging body %s into %s", child_name, parent_name)

# Create a site element at the position of the merged body
site = ET.SubElement(parent_body, "site")
site.set("name", child_name)
for attr in ["pos", "quat", "euler"]:
if child_body.get(attr):
site.set(attr, child_body.get(attr))

# Transfer all child elements except inertial to the parent
for grandchild in child_body:
if grandchild.tag != "inertial":
child_body.remove(grandchild)
parent_body.append(grandchild)

# Remove the merged body
parent_body.remove(child_body)

# Save the modified XML
tree.write(mjcf_path, encoding="utf-8", xml_declaration=True)


def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("mjcf_path", type=Path)
args = parser.parse_args()

remove_fixed_joints(args.mjcf_path)


if __name__ == "__main__":
# python -m urdf2mjcf.postprocess.merge_fixed
main()

0 comments on commit 65182a4

Please sign in to comment.