diff --git a/doc/docs/Python_Tutorials/Basics.md b/doc/docs/Python_Tutorials/Basics.md index 0b6df296b..7a5e40de1 100644 --- a/doc/docs/Python_Tutorials/Basics.md +++ b/doc/docs/Python_Tutorials/Basics.md @@ -1192,13 +1192,7 @@ sim = mp.Simulation(resolution=50, cell_size=cell_size, geometry=geometry) -sim.init_sim() - -eps_data = sim.get_epsilon() - -from mayavi import mlab -s = mlab.contour3d(eps_data, colormap="YlGnBu") -mlab.show() +sim.plot3D() ``` ![](../images/prism_epsilon.png#center) diff --git a/doc/docs/images/prism_epsilon.png b/doc/docs/images/prism_epsilon.png index 2a3fbfe32..0e85cb51f 100644 Binary files a/doc/docs/images/prism_epsilon.png and b/doc/docs/images/prism_epsilon.png differ diff --git a/python/simulation.py b/python/simulation.py index 2f3491fc4..093e5baa0 100644 --- a/python/simulation.py +++ b/python/simulation.py @@ -4808,14 +4808,25 @@ def plot_fields(self, **kwargs): return vis.plot_fields(self, **kwargs) - def plot3D(self): + def plot3D( + self, save_to_image: bool = False, image_name: str = "sim.png", **kwargs + ): """ - Uses Mayavi to render a 3D simulation domain. The simulation object must be 3D. + Uses vispy to render a 3D scene of the simulation object. The simulation object must be 3D. Can also be embedded in Jupyter notebooks. + + Args: + save_to_image: if True, saves the image to a file + image_name: the name of the image file to save to + + kwargs: Camera settings. + scale_factor: float, camera zoom factor + azimuth: float, azimuthal angle in degrees + elevation: float, elevation angle in degrees """ import meep.visualization as vis - return vis.plot3D(self) + return vis.plot3D(self, save_to_image, image_name, **kwargs) def visualize_chunks(self): """ diff --git a/python/visualization.py b/python/visualization.py index 8fe1ac21c..58d13e350 100644 --- a/python/visualization.py +++ b/python/visualization.py @@ -16,6 +16,7 @@ from matplotlib.figure import Figure from typing import Callable, Union, Any, Tuple, List, Optional + # ------------------------------------------------------- # # Visualization # ------------------------------------------------------- # @@ -398,7 +399,7 @@ def sort_points(xy): ax.plot( [a.y for a in intersection], [a.z for a in intersection], - **line_args + **line_args, ) return ax # Plot XZ @@ -406,7 +407,7 @@ def sort_points(xy): ax.plot( [a.x for a in intersection], [a.z for a in intersection], - **line_args + **line_args, ) return ax # Plot XY @@ -414,7 +415,7 @@ def sort_points(xy): ax.plot( [a.x for a in intersection], [a.y for a in intersection], - **line_args + **line_args, ) return ax else: @@ -988,27 +989,176 @@ def plot2D( return ax -def plot3D(sim: Simulation): - from mayavi import mlab +def plot3D(sim, save_to_image: bool = False, image_name: str = "sim.png", **kwargs): + from vispy.scene.visuals import Box, Mesh + from vispy.scene import SceneCanvas, transforms - if sim.dimensions < 3: - raise ValueError("Simulation must have 3 dimensions to visualize 3D") + try: + from skimage.measure import marching_cubes + except: + from skimage.measure import marching_cubes_lewiner as marching_cubes + from vispy.visuals.filters import ShadingFilter - xmin, xmax, ymin, ymax, zmin, zmax = box_vertices( - sim.geometry_center, sim.cell_size + # Set canvas + canvas = SceneCanvas(keys="interactive", bgcolor="white") + + view = canvas.central_widget.add_view() + view.camera = "turntable" + + # Get domain measurements + sim_center, sim_size = sim.geometry_center, sim.cell_size + + xmin, xmax, ymin, ymax, zmin, zmax = mp.visualization.box_vertices( + sim_center, sim_size, sim.is_cylindrical ) - Nx = int(sim.cell_size.x * sim.resolution) + 1 - Ny = int(sim.cell_size.y * sim.resolution) + 1 - Nz = int(sim.cell_size.z * sim.resolution) + 1 + grid_resolution = sim.resolution + + Nx = int((xmax - xmin) * grid_resolution + 1) + Ny = int((ymax - ymin) * grid_resolution + 1) + Nz = int((zmax - zmin) * grid_resolution + 1) xtics = np.linspace(xmin, xmax, Nx) ytics = np.linspace(ymin, ymax, Ny) ztics = np.linspace(zmin, zmax, Nz) - eps_data = sim.get_epsilon_grid(xtics, ytics, ztics) - s = mlab.contour3d(eps_data, colormap="YlGnBu") - return s + # Get eps for geometry + eps_data = np.round(np.real(sim.get_epsilon_grid(xtics, ytics, ztics)), 2) + + unique = np.unique(np.abs(eps_data)).tolist() + + # Remove background material + unique.remove(np.round(np.abs(np.asarray(sim.default_material.epsilon_diag)), 2)[0]) + + mesh_midpoint = (sim_size[0] / 2, sim_size[1] / 2, sim_size[2] / 2) + + light_dir = (0, 0, -1, 0) + + # Build geometry + for i, eps in enumerate(unique): + eps_ = np.array(eps_data.flatten() == eps).astype(int).reshape(eps_data.shape) + marching_cube = marching_cubes( + eps_, + 0.99, + spacing=(sim.cell_size.x / Nx, sim.cell_size.y / Ny, sim.cell_size.z / Nz), + ) + vertices, faces = marching_cube[0], marching_cube[1] + + mesh = Mesh( + vertices, + faces, + color=( + 1 - ((i + 1) / len(unique)), + 1 - ((i + 1) / len(unique)), + 1 - ((i + 1) / len(unique)), + 0.8, + ), + ) + + mesh.transform = transforms.MatrixTransform() + mesh.transform.translate(np.asarray(sim.geometry_center)) + shading_filter = ShadingFilter(shininess=100) + shading_filter.light_dir = light_dir[:3] + mesh.attach(shading_filter) + view.add(mesh) + + # Build source + thickness = ( + sim.boundary_layers[0].thickness if not len(sim.boundary_layers) < 1 else 0 + ) + for source in sim.sources: + size = tuple(source.size) + source_box = Box( + *size, + color=(1, 0, 0, 1), # red + ) + center = list(source.center) + source_box.transform = transforms.MatrixTransform() + source_box.transform.translate(np.asarray(mesh_midpoint)) + source_box.transform.translate(center) + source_box.transform.translate(tuple(sim.geometry_center)) + view.add(source_box) + + # Build monitors + for mon in sim.dft_objects: + for reg in mon.regions: + size = list(reg.size) + monitor_box = Box( + *size, + color=(0, 0, 1, 1), # blue + ) + center = list(reg.center) + monitor_box.transform = transforms.MatrixTransform() + vector = [0, 0, 0] + vector[reg.direction] = 1 + vector = mp.Vector3(*vector) + monitor_box.transform.translate(tuple(mesh_midpoint)) + monitor_box.transform.translate(center) + monitor_box.transform.translate(tuple(sim.geometry_center)) + view.add(monitor_box) + + # Build boundaries + for box_center_top in [ + np.add(mesh_midpoint, (0, 0, sim_size[2] / 2 - thickness / 2)), + np.subtract(mesh_midpoint, (0, 0, sim_size[2] / 2 - thickness / 2)), + ]: + box = _build_3d_pml(sim_size[0], sim_size[1], thickness, box_center_top) + view.add(box) + + for box_center_right in [ + np.add(mesh_midpoint, (sim_size[0] / 2 - thickness / 2, 0, 0)), + np.subtract(mesh_midpoint, (sim_size[0] / 2 - thickness / 2, 0, 0)), + ]: + box = _build_3d_pml(thickness, sim_size[1], sim_size[2], box_center_right) + view.add(box) + + for box_center_front in [ + np.add(mesh_midpoint, (0, sim_size[1] / 2 - thickness / 2, 0)), + np.subtract(mesh_midpoint, (0, sim_size[1] / 2 - thickness / 2, 0)), + ]: + box = _build_3d_pml(sim_size[0], thickness, sim_size[2], box_center_front) + view.add(box) + + # Camera options + view.camera.center = mesh_midpoint + view.camera.scale_factor = getattr( + kwargs, "scale_factor", 2 * np.linalg.norm(sim_size) + ) + view.camera.elevation = getattr(kwargs, "elevation", 10) + view.camera.azimuth = getattr(kwargs, "azimuth", 45) + view.camera.transform.imap(light_dir) + + # Plot or save + if save_to_image: + image = canvas.render() + import imageio + + imageio.imwrite(image_name, image) + + return + + canvas.show(run=True) + + +def _build_3d_pml(x: float, y: float, thickness: float, translate: tuple): + from vispy.scene.visuals import Box + from vispy.scene import transforms + from vispy.visuals.filters import WireframeFilter + + box = Box( + x, + y, + thickness, + color=(0, 1, 0, 0.2), # green but transparent + # color=None, + ) + box.transform = transforms.MatrixTransform() + box.transform.rotate(90, (1, 0, 0)) + box.transform.translate(translate) + wireframe_filter = WireframeFilter(width=2) + box.mesh.attach(wireframe_filter) + + return box def visualize_chunks(sim: Simulation): @@ -1446,7 +1596,7 @@ def to_jshtml(self, fps: int) -> JS_Animation: Nframes=Nframes, fill_frames=fill_frames, interval=interval, - **mode_dict + **mode_dict, ) return JS_Animation(html_string)