Skip to content

Commit

Permalink
Merge pull request #278 from boutproject/plotting-bugfix
Browse files Browse the repository at this point in the history
Plotting bugfix
  • Loading branch information
johnomotani authored Mar 16, 2023
2 parents 280a570 + 49ee207 commit 212b407
Show file tree
Hide file tree
Showing 11 changed files with 10 additions and 55 deletions.
1 change: 0 additions & 1 deletion xbout/boutdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,6 @@ def is_list(variable):
extend,
)
):

(
v,
ax,
Expand Down
3 changes: 0 additions & 3 deletions xbout/fastoutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def open_fastoutput(datapath="BOUT.fast.*.nc"):
# Iterate over all files, extracting DataArrays ready for combining
fo_data = []
for i, filepath in enumerate(filepaths):

fo = xr.open_dataset(filepath)

if i == 0:
Expand All @@ -27,9 +26,7 @@ def open_fastoutput(datapath="BOUT.fast.*.nc"):

# There might be no virtual probe in this region
if len(fo.data_vars) > 0:

for name, da in fo.items():

# Save the physical position (in index units)
da = da.expand_dims(x=1, y=1, z=1)
da = da.assign_coords(
Expand Down
2 changes: 0 additions & 2 deletions xbout/geometries.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ def _add_vars_from_grid(ds, grid, variables, *, optional_variables=None):

@register_geometry("toroidal")
def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):

coordinates = _set_default_toroidal_coordinates(coordinates, ds)

if ds.attrs.get("geometry", None) == "toroidal":
Expand Down Expand Up @@ -447,7 +446,6 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None):

@register_geometry("s-alpha")
def add_s_alpha_geometry_coords(ds, *, coordinates=None, grid=None):

coordinates = _set_default_toroidal_coordinates(coordinates, ds)

if set(coordinates.values()).issubset(set(ds.coords).union(ds.dims)):
Expand Down
1 change: 0 additions & 1 deletion xbout/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,6 @@ def collect(

# Convert indexing values to an isel suitable format
for dim, ind in zip(dims, inds):

if isinstance(ind, int):
indexer = [ind]
elif isinstance(ind, list):
Expand Down
11 changes: 9 additions & 2 deletions xbout/plotting/plotfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,14 @@ def plot2d_wrapper(
sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
cmap = sm.get_cmap()
fig.colorbar(sm, ax=ax, extend=extend)
cbar = fig.colorbar(sm, ax=ax, extend=extend)
if "long_name" in da.attrs:
cbar_label = da.long_name
else:
cbar_label = da.name
if "units" in da.attrs:
cbar_label += f" [{da.units}]"
cbar.ax.set_ylabel(cbar_label)

if method is xr.plot.pcolormesh:
if "infer_intervals" not in kwargs:
Expand Down Expand Up @@ -273,6 +280,7 @@ def plot2d_wrapper(
add_colorbar=False,
add_labels=add_label,
cmap=cmap,
norm=norm,
**kwargs,
)
for region, add_label in zip(da_regions.values(), add_labels)
Expand Down Expand Up @@ -605,7 +613,6 @@ def plot3d(
return

for region_name, da_region in _decompose_regions(da).items():

npsi, ntheta, nzeta = da_region.shape

if style == "surface":
Expand Down
3 changes: 0 additions & 3 deletions xbout/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def plot_separatrix(da, sep_pos, ax, radial_coord="x"):
# 2D domain needs to intersect the separatrix plane to be able to plot it
dims = da.dims
if radial_coord not in dims:

warnings.warn(
"Cannot plot separatrix as domain does not cross "
"separatrix, as it does not have a radial dimension",
Expand All @@ -63,7 +62,6 @@ def plot_separatrix(da, sep_pos, ax, radial_coord="x"):


def _decompose_regions(da):

if da.geometry == "fci":
return {region: da for region in da.bout._regions}
return {
Expand All @@ -73,7 +71,6 @@ def _decompose_regions(da):


def _is_core_only(da):

nx = da.metadata["nx"]
ix1 = da.metadata["ixseps1"]
ix2 = da.metadata["ixseps2"]
Expand Down
4 changes: 0 additions & 4 deletions xbout/tests/test_against_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

class TestAccuracyAgainstOldCollect:
def test_single_file(self, tmp_path_factory):

# Create temp directory for files
test_dir = tmp_path_factory.mktemp("test_data")

Expand All @@ -37,7 +36,6 @@ def test_single_file(self, tmp_path_factory):
npt.assert_equal(actual, expected)

def test_multiple_files_along_x(self, tmp_path_factory):

# Create temp directory for files
test_dir = tmp_path_factory.mktemp("test_data")

Expand Down Expand Up @@ -66,7 +64,6 @@ def test_multiple_files_along_x(self, tmp_path_factory):
npt.assert_equal(actual, expected)

def test_multiple_files_along_y(self, tmp_path_factory):

# Create temp directory for files
test_dir = tmp_path_factory.mktemp("test_data")

Expand Down Expand Up @@ -95,7 +92,6 @@ def test_multiple_files_along_y(self, tmp_path_factory):
npt.assert_equal(actual, expected)

def test_multiple_files_along_xy(self, tmp_path_factory):

# Create temp directory for files
test_dir = tmp_path_factory.mktemp("test_data")

Expand Down
33 changes: 1 addition & 32 deletions xbout/tests/test_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

@pytest.fixture
def create_test_file(tmp_path_factory):

# Create temp dir for output of animate1D/2D
save_dir = tmp_path_factory.mktemp("test_data")

Expand All @@ -37,7 +36,6 @@ class TestAnimate:
"""

def test_animate2D(self, create_test_file):

save_dir, ds = create_test_file

animation = ds["n"].isel(x=1).bout.animate2D(save_as="%s/testyz" % save_dir)
Expand Down Expand Up @@ -107,7 +105,6 @@ def test_animate2D_controls_arg(self, create_test_file, controls):
plt.close()

def test_animate1D(self, create_test_file):

save_dir, ds = create_test_file
animation = ds["n"].isel(y=2, z=0).bout.animate1D(save_as="%s/test" % save_dir)

Expand Down Expand Up @@ -152,7 +149,6 @@ def test_animate1D_controls_arg(self, create_test_file, controls):
plt.close()

def test_animate_list(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -167,7 +163,6 @@ def test_animate_list(self, create_test_file):
plt.close()

def test_animate_list_1d_default(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(y=2, z=3).bout.animate_list(
Expand All @@ -182,7 +177,6 @@ def test_animate_list_1d_default(self, create_test_file):
plt.close()

def test_animate_list_1d_multiline(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(y=2, z=3).bout.animate_list(
Expand All @@ -196,21 +190,11 @@ def test_animate_list_1d_multiline(self, create_test_file):
assert isinstance(animation.blocks[3], Line)

# check there were actually 3 subplots
assert (
len(
[
x
for x in plt.gcf().get_axes()
if isinstance(x, matplotlib.axes.Subplot)
]
)
== 3
)
assert len([x for x in plt.gcf().get_axes() if x.get_xlabel() != ""]) == 3

plt.close()

def test_animate_list_animate_over(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -225,7 +209,6 @@ def test_animate_list_animate_over(self, create_test_file):
plt.close()

def test_animate_list_save_as(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -241,7 +224,6 @@ def test_animate_list_save_as(self, create_test_file):
plt.close()

def test_animate_list_fps(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -257,7 +239,6 @@ def test_animate_list_fps(self, create_test_file):
plt.close()

def test_animate_list_nrows(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -272,7 +253,6 @@ def test_animate_list_nrows(self, create_test_file):
plt.close()

def test_animate_list_ncols(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -287,7 +267,6 @@ def test_animate_list_ncols(self, create_test_file):
plt.close()

def test_animate_list_not_enough_nrowsncols(self, create_test_file):

save_dir, ds = create_test_file

with pytest.raises(ValueError):
Expand All @@ -297,7 +276,6 @@ def test_animate_list_not_enough_nrowsncols(self, create_test_file):

@pytest.mark.skip(reason="test data for plot_poloidal needs more work")
def test_animate_list_poloidal_plot(self, create_test_file):

save_dir, ds = create_test_file

metadata = ds.metadata
Expand Down Expand Up @@ -340,7 +318,6 @@ def test_animate_list_poloidal_plot(self, create_test_file):
plt.close()

def test_animate_list_subplots_adjust(self, create_test_file):

save_dir, ds = create_test_file

with pytest.warns(UserWarning):
Expand All @@ -357,7 +334,6 @@ def test_animate_list_subplots_adjust(self, create_test_file):
plt.close()

def test_animate_list_vmin(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -372,7 +348,6 @@ def test_animate_list_vmin(self, create_test_file):
plt.close()

def test_animate_list_vmin_list(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -387,7 +362,6 @@ def test_animate_list_vmin_list(self, create_test_file):
plt.close()

def test_animate_list_vmax(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -402,7 +376,6 @@ def test_animate_list_vmax(self, create_test_file):
plt.close()

def test_animate_list_vmax_list(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -417,7 +390,6 @@ def test_animate_list_vmax_list(self, create_test_file):
plt.close()

def test_animate_list_logscale(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -432,7 +404,6 @@ def test_animate_list_logscale(self, create_test_file):
plt.close()

def test_animate_list_logscale_float(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -447,7 +418,6 @@ def test_animate_list_logscale_float(self, create_test_file):
plt.close()

def test_animate_list_logscale_list(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand All @@ -463,7 +433,6 @@ def test_animate_list_logscale_list(self, create_test_file):
plt.close()

def test_animate_list_titles_list(self, create_test_file):

save_dir, ds = create_test_file

animation = ds.isel(z=3).bout.animate_list(
Expand Down
3 changes: 0 additions & 3 deletions xbout/tests/test_boutdataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def test_to_field_aligned(self, bout_xyt_example_files, nz, permute_dims):
"permute_dims", [False, pytest.param(True, marks=pytest.mark.long)]
)
def test_to_field_aligned_dask(self, bout_xyt_example_files, permute_dims):

nz = 6

dataset_list = bout_xyt_example_files(
Expand Down Expand Up @@ -1034,7 +1033,6 @@ def test_add_cartesian_coordinates(self, bout_xyt_example_files):
)

def test_ddx(self, bout_xyt_example_files):

nx = 64

dataset_list = bout_xyt_example_files(
Expand Down Expand Up @@ -1069,7 +1067,6 @@ def test_ddx(self, bout_xyt_example_files):
)

def test_ddy(self, bout_xyt_example_files):

ny = 64

dataset_list, gridfilepath = bout_xyt_example_files(
Expand Down
2 changes: 0 additions & 2 deletions xbout/tests/test_boutdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,6 @@ def test_to_field_aligned(self, bout_xyt_example_files, nz):
) # noqa: E501

def test_to_field_aligned_dask(self, bout_xyt_example_files):

nz = 6

dataset_list = bout_xyt_example_files(
Expand Down Expand Up @@ -2027,7 +2026,6 @@ def test_reload_all(self, tmp_path_factory, bout_xyt_example_files, geometry):
def test_save_dtype(
self, tmp_path_factory, bout_xyt_example_files, save_dtype, separate_vars
):

# Create data
path = bout_xyt_example_files(
tmp_path_factory, nxpe=1, nype=1, nt=1, write_to_disk=True
Expand Down
2 changes: 0 additions & 2 deletions xbout/tests/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,6 @@ def create_bout_grid_ds(xsize=2, ysize=4, guards={}, topology="core", ny_inner=0

class TestStripMetadata:
def test_strip_metadata(self):

original = create_bout_ds()
assert original["NXPE"] == 1

Expand Down Expand Up @@ -1534,7 +1533,6 @@ def create_example_grid_file_fci(tmp_path_factory):

@pytest.fixture
def create_example_files_fci(tmp_path_factory):

return _bout_xyt_example_files(
tmp_path_factory,
lengths=fci_shape,
Expand Down

0 comments on commit 212b407

Please sign in to comment.