Skip to content

Commit

Permalink
clean up stuff, add some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
codekansas committed May 14, 2024
1 parent 2049904 commit 749d43e
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 131 deletions.
30 changes: 6 additions & 24 deletions kol/formats/mjcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,7 @@ class Mesh:
scale: tuple[float, float, float] | None = None

def to_xml(self, root: ET.Element | None = None) -> ET.Element:
if root is None:
mesh = ET.Element("mesh")
else:
mesh = ET.SubElement(root, "mesh")
mesh = ET.Element("mesh") if root is None else ET.SubElement(root, "mesh")
mesh.set("name", self.name)
mesh.set("file", self.file)
if self.scale is not None:
Expand All @@ -66,10 +63,7 @@ class Joint:
stiffness: float

def to_xml(self, root: ET.Element | None = None) -> ET.Element:
if root is None:
joint = ET.Element("joint")
else:
joint = ET.SubElement(root, "joint")
joint = ET.Element("joint") if root is None else ET.SubElement(root, "joint")
joint.set("name", self.name)
joint.set("type", self.type)
if self.pos is not None:
Expand Down Expand Up @@ -99,10 +93,7 @@ class Geom:
quat: tuple[float, float, float, float] | None = None

def to_xml(self, root: ET.Element | None = None) -> ET.Element:
if root is None:
geom = ET.Element("geom")
else:
geom = ET.SubElement(root, "geom")
geom = ET.Element("geom") if root is None else ET.SubElement(root, "geom")
geom.set("mesh", self.mesh)
geom.set("type", self.type)
geom.set("rgba", " ".join(map(str, self.rgba)))
Expand All @@ -124,10 +115,7 @@ class Body:
# inertial: Inertial = None

def to_xml(self, root: ET.Element | None = None) -> ET.Element:
if root is None:
body = ET.Element("body")
else:
body = ET.SubElement(root, "body")
body = ET.Element("body") if root is None else ET.SubElement(root, "body")
body.set("name", self.name)
if self.pos is not None:
body.set("pos", " ".join(map(str, self.pos)))
Expand All @@ -146,10 +134,7 @@ class Option:
viscosity: float

def to_xml(self, root: ET.Element | None = None) -> ET.Element:
if root is None:
option = ET.Element("option")
else:
option = ET.SubElement(root, "option")
option = ET.Element("option") if root is None else ET.SubElement(root, "option")
option.set("timestep", str(self.timestep))
option.set("viscosity", str(self.viscosity))
return option
Expand All @@ -162,10 +147,7 @@ class Actuator:
ctrlrange: tuple[float, float]

def to_xml(self, root: ET.Element | None = None) -> ET.Element:
if root is None:
actuator = ET.Element("actuator")
else:
actuator = ET.SubElement(root, "actuator")
actuator = ET.Element("actuator") if root is None else ET.SubElement(root, "actuator")
actuator.set("name", self.name)
actuator.set("joint", self.joint)
actuator.set("ctrlrange", " ".join(map(str, self.ctrlrange)))
Expand Down
136 changes: 65 additions & 71 deletions kol/formats/urdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@ class Origin:
xyz: tuple[float, float, float]
rpy: tuple[float, float, float]

def to_xml(self, root: ET.Element) -> ET.Element:
return ET.SubElement(
root,
"origin",
xyz=" ".join(format_number(v) for v in self.xyz),
rpy=" ".join(format_number(v) for v in self.rpy),
)
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
origin = ET.Element("origin") if root is None else ET.SubElement(root, "origin")
origin.set("xyz", " ".join(format_number(v) for v in self.xyz))
origin.set("rpy", " ".join(format_number(v) for v in self.rpy))
return origin

@classmethod
def from_matrix(cls, matrix: np.matrix) -> Self:
Expand All @@ -54,8 +52,10 @@ def zero_origin(cls) -> Self:
class Axis:
xyz: tuple[float, float, float]

def to_xml(self, root: ET.Element) -> ET.Element:
return ET.SubElement(root, "axis", xyz=" ".join(format_number(v) for v in self.xyz))
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
axis = ET.Element("axis") if root is None else ET.SubElement(root, "axis")
axis.set("xyz", " ".join(format_number(v) for v in self.xyz))
return axis


@dataclass
Expand All @@ -65,15 +65,13 @@ class JointLimits:
lower: float # radians for revolute, meters for prismatic
upper: float # radians for revolute, meters for prismatic

def to_xml(self, root: ET.Element) -> ET.Element:
return ET.SubElement(
root,
"limit",
effort=format_number(self.effort),
velocity=format_number(self.velocity),
lower=format_number(self.lower),
upper=format_number(self.upper),
)
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
limit = ET.Element("limit") if root is None else ET.SubElement(root, "limit")
limit.set("effort", format_number(self.effort))
limit.set("velocity", format_number(self.velocity))
limit.set("lower", format_number(self.lower))
limit.set("upper", format_number(self.upper))
return limit


@dataclass
Expand All @@ -82,28 +80,24 @@ class JointMimic:
multiplier: float = 1.0
offset: float = 0.0

def to_xml(self, root: ET.Element) -> ET.Element:
return ET.SubElement(
root,
"mimic",
joint=self.joint,
multiplier=format_number(self.multiplier),
offset=format_number(self.offset),
)
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
mimic = ET.Element("mimic") if root is None else ET.SubElement(root, "mimic")
mimic.set("joint", self.joint)
mimic.set("multiplier", format_number(self.multiplier))
mimic.set("offset", format_number(self.offset))
return mimic


@dataclass
class JointDynamics:
damping: float
friction: float

def to_xml(self, root: ET.Element) -> ET.Element:
return ET.SubElement(
root,
"dynamics",
damping=format_number(self.damping),
friction=format_number(self.friction),
)
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
joint = ET.Element("dynamics") if root is None else ET.SubElement(root, "dynamics")
joint.set("damping", format_number(self.damping))
joint.set("friction", format_number(self.friction))
return joint


@dataclass
Expand All @@ -113,8 +107,10 @@ class BaseJoint(ABC):
child: str
origin: Origin

def to_xml(self, root: ET.Element) -> ET.Element:
joint = ET.SubElement(root, "joint", name=self.name, type=self.joint_type())
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
joint = ET.Element("joint") if root is None else ET.SubElement(root, "joint")
joint.set("name", self.name)
joint.set("type", self.joint_type())
self.origin.to_xml(joint)
ET.SubElement(joint, "parent", link=self.parent)
ET.SubElement(joint, "child", link=self.child)
Expand All @@ -131,7 +127,7 @@ class RevoluteJoint(BaseJoint):
dynamics: JointDynamics | None = None # N*m*s/rad for damping, N*m for friction
mimic: JointMimic | None = None

def to_xml(self, root: ET.Element) -> ET.Element:
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
joint = super().to_xml(root)
self.limits.to_xml(joint)
self.axis.to_xml(joint)
Expand All @@ -149,7 +145,7 @@ def joint_type(self) -> str:
class ContinuousJoint(BaseJoint):
mimic: JointMimic | None = None

def to_xml(self, root: ET.Element) -> ET.Element:
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
joint = super().to_xml(root)
if self.mimic is not None:
self.mimic.to_xml(joint)
Expand All @@ -166,7 +162,7 @@ class PrismaticJoint(BaseJoint):
dynamics: JointDynamics | None = None # N*s/m for damping, N for friction
mimic: JointMimic | None = None

def to_xml(self, root: ET.Element) -> ET.Element:
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
joint = super().to_xml(root)
self.limits.to_xml(joint)
self.axis.to_xml(joint)
Expand All @@ -182,7 +178,7 @@ def joint_type(self) -> str:

@dataclass
class FixedJoint(BaseJoint):
def to_xml(self, root: ET.Element) -> ET.Element:
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
joint = super().to_xml(root)
return joint

Expand All @@ -194,7 +190,7 @@ def joint_type(self) -> str:
class FloatingJoint(BaseJoint):
mimic: JointMimic | None = None

def to_xml(self, root: ET.Element) -> ET.Element:
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
joint = super().to_xml(root)
if self.mimic is not None:
self.mimic.to_xml(joint)
Expand All @@ -210,7 +206,7 @@ class PlanarJoint(BaseJoint):
axis: Axis # The surface normal
mimic: JointMimic | None = None

def to_xml(self, root: ET.Element) -> ET.Element:
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
joint = super().to_xml(root)
self.limits.to_xml(joint)
self.axis.to_xml(joint)
Expand All @@ -231,17 +227,15 @@ class Inertia:
ixz: float
iyz: float

def to_xml(self, root: ET.Element) -> ET.Element:
return ET.SubElement(
root,
"inertia",
ixx=format_number(self.ixx),
iyy=format_number(self.iyy),
izz=format_number(self.izz),
ixy=format_number(self.ixy),
ixz=format_number(self.ixz),
iyz=format_number(self.iyz),
)
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
inertia = ET.Element("inertia") if root is None else ET.SubElement(root, "inertia")
inertia.set("ixx", format_number(self.ixx))
inertia.set("iyy", format_number(self.iyy))
inertia.set("izz", format_number(self.izz))
inertia.set("ixy", format_number(self.ixy))
inertia.set("ixz", format_number(self.ixz))
inertia.set("iyz", format_number(self.iyz))
return inertia


@dataclass
Expand All @@ -250,8 +244,8 @@ class InertialLink:
inertia: Inertia
origin: Origin

def to_xml(self, root: ET.Element) -> ET.Element:
inertial = ET.SubElement(root, "inertial")
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
inertial = ET.Element("inertial") if root is None else ET.SubElement(root, "inertial")
ET.SubElement(inertial, "mass", value=format_number(self.mass))
self.inertia.to_xml(inertial)
self.origin.to_xml(inertial)
Expand All @@ -261,15 +255,15 @@ def to_xml(self, root: ET.Element) -> ET.Element:
@dataclass
class BaseGeometry(ABC):
@abstractmethod
def to_xml(self, root: ET.Element) -> ET.Element: ...
def to_xml(self, root: ET.Element | None = None) -> ET.Element: ...


@dataclass
class BoxGeometry(BaseGeometry):
size: tuple[float, float, float]

def to_xml(self, root: ET.Element) -> ET.Element:
geometry = ET.SubElement(root, "geometry")
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
geometry = ET.Element("geometry") if root is None else ET.SubElement(root, "geometry")
ET.SubElement(geometry, "box", size=" ".join(format_number(v) for v in self.size))
return geometry

Expand All @@ -279,8 +273,8 @@ class CylinderGeometry(BaseGeometry):
radius: float
length: float

def to_xml(self, root: ET.Element) -> ET.Element:
geometry = ET.SubElement(root, "geometry")
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
geometry = ET.Element("geometry") if root is None else ET.SubElement(root, "geometry")
ET.SubElement(
geometry,
"cylinder",
Expand All @@ -294,8 +288,8 @@ def to_xml(self, root: ET.Element) -> ET.Element:
class SphereGeometry(BaseGeometry):
radius: float

def to_xml(self, root: ET.Element) -> ET.Element:
geometry = ET.SubElement(root, "geometry")
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
geometry = ET.Element("geometry") if root is None else ET.SubElement(root, "geometry")
ET.SubElement(geometry, "sphere", radius=format_number(self.radius))
return geometry

Expand All @@ -304,8 +298,8 @@ def to_xml(self, root: ET.Element) -> ET.Element:
class MeshGeometry(BaseGeometry):
filename: str

def to_xml(self, root: ET.Element) -> ET.Element:
geometry = ET.SubElement(root, "geometry")
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
geometry = ET.Element("geometry") if root is None else ET.SubElement(root, "geometry")
ET.SubElement(geometry, "mesh", filename=self.filename)
return geometry

Expand All @@ -318,8 +312,8 @@ class Material:
name: str
color: list[float]

def to_xml(self, root: ET.Element) -> ET.Element:
material = ET.SubElement(root, "material", name=self.name)
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
material = ET.Element("material") if root is None else ET.SubElement(root, "material")
ET.SubElement(material, "color", rgba=" ".join(format_number(v) for v in self.color))
return material

Expand Down Expand Up @@ -348,8 +342,8 @@ class VisualLink:
geometry: BaseGeometry
material: Material

def to_xml(self, root: ET.Element) -> ET.Element:
visual = ET.SubElement(root, "visual")
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
visual = ET.Element("visual") if root is None else ET.SubElement(root, "visual")
self.origin.to_xml(visual)
self.geometry.to_xml(visual)
self.material.to_xml(visual)
Expand All @@ -361,8 +355,8 @@ class CollisionLink:
origin: Origin
geometry: BaseGeometry

def to_xml(self, root: ET.Element) -> ET.Element:
collision = ET.SubElement(root, "collision")
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
collision = ET.Element("collision") if root is None else ET.SubElement(root, "collision")
self.origin.to_xml(collision)
self.geometry.to_xml(collision)
return collision
Expand All @@ -375,8 +369,8 @@ class Link:
collision: CollisionLink | None = None
inertial: InertialLink | None = None

def to_xml(self, root: ET.Element) -> ET.Element:
link = ET.SubElement(root, "link", name=self.name)
def to_xml(self, root: ET.Element | None = None) -> ET.Element:
link = ET.Element("link") if root is None else ET.SubElement(root, "link")
if self.visual is not None:
self.visual.to_xml(link)
if self.collision is not None:
Expand Down
Loading

0 comments on commit 749d43e

Please sign in to comment.