diff --git a/xbout/boutdataarray.py b/xbout/boutdataarray.py index b018cd91..2a2c5ed6 100644 --- a/xbout/boutdataarray.py +++ b/xbout/boutdataarray.py @@ -1072,6 +1072,12 @@ def pcolormesh(self, ax=None, **kwargs): """ return plotfuncs.plot2d_wrapper(self.data, xr.plot.pcolormesh, ax=ax, **kwargs) + def polygon(self, ax=None, **kwargs): + """ + Colour-plot of a radial-poloidal slice on the R-Z plane using polygons + """ + return plotfuncs.plot2d_polygon(self.data, ax=ax, **kwargs) + def plot_regions(self, ax=None, **kwargs): """ Plot the regions into which xBOUT splits radial-poloidal arrays to handle diff --git a/xbout/geometries.py b/xbout/geometries.py index 7a7f8ad7..2299f54b 100644 --- a/xbout/geometries.py +++ b/xbout/geometries.py @@ -381,6 +381,14 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): "total_poloidal_distance", "zShift", "zShift_ylow", + "Rxy_corners", # Lower left corners + "Rxy_lower_right_corners", + "Rxy_upper_left_corners", + "Rxy_upper_right_corners", + "Zxy_corners", # Lower left corners + "Zxy_lower_right_corners", + "Zxy_upper_left_corners", + "Zxy_upper_right_corners", ], ) @@ -420,6 +428,24 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): else: ds = ds.set_coords(("Rxy", "Zxy")) + # Add cell corners as coordinates for polygon plotting + if "Rxy_lower_right_corners" in ds: + ds = ds.rename( + Rxy_corners="Rxy_lower_left_corners", Zxy_corners="Zxy_lower_left_corners" + ) + ds = ds.set_coords( + ( + "Rxy_lower_left_corners", + "Rxy_lower_right_corners", + "Rxy_upper_left_corners", + "Rxy_upper_right_corners", + "Zxy_lower_left_corners", + "Zxy_lower_right_corners", + "Zxy_upper_left_corners", + "Zxy_upper_right_corners", + ) + ) + # Rename zShift_ylow if it was added from grid file, to be consistent with name if # it was added from dump file if "zShift_CELL_YLOW" in ds and "zShift_ylow" in ds: diff --git a/xbout/plotting/animate.py b/xbout/plotting/animate.py index 68536a97..e223983f 100644 --- a/xbout/plotting/animate.py +++ b/xbout/plotting/animate.py @@ -97,6 +97,7 @@ def animate_poloidal( cax=None, animate_over=None, separatrix=True, + separatrix_kwargs=dict(), targets=True, add_limiter_hatching=True, cmap=None, @@ -130,6 +131,8 @@ def animate_poloidal( Dimension over which to animate, defaults to the time dimension separatrix : bool, optional Add dashed lines showing separatrices + separatrix_kwargs : dict, optional + Options to pass to the separatrix plotter (e.g. line color) targets : bool, optional Draw solid lines at the target surfaces add_limiter_hatching : bool, optional @@ -277,7 +280,7 @@ def animate_poloidal( targets = False if separatrix: - plot_separatrices(da_regions, ax, x=x, y=y) + plot_separatrices(da_regions, ax, x=x, y=y, **separatrix_kwargs) if targets: plot_targets(da_regions, ax, x=x, y=y, hatching=add_limiter_hatching) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index e115d434..db0a1e67 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -1,6 +1,7 @@ from collections.abc import Sequence import matplotlib import matplotlib.pyplot as plt +from mpl_toolkits.axes_grid1 import make_axes_locatable import numpy as np from pathlib import Path from tempfile import TemporaryDirectory @@ -752,9 +753,9 @@ def create_or_update_plot(plot_objects=None, tind=None, this_save_as=None): X, Y, Z, scalars=data, vmin=vmin, vmax=vmax, **kwargs ) else: - plot_objects[ - region_name + str(i) - ].mlab_source.scalars = data + plot_objects[region_name + str(i)].mlab_source.scalars = ( + data + ) if mayavi_view is not None: mlab.view(*mayavi_view) @@ -849,3 +850,181 @@ def animation_func(): plt.show() else: raise ValueError(f"Unrecognised plot3d() 'engine' argument: {engine}") + + +def plot2d_polygon( + da, + ax=None, + cax=None, + cmap="viridis", + norm=None, + logscale=False, + antialias=False, + vmin=None, + vmax=None, + extend="neither", + add_colorbar=True, + colorbar_label=None, + separatrix=True, + separatrix_kwargs={"color": "white", "linestyle": "-", "linewidth": 2}, + targets=False, + add_limiter_hatching=False, + grid_only=False, + linewidth=0, + linecolor="black", +): + """ + Nice looking 2D plots which have no visual artifacts around the X-point. + + Parameters + ---------- + da : xarray.DataArray + A 2D (x,y) DataArray of data to plot + ax : Axes, optional + Axes to plot on. If not provided, will make its own. + cax : Axes, optional + Axes to plot colorbar on. If not provided, will plot on the same axes as the plot. + cmap : str or matplotlib.colors.Colormap, default "viridis" + Colormap to use for the plot + norm : matplotlib.colors.Normalize, optional + Normalization to use for the color scale + logscale : bool, default False + If True, use a symlog color scale + antialias : bool, default False + Enables antialiasing. Note: this also shows mesh cell edges - it's unclear how to disable this. + vmin : float, optional + Minimum value for the color scale + vmax : float, optional + Maximum value for the color scale + extend : str, optional, default "neither" + Extend the colorbar. Options are "neither", "both", "min", "max" + add_colorbar : bool, default True + Enable colorbar in figure? + colorbar_label : str, optional + Label for the colorbar + separatrix : bool, default True + Add lines showing separatrices + separatrix_kwargs : dict + Keyword arguments to pass custom style to the separatrices plot + targets : bool, default True + Draw solid lines at the target surfaces + add_limiter_hatching : bool, default True + Draw hatched areas at the targets + grid_only : bool, default False + Only plot the grid, not the data. This sets all the polygons to have a white face. + linewidth : float, default 0 + Width of the gridlines on cell edges + linecolor : str, default "black" + Color of the gridlines on cell edges + """ + + if ax is None: + fig, ax = plt.subplots(figsize=(3, 6), dpi=120) + else: + fig = ax.get_figure() + + if cax is None: + cax = ax + + if vmin is None: + vmin = np.nanmin(da.values) + + if vmax is None: + vmax = np.nanmax(da.max().values) + + if colorbar_label == None: + if "short_name" in da.attrs: + colorbar_label = da.attrs["short_name"] + elif da.name != None: + colorbar_label = da.name + else: + colorbar_label = "" + + if "units" in da.attrs: + colorbar_label += f" [{da.attrs['units']}]" + + if "Rxy_lower_right_corners" in da.coords: + r_nodes = [ + "R", + "Rxy_lower_left_corners", + "Rxy_lower_right_corners", + "Rxy_upper_left_corners", + "Rxy_upper_right_corners", + ] + z_nodes = [ + "Z", + "Zxy_lower_left_corners", + "Zxy_lower_right_corners", + "Zxy_upper_left_corners", + "Zxy_upper_right_corners", + ] + cell_r = np.concatenate( + [np.expand_dims(da[x], axis=2) for x in r_nodes], axis=2 + ) + cell_z = np.concatenate( + [np.expand_dims(da[x], axis=2) for x in z_nodes], axis=2 + ) + else: + raise Exception("Cell corners not present in mesh, cannot do polygon plot") + + Nx = len(cell_r) + Ny = len(cell_r[0]) + patches = [] + + # https://matplotlib.org/2.0.2/examples/api/patch_collection.html + + idx = [np.array([1, 2, 4, 3, 1])] + patches = [] + for i in range(Nx): + for j in range(Ny): + p = matplotlib.patches.Polygon( + np.concatenate((cell_r[i][j][tuple(idx)], cell_z[i][j][tuple(idx)])) + .reshape(2, 5) + .T, + fill=False, + closed=True, + facecolor=None, + ) + patches.append(p) + + norm = _create_norm(logscale, norm, vmin, vmax) + + if grid_only is True: + cmap = matplotlib.colors.ListedColormap(["white"]) + colors = da.data.flatten() + polys = matplotlib.collections.PatchCollection( + patches, + alpha=1, + norm=norm, + cmap=cmap, + antialiaseds=antialias, + edgecolors=linecolor, + linewidths=linewidth, + joinstyle="bevel", + ) + + polys.set_array(colors) + + if add_colorbar: + # This produces a "foolproof" colorbar which + # is always the height of the plot + # From https://joseph-long.com/writing/colorbars/ + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="5%", pad=0.05) + fig.colorbar(polys, cax=cax, label=colorbar_label, extend=extend) + cax.grid(which="both", visible=False) + + ax.add_collection(polys) + + ax.set_aspect("equal", adjustable="box") + ax.set_xlabel("R [m]") + ax.set_ylabel("Z [m]") + ax.set_ylim(cell_z.min(), cell_z.max()) + ax.set_xlim(cell_r.min(), cell_r.max()) + ax.set_title(da.name) + + if separatrix: + plot_separatrices(da, ax, x="R", y="Z", **separatrix_kwargs) + + if targets: + plot_targets(da, ax, x="R", y="Z", hatching=add_limiter_hatching) diff --git a/xbout/plotting/utils.py b/xbout/plotting/utils.py index db70b74a..8a8d3c1a 100644 --- a/xbout/plotting/utils.py +++ b/xbout/plotting/utils.py @@ -78,8 +78,10 @@ def _is_core_only(da): return ix1 >= nx and ix2 >= nx -def plot_separatrices(da, ax, *, x="R", y="Z"): - """Plot separatrices""" +def plot_separatrices(da, ax, *, x="R", y="Z", **kwargs): + """ + Plot separatrices. Kwargs are passed to ax.plot(). + """ if not isinstance(da, dict): da_regions = _decompose_regions(da) @@ -116,7 +118,13 @@ def plot_separatrices(da, ax, *, x="R", y="Z"): y_sep = 0.5 * ( da_region[y].isel(**{xcoord: 0}) + da_inner[y].isel(**{xcoord: -1}) ) - ax.plot(x_sep, y_sep, "k--") + default_style = {"color": "black", "linestyle": "--"} + if any(x for x in kwargs if x in ["c", "ls"]): + raise ValueError( + "When passing separatrix plot style kwargs, use 'color' and 'linestyle' instead lf 'c' and 'ls'" + ) + style = {**default_style, **kwargs} + ax.plot(x_sep, y_sep, **style) def plot_targets(da, ax, *, x="R", y="Z", hatching=True): diff --git a/xbout/tests/test_against_collect.py b/xbout/tests/test_against_collect.py index 5f22cf97..a1bf3da3 100644 --- a/xbout/tests/test_against_collect.py +++ b/xbout/tests/test_against_collect.py @@ -220,5 +220,4 @@ def test_new_collect_indexing_slice(self, tmp_path_factory): @pytest.mark.skip -class test_speed_against_old_collect: - ... +class test_speed_against_old_collect: ... diff --git a/xbout/tests/test_load.py b/xbout/tests/test_load.py index d8766236..bb4c917e 100644 --- a/xbout/tests/test_load.py +++ b/xbout/tests/test_load.py @@ -472,8 +472,7 @@ def test_combine_along_y(self, tmp_path_factory, bout_xyt_example_files): xrt.assert_identical(actual, fake) @pytest.mark.skip - def test_combine_along_t(self): - ... + def test_combine_along_t(self): ... @pytest.mark.parametrize( "bout_v5,metric_3D", [(False, False), (True, False), (True, True)] @@ -623,8 +622,7 @@ def test_drop_vars(self, tmp_path_factory, bout_xyt_example_files): assert "n" in ds.keys() @pytest.mark.skip - def test_combine_along_tx(self): - ... + def test_combine_along_tx(self): ... def test_restarts(self): datapath = Path(__file__).parent.joinpath( diff --git a/xbout/utils.py b/xbout/utils.py index 32be7edc..66dd9593 100644 --- a/xbout/utils.py +++ b/xbout/utils.py @@ -167,12 +167,16 @@ def _1d_coord_from_spacing(spacing, dim, ds=None, *, origin_at=None): ) point_to_use = { - spacing.metadata["bout_xdim"]: spacing.metadata.get("MXG", 0) - if spacing.metadata["keep_xboundaries"] - else 0, - spacing.metadata["bout_ydim"]: spacing.metadata.get("MYG", 0) - if spacing.metadata["keep_yboundaries"] - else 0, + spacing.metadata["bout_xdim"]: ( + spacing.metadata.get("MXG", 0) + if spacing.metadata["keep_xboundaries"] + else 0 + ), + spacing.metadata["bout_ydim"]: ( + spacing.metadata.get("MYG", 0) + if spacing.metadata["keep_yboundaries"] + else 0 + ), spacing.metadata["bout_zdim"]: spacing.metadata.get("MZG", 0), }