Skip to content

Commit

Permalink
Merge pull request #280 from boutproject/polygon-plots
Browse files Browse the repository at this point in the history
Polygonal 2D poloidal plots
  • Loading branch information
bendudson authored Jun 25, 2024
2 parents d02fe5a + ba35f46 commit d8b79ee
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 19 deletions.
6 changes: 6 additions & 0 deletions xbout/boutdataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,6 +1072,12 @@ def pcolormesh(self, ax=None, **kwargs):
"""
return plotfuncs.plot2d_wrapper(self.data, xr.plot.pcolormesh, ax=ax, **kwargs)

def polygon(self, ax=None, **kwargs):
"""
Colour-plot of a radial-poloidal slice on the R-Z plane using polygons
"""
return plotfuncs.plot2d_polygon(self.data, ax=ax, **kwargs)

def plot_regions(self, ax=None, **kwargs):
"""
Plot the regions into which xBOUT splits radial-poloidal arrays to handle
Expand Down
26 changes: 26 additions & 0 deletions xbout/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,14 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
"total_poloidal_distance",
"zShift",
"zShift_ylow",
"Rxy_corners", # Lower left corners
"Rxy_lower_right_corners",
"Rxy_upper_left_corners",
"Rxy_upper_right_corners",
"Zxy_corners", # Lower left corners
"Zxy_lower_right_corners",
"Zxy_upper_left_corners",
"Zxy_upper_right_corners",
],
)

Expand Down Expand Up @@ -420,6 +428,24 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):
else:
ds = ds.set_coords(("Rxy", "Zxy"))

# Add cell corners as coordinates for polygon plotting
if "Rxy_lower_right_corners" in ds:
ds = ds.rename(
Rxy_corners="Rxy_lower_left_corners", Zxy_corners="Zxy_lower_left_corners"
)
ds = ds.set_coords(
(
"Rxy_lower_left_corners",
"Rxy_lower_right_corners",
"Rxy_upper_left_corners",
"Rxy_upper_right_corners",
"Zxy_lower_left_corners",
"Zxy_lower_right_corners",
"Zxy_upper_left_corners",
"Zxy_upper_right_corners",
)
)

# Rename zShift_ylow if it was added from grid file, to be consistent with name if
# it was added from dump file
if "zShift_CELL_YLOW" in ds and "zShift_ylow" in ds:
Expand Down
5 changes: 4 additions & 1 deletion xbout/plotting/animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def animate_poloidal(
cax=None,
animate_over=None,
separatrix=True,
separatrix_kwargs=dict(),
targets=True,
add_limiter_hatching=True,
cmap=None,
Expand Down Expand Up @@ -130,6 +131,8 @@ def animate_poloidal(
Dimension over which to animate, defaults to the time dimension
separatrix : bool, optional
Add dashed lines showing separatrices
separatrix_kwargs : dict, optional
Options to pass to the separatrix plotter (e.g. line color)
targets : bool, optional
Draw solid lines at the target surfaces
add_limiter_hatching : bool, optional
Expand Down Expand Up @@ -277,7 +280,7 @@ def animate_poloidal(
targets = False

if separatrix:
plot_separatrices(da_regions, ax, x=x, y=y)
plot_separatrices(da_regions, ax, x=x, y=y, **separatrix_kwargs)

if targets:
plot_targets(da_regions, ax, x=x, y=y, hatching=add_limiter_hatching)
Expand Down
185 changes: 182 additions & 3 deletions xbout/plotting/plotfuncs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Sequence
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from pathlib import Path
from tempfile import TemporaryDirectory
Expand Down Expand Up @@ -752,9 +753,9 @@ def create_or_update_plot(plot_objects=None, tind=None, this_save_as=None):
X, Y, Z, scalars=data, vmin=vmin, vmax=vmax, **kwargs
)
else:
plot_objects[
region_name + str(i)
].mlab_source.scalars = data
plot_objects[region_name + str(i)].mlab_source.scalars = (
data
)

if mayavi_view is not None:
mlab.view(*mayavi_view)
Expand Down Expand Up @@ -849,3 +850,181 @@ def animation_func():
plt.show()
else:
raise ValueError(f"Unrecognised plot3d() 'engine' argument: {engine}")


def plot2d_polygon(
da,
ax=None,
cax=None,
cmap="viridis",
norm=None,
logscale=False,
antialias=False,
vmin=None,
vmax=None,
extend="neither",
add_colorbar=True,
colorbar_label=None,
separatrix=True,
separatrix_kwargs={"color": "white", "linestyle": "-", "linewidth": 2},
targets=False,
add_limiter_hatching=False,
grid_only=False,
linewidth=0,
linecolor="black",
):
"""
Nice looking 2D plots which have no visual artifacts around the X-point.
Parameters
----------
da : xarray.DataArray
A 2D (x,y) DataArray of data to plot
ax : Axes, optional
Axes to plot on. If not provided, will make its own.
cax : Axes, optional
Axes to plot colorbar on. If not provided, will plot on the same axes as the plot.
cmap : str or matplotlib.colors.Colormap, default "viridis"
Colormap to use for the plot
norm : matplotlib.colors.Normalize, optional
Normalization to use for the color scale
logscale : bool, default False
If True, use a symlog color scale
antialias : bool, default False
Enables antialiasing. Note: this also shows mesh cell edges - it's unclear how to disable this.
vmin : float, optional
Minimum value for the color scale
vmax : float, optional
Maximum value for the color scale
extend : str, optional, default "neither"
Extend the colorbar. Options are "neither", "both", "min", "max"
add_colorbar : bool, default True
Enable colorbar in figure?
colorbar_label : str, optional
Label for the colorbar
separatrix : bool, default True
Add lines showing separatrices
separatrix_kwargs : dict
Keyword arguments to pass custom style to the separatrices plot
targets : bool, default True
Draw solid lines at the target surfaces
add_limiter_hatching : bool, default True
Draw hatched areas at the targets
grid_only : bool, default False
Only plot the grid, not the data. This sets all the polygons to have a white face.
linewidth : float, default 0
Width of the gridlines on cell edges
linecolor : str, default "black"
Color of the gridlines on cell edges
"""

if ax is None:
fig, ax = plt.subplots(figsize=(3, 6), dpi=120)
else:
fig = ax.get_figure()

if cax is None:
cax = ax

if vmin is None:
vmin = np.nanmin(da.values)

if vmax is None:
vmax = np.nanmax(da.max().values)

if colorbar_label == None:
if "short_name" in da.attrs:
colorbar_label = da.attrs["short_name"]
elif da.name != None:
colorbar_label = da.name
else:
colorbar_label = ""

if "units" in da.attrs:
colorbar_label += f" [{da.attrs['units']}]"

if "Rxy_lower_right_corners" in da.coords:
r_nodes = [
"R",
"Rxy_lower_left_corners",
"Rxy_lower_right_corners",
"Rxy_upper_left_corners",
"Rxy_upper_right_corners",
]
z_nodes = [
"Z",
"Zxy_lower_left_corners",
"Zxy_lower_right_corners",
"Zxy_upper_left_corners",
"Zxy_upper_right_corners",
]
cell_r = np.concatenate(
[np.expand_dims(da[x], axis=2) for x in r_nodes], axis=2
)
cell_z = np.concatenate(
[np.expand_dims(da[x], axis=2) for x in z_nodes], axis=2
)
else:
raise Exception("Cell corners not present in mesh, cannot do polygon plot")

Nx = len(cell_r)
Ny = len(cell_r[0])
patches = []

# https://matplotlib.org/2.0.2/examples/api/patch_collection.html

idx = [np.array([1, 2, 4, 3, 1])]
patches = []
for i in range(Nx):
for j in range(Ny):
p = matplotlib.patches.Polygon(
np.concatenate((cell_r[i][j][tuple(idx)], cell_z[i][j][tuple(idx)]))
.reshape(2, 5)
.T,
fill=False,
closed=True,
facecolor=None,
)
patches.append(p)

norm = _create_norm(logscale, norm, vmin, vmax)

if grid_only is True:
cmap = matplotlib.colors.ListedColormap(["white"])
colors = da.data.flatten()
polys = matplotlib.collections.PatchCollection(
patches,
alpha=1,
norm=norm,
cmap=cmap,
antialiaseds=antialias,
edgecolors=linecolor,
linewidths=linewidth,
joinstyle="bevel",
)

polys.set_array(colors)

if add_colorbar:
# This produces a "foolproof" colorbar which
# is always the height of the plot
# From https://joseph-long.com/writing/colorbars/
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
fig.colorbar(polys, cax=cax, label=colorbar_label, extend=extend)
cax.grid(which="both", visible=False)

ax.add_collection(polys)

ax.set_aspect("equal", adjustable="box")
ax.set_xlabel("R [m]")
ax.set_ylabel("Z [m]")
ax.set_ylim(cell_z.min(), cell_z.max())
ax.set_xlim(cell_r.min(), cell_r.max())
ax.set_title(da.name)

if separatrix:
plot_separatrices(da, ax, x="R", y="Z", **separatrix_kwargs)

if targets:
plot_targets(da, ax, x="R", y="Z", hatching=add_limiter_hatching)
14 changes: 11 additions & 3 deletions xbout/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,10 @@ def _is_core_only(da):
return ix1 >= nx and ix2 >= nx


def plot_separatrices(da, ax, *, x="R", y="Z"):
"""Plot separatrices"""
def plot_separatrices(da, ax, *, x="R", y="Z", **kwargs):
"""
Plot separatrices. Kwargs are passed to ax.plot().
"""

if not isinstance(da, dict):
da_regions = _decompose_regions(da)
Expand Down Expand Up @@ -116,7 +118,13 @@ def plot_separatrices(da, ax, *, x="R", y="Z"):
y_sep = 0.5 * (
da_region[y].isel(**{xcoord: 0}) + da_inner[y].isel(**{xcoord: -1})
)
ax.plot(x_sep, y_sep, "k--")
default_style = {"color": "black", "linestyle": "--"}
if any(x for x in kwargs if x in ["c", "ls"]):
raise ValueError(
"When passing separatrix plot style kwargs, use 'color' and 'linestyle' instead lf 'c' and 'ls'"
)
style = {**default_style, **kwargs}
ax.plot(x_sep, y_sep, **style)


def plot_targets(da, ax, *, x="R", y="Z", hatching=True):
Expand Down
3 changes: 1 addition & 2 deletions xbout/tests/test_against_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,5 +220,4 @@ def test_new_collect_indexing_slice(self, tmp_path_factory):


@pytest.mark.skip
class test_speed_against_old_collect:
...
class test_speed_against_old_collect: ...
6 changes: 2 additions & 4 deletions xbout/tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,7 @@ def test_combine_along_y(self, tmp_path_factory, bout_xyt_example_files):
xrt.assert_identical(actual, fake)

@pytest.mark.skip
def test_combine_along_t(self):
...
def test_combine_along_t(self): ...

@pytest.mark.parametrize(
"bout_v5,metric_3D", [(False, False), (True, False), (True, True)]
Expand Down Expand Up @@ -623,8 +622,7 @@ def test_drop_vars(self, tmp_path_factory, bout_xyt_example_files):
assert "n" in ds.keys()

@pytest.mark.skip
def test_combine_along_tx(self):
...
def test_combine_along_tx(self): ...

def test_restarts(self):
datapath = Path(__file__).parent.joinpath(
Expand Down
16 changes: 10 additions & 6 deletions xbout/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,16 @@ def _1d_coord_from_spacing(spacing, dim, ds=None, *, origin_at=None):
)

point_to_use = {
spacing.metadata["bout_xdim"]: spacing.metadata.get("MXG", 0)
if spacing.metadata["keep_xboundaries"]
else 0,
spacing.metadata["bout_ydim"]: spacing.metadata.get("MYG", 0)
if spacing.metadata["keep_yboundaries"]
else 0,
spacing.metadata["bout_xdim"]: (
spacing.metadata.get("MXG", 0)
if spacing.metadata["keep_xboundaries"]
else 0
),
spacing.metadata["bout_ydim"]: (
spacing.metadata.get("MYG", 0)
if spacing.metadata["keep_yboundaries"]
else 0
),
spacing.metadata["bout_zdim"]: spacing.metadata.get("MZG", 0),
}

Expand Down

0 comments on commit d8b79ee

Please sign in to comment.