diff --git a/pytransform3d/mesh_loader.py b/pytransform3d/mesh_loader.py new file mode 100644 index 000000000..b42a26330 --- /dev/null +++ b/pytransform3d/mesh_loader.py @@ -0,0 +1,150 @@ +"""Common interface to load meshes.""" +import abc + +import numpy as np + + +def load_mesh(filename): + """Load mesh from file. + + This feature relies on optional dependencies. It can use trimesh or + Open3D to load meshes. If both are not available, it will fail. + Furthermore, some mesh formats require additional dependencies. For + example, loading collada files ('.dae' file ending) requires pycollada + and trimesh. + + Parameters + ---------- + filename : str + File in which the mesh is stored. + + Returns + ------- + mesh : MeshBase + Mesh instance. + """ + mesh = _Trimesh(filename) + loader_available = mesh.load() + + if not loader_available: # pragma: no cover + mesh = _Open3DMesh(filename) + loader_available = mesh.load() + + if not loader_available: # pragma: no cover + raise ImportError( + "Could not load mesh from '%s'. Please install one of the " + "optional dependencies 'trimesh' or 'open3d'." % filename) + + return mesh + + +class MeshBase(abc.ABC): + """Abstract base class of meshes. + + Parameters + ---------- + filename : str + File in which the mesh is stored. + """ + def __init__(self, filename): + self.filename = filename + + @abc.abstractmethod + def load(self): + """Load mesh from file. + + Returns + ------- + loader_available : bool + Is the mesh loader available? + """ + + @abc.abstractmethod + def convex_hull(self): + """Compute convex hull of mesh.""" + + @abc.abstractmethod + def get_open3d_mesh(self): + """Return Open3D mesh. + + Returns + ------- + mesh : open3d.geometry.TriangleMesh + Open3D mesh. + """ + + @property + @abc.abstractmethod + def vertices(self): + """Vertices.""" + + @property + @abc.abstractmethod + def triangles(self): + """Triangles.""" + + +class _Trimesh(MeshBase): + def __init__(self, filename): + super(_Trimesh, self).__init__(filename) + self.mesh = None + + def load(self): + try: + import trimesh + except ImportError: + return False + obj = trimesh.load(self.filename) + if isinstance(obj, trimesh.Scene): # pragma: no cover + # Special case in which we load a collada file that contains + # multiple meshes. We might lose textures. This is excluded + # from testing as it would add another dependency. + obj = obj.dump().sum() + self.mesh = obj + return True + + def convex_hull(self): + self.mesh = self.mesh.convex_hull + + def get_open3d_mesh(self): # pragma: no cover + import open3d + return open3d.geometry.TriangleMesh( + open3d.utility.Vector3dVector(self.vertices), + open3d.utility.Vector3iVector(self.triangles)) + + @property + def vertices(self): + return self.mesh.vertices + + @property + def triangles(self): + return self.mesh.faces + + +class _Open3DMesh(MeshBase): # pragma: no cover + def __init__(self, filename): + super(_Open3DMesh, self).__init__(filename) + self.mesh = None + + def load(self): + try: + import open3d + except ImportError: + return False + self.mesh = open3d.io.read_triangle_mesh(self.filename) + return True + + def convex_hull(self): + assert self.mesh is not None + self.mesh = self.mesh.compute_convex_hull()[0] + + def get_open3d_mesh(self): + return self.mesh + + @property + def vertices(self): + return np.asarray(self.mesh.vertices) + + @property + def triangles(self): + return np.asarray(self.mesh.triangles) diff --git a/pytransform3d/mesh_loader.pyi b/pytransform3d/mesh_loader.pyi new file mode 100644 index 000000000..3eab08060 --- /dev/null +++ b/pytransform3d/mesh_loader.pyi @@ -0,0 +1,30 @@ +import abc +from typing import Any + +import numpy.typing as npt + + +class MeshBase(abc.ABC): + filename: str + + def __init__(self, filename: str): ... + + @abc.abstractmethod + def load(self) -> bool: ... + + @abc.abstractmethod + def convex_hull(self): ... + + @abc.abstractmethod + def get_open3d_mesh(self) -> Any: ... + + @property + @abc.abstractmethod + def vertices(self) -> npt.ArrayLike: ... + + @property + @abc.abstractmethod + def triangles(self) -> npt.ArrayLike: ... + + +def load_mesh(filename: str) -> MeshBase: ... diff --git a/pytransform3d/plot_utils/_plot_functions.py b/pytransform3d/plot_utils/_plot_functions.py index 171fe860c..9ce1a5c87 100644 --- a/pytransform3d/plot_utils/_plot_functions.py +++ b/pytransform3d/plot_utils/_plot_functions.py @@ -6,6 +6,7 @@ from ._artists import Arrow3D from ..transformations import transform, vectors_to_points from ..rotations import unitx, unitz, perpendicular_to_vectors, norm_vector +from ..mesh_loader import load_mesh def plot_box(ax=None, size=np.ones(3), A2B=np.eye(4), ax_s=1, wireframe=True, @@ -316,8 +317,8 @@ def plot_mesh(ax=None, filename=None, A2B=np.eye(4), convex_hull=False, alpha=1.0, color="k"): """Plot mesh. - Note that this function requires the additional library 'trimesh'. - It will print a warning if trimesh is not available. + Note that this function requires the additional library to load meshes + such as trimesh or open3d. Parameters ---------- @@ -364,20 +365,14 @@ def plot_mesh(ax=None, filename=None, A2B=np.eye(4), "package directory.") return ax - try: - import trimesh - except ImportError: - warnings.warn( - "Cannot display mesh. Library 'trimesh' not installed.") - return ax - - mesh = trimesh.load(filename) + mesh = load_mesh(filename) if convex_hull: - mesh = mesh.convex_hull + mesh.convex_hull() + vertices = mesh.vertices * s vertices = np.hstack((vertices, np.ones((len(vertices), 1)))) vertices = transform(A2B, vertices)[:, :3] - vectors = np.array([vertices[[i, j, k]] for i, j, k in mesh.faces]) + vectors = np.array([vertices[[i, j, k]] for i, j, k in mesh.triangles]) if wireframe: surface = Line3DCollection(vectors) surface.set_color(color) diff --git a/pytransform3d/test/test_mesh_loader.py b/pytransform3d/test/test_mesh_loader.py new file mode 100644 index 000000000..59395831d --- /dev/null +++ b/pytransform3d/test/test_mesh_loader.py @@ -0,0 +1,54 @@ +from pytransform3d import mesh_loader + +import pytest + + +def test_trimesh(): + mesh = mesh_loader._Trimesh("test/test_data/cone.stl") + loader_available = mesh.load() + if not loader_available: + pytest.skip("trimesh is required for this test") + + assert len(mesh.vertices) == 64 + assert len(mesh.triangles) == 124 + + mesh.convex_hull() + + assert len(mesh.vertices) == 64 + + +def test_open3d(): + mesh = mesh_loader._Open3DMesh("test/test_data/cone.stl") + loader_available = mesh.load() + if not loader_available: + pytest.skip("open3d is required for this test") + + assert len(mesh.vertices) == 295 + assert len(mesh.triangles) == 124 + + o3d_mesh = mesh.get_open3d_mesh() + assert len(o3d_mesh.vertices) == 295 + + mesh.convex_hull() + + assert len(mesh.vertices) == 64 + + +def test_trimesh_with_open3d(): + mesh = mesh_loader._Trimesh("test/test_data/cone.stl") + loader_available = mesh.load() + if not loader_available: + pytest.skip("trimesh is required for this test") + try: + o3d_mesh = mesh.get_open3d_mesh() + except ImportError: + pytest.skip("open3d is required for this test") + assert len(o3d_mesh.vertices) == 64 + + +def test_interface(): + try: + mesh = mesh_loader.load_mesh("test/test_data/cone.stl") + assert len(mesh.triangles) == 124 + except ImportError: + pytest.skip("trimesh or open3d are required for this test") diff --git a/pytransform3d/visualizer/_artists.py b/pytransform3d/visualizer/_artists.py index 3aaec1b97..6219d6098 100644 --- a/pytransform3d/visualizer/_artists.py +++ b/pytransform3d/visualizer/_artists.py @@ -6,6 +6,7 @@ from .. import rotations as pr from .. import transformations as pt from .. import urdf +from .. import mesh_loader class Artist: @@ -554,11 +555,11 @@ class Mesh(Artist): """ def __init__(self, filename, A2B=np.eye(4), s=np.ones(3), c=None, convex_hull=False): - mesh = o3d.io.read_triangle_mesh(filename) + mesh = mesh_loader.load_mesh(filename) if convex_hull: - self.mesh = mesh.compute_convex_hull()[0] - else: - self.mesh = mesh + mesh.convex_hull() + + self.mesh = mesh.get_open3d_mesh() self.mesh.vertices = o3d.utility.Vector3dVector( np.asarray(self.mesh.vertices) * s) self.mesh.compute_vertex_normals() diff --git a/requirements.txt b/requirements.txt index ba9132165..85caf81c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ scipy matplotlib lxml trimesh +pycollada pydot open3d diff --git a/setup.py b/setup.py index c380cc6db..a3a5e695d 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ packages=find_packages(), install_requires=["numpy", "scipy", "matplotlib", "lxml"], extras_require={ - "all": ["pydot", "trimesh", "open3d"], + "all": ["pydot", "trimesh", "pycollada", "open3d"], "doc": ["numpydoc", "sphinx", "sphinx-gallery", "sphinx-bootstrap-theme"], "test": ["pytest", "pytest-cov"]