From 7518f6fb36857f0986abc84d6147230edf7d05bc Mon Sep 17 00:00:00 2001 From: Mike Kryjak Date: Thu, 16 Feb 2023 12:23:33 +0000 Subject: [PATCH 1/4] Add colorbar label to bout.pcolormesh --- xbout/plotting/plotfuncs.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 79b36207..7fe182c7 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -239,11 +239,18 @@ def plot2d_wrapper( kwargs["vmax"] = vmax # create colorbar - norm = _create_norm(logscale, norm, vmin, vmax) + norm = _create_norm(logscale, norm, vmin, vmax) sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap) sm.set_array([]) cmap = sm.get_cmap() - fig.colorbar(sm, ax=ax, extend=extend) + cbar = fig.colorbar(sm, ax=ax, extend=extend) + if "long_name" in da.attrs: + cbar_label = da.long_name + else: + cbar_label = da.name + if "units" in da.attrs: + cbar_label += f" [{da.units}]" + cbar.ax.set_ylabel(cbar_label) if method is xr.plot.pcolormesh: if "infer_intervals" not in kwargs: From 276609a481f30544d295c1152dc7c219c353a2cf Mon Sep 17 00:00:00 2001 From: Mike Kryjak Date: Thu, 16 Feb 2023 12:24:36 +0000 Subject: [PATCH 2/4] bugfix: pcolormesh plot was missing norm --- xbout/plotting/plotfuncs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 7fe182c7..5ac30c85 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -280,6 +280,7 @@ def plot2d_wrapper( add_colorbar=False, add_labels=add_label, cmap=cmap, + norm=norm, **kwargs, ) for region, add_label in zip(da_regions.values(), add_labels) From 503919c3c300210548effab280d409333960a510 Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 16 Mar 2023 09:05:01 +0000 Subject: [PATCH 3/4] black-23.1.0 updates --- xbout/boutdataset.py | 1 - xbout/fastoutput.py | 3 --- xbout/geometries.py | 2 -- xbout/load.py | 1 - xbout/plotting/plotfuncs.py | 3 +-- xbout/plotting/utils.py | 3 --- xbout/tests/test_against_collect.py | 4 ---- xbout/tests/test_animate.py | 22 ---------------------- xbout/tests/test_boutdataarray.py | 3 --- xbout/tests/test_boutdataset.py | 2 -- xbout/tests/test_load.py | 2 -- 11 files changed, 1 insertion(+), 45 deletions(-) diff --git a/xbout/boutdataset.py b/xbout/boutdataset.py index 94ed2537..6540fe78 100644 --- a/xbout/boutdataset.py +++ b/xbout/boutdataset.py @@ -1185,7 +1185,6 @@ def is_list(variable): extend, ) ): - ( v, ax, diff --git a/xbout/fastoutput.py b/xbout/fastoutput.py index 00606a42..251abcfb 100644 --- a/xbout/fastoutput.py +++ b/xbout/fastoutput.py @@ -15,7 +15,6 @@ def open_fastoutput(datapath="BOUT.fast.*.nc"): # Iterate over all files, extracting DataArrays ready for combining fo_data = [] for i, filepath in enumerate(filepaths): - fo = xr.open_dataset(filepath) if i == 0: @@ -27,9 +26,7 @@ def open_fastoutput(datapath="BOUT.fast.*.nc"): # There might be no virtual probe in this region if len(fo.data_vars) > 0: - for name, da in fo.items(): - # Save the physical position (in index units) da = da.expand_dims(x=1, y=1, z=1) da = da.assign_coords( diff --git a/xbout/geometries.py b/xbout/geometries.py index fa8a67b6..1cc6d727 100644 --- a/xbout/geometries.py +++ b/xbout/geometries.py @@ -349,7 +349,6 @@ def _add_vars_from_grid(ds, grid, variables, *, optional_variables=None): @register_geometry("toroidal") def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): - coordinates = _set_default_toroidal_coordinates(coordinates, ds) if ds.attrs.get("geometry", None) == "toroidal": @@ -447,7 +446,6 @@ def add_toroidal_geometry_coords(ds, *, coordinates=None, grid=None): @register_geometry("s-alpha") def add_s_alpha_geometry_coords(ds, *, coordinates=None, grid=None): - coordinates = _set_default_toroidal_coordinates(coordinates, ds) if set(coordinates.values()).issubset(set(ds.coords).union(ds.dims)): diff --git a/xbout/load.py b/xbout/load.py index 7c08f9d8..66908434 100644 --- a/xbout/load.py +++ b/xbout/load.py @@ -508,7 +508,6 @@ def collect( # Convert indexing values to an isel suitable format for dim, ind in zip(dims, inds): - if isinstance(ind, int): indexer = [ind] elif isinstance(ind, list): diff --git a/xbout/plotting/plotfuncs.py b/xbout/plotting/plotfuncs.py index 5ac30c85..9cf371d2 100644 --- a/xbout/plotting/plotfuncs.py +++ b/xbout/plotting/plotfuncs.py @@ -239,7 +239,7 @@ def plot2d_wrapper( kwargs["vmax"] = vmax # create colorbar - norm = _create_norm(logscale, norm, vmin, vmax) + norm = _create_norm(logscale, norm, vmin, vmax) sm = plt.cm.ScalarMappable(norm=norm, cmap=cmap) sm.set_array([]) cmap = sm.get_cmap() @@ -613,7 +613,6 @@ def plot3d( return for region_name, da_region in _decompose_regions(da).items(): - npsi, ntheta, nzeta = da_region.shape if style == "surface": diff --git a/xbout/plotting/utils.py b/xbout/plotting/utils.py index e94e2f8c..a99ef257 100644 --- a/xbout/plotting/utils.py +++ b/xbout/plotting/utils.py @@ -42,7 +42,6 @@ def plot_separatrix(da, sep_pos, ax, radial_coord="x"): # 2D domain needs to intersect the separatrix plane to be able to plot it dims = da.dims if radial_coord not in dims: - warnings.warn( "Cannot plot separatrix as domain does not cross " "separatrix, as it does not have a radial dimension", @@ -63,7 +62,6 @@ def plot_separatrix(da, sep_pos, ax, radial_coord="x"): def _decompose_regions(da): - if da.geometry == "fci": return {region: da for region in da.bout._regions} return { @@ -73,7 +71,6 @@ def _decompose_regions(da): def _is_core_only(da): - nx = da.metadata["nx"] ix1 = da.metadata["ixseps1"] ix2 = da.metadata["ixseps2"] diff --git a/xbout/tests/test_against_collect.py b/xbout/tests/test_against_collect.py index 71853680..81d8b33c 100644 --- a/xbout/tests/test_against_collect.py +++ b/xbout/tests/test_against_collect.py @@ -11,7 +11,6 @@ class TestAccuracyAgainstOldCollect: def test_single_file(self, tmp_path_factory): - # Create temp directory for files test_dir = tmp_path_factory.mktemp("test_data") @@ -37,7 +36,6 @@ def test_single_file(self, tmp_path_factory): npt.assert_equal(actual, expected) def test_multiple_files_along_x(self, tmp_path_factory): - # Create temp directory for files test_dir = tmp_path_factory.mktemp("test_data") @@ -66,7 +64,6 @@ def test_multiple_files_along_x(self, tmp_path_factory): npt.assert_equal(actual, expected) def test_multiple_files_along_y(self, tmp_path_factory): - # Create temp directory for files test_dir = tmp_path_factory.mktemp("test_data") @@ -95,7 +92,6 @@ def test_multiple_files_along_y(self, tmp_path_factory): npt.assert_equal(actual, expected) def test_multiple_files_along_xy(self, tmp_path_factory): - # Create temp directory for files test_dir = tmp_path_factory.mktemp("test_data") diff --git a/xbout/tests/test_animate.py b/xbout/tests/test_animate.py index f66a4694..2842dafa 100644 --- a/xbout/tests/test_animate.py +++ b/xbout/tests/test_animate.py @@ -13,7 +13,6 @@ @pytest.fixture def create_test_file(tmp_path_factory): - # Create temp dir for output of animate1D/2D save_dir = tmp_path_factory.mktemp("test_data") @@ -37,7 +36,6 @@ class TestAnimate: """ def test_animate2D(self, create_test_file): - save_dir, ds = create_test_file animation = ds["n"].isel(x=1).bout.animate2D(save_as="%s/testyz" % save_dir) @@ -107,7 +105,6 @@ def test_animate2D_controls_arg(self, create_test_file, controls): plt.close() def test_animate1D(self, create_test_file): - save_dir, ds = create_test_file animation = ds["n"].isel(y=2, z=0).bout.animate1D(save_as="%s/test" % save_dir) @@ -152,7 +149,6 @@ def test_animate1D_controls_arg(self, create_test_file, controls): plt.close() def test_animate_list(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -167,7 +163,6 @@ def test_animate_list(self, create_test_file): plt.close() def test_animate_list_1d_default(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(y=2, z=3).bout.animate_list( @@ -182,7 +177,6 @@ def test_animate_list_1d_default(self, create_test_file): plt.close() def test_animate_list_1d_multiline(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(y=2, z=3).bout.animate_list( @@ -210,7 +204,6 @@ def test_animate_list_1d_multiline(self, create_test_file): plt.close() def test_animate_list_animate_over(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -225,7 +218,6 @@ def test_animate_list_animate_over(self, create_test_file): plt.close() def test_animate_list_save_as(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -241,7 +233,6 @@ def test_animate_list_save_as(self, create_test_file): plt.close() def test_animate_list_fps(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -257,7 +248,6 @@ def test_animate_list_fps(self, create_test_file): plt.close() def test_animate_list_nrows(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -272,7 +262,6 @@ def test_animate_list_nrows(self, create_test_file): plt.close() def test_animate_list_ncols(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -287,7 +276,6 @@ def test_animate_list_ncols(self, create_test_file): plt.close() def test_animate_list_not_enough_nrowsncols(self, create_test_file): - save_dir, ds = create_test_file with pytest.raises(ValueError): @@ -297,7 +285,6 @@ def test_animate_list_not_enough_nrowsncols(self, create_test_file): @pytest.mark.skip(reason="test data for plot_poloidal needs more work") def test_animate_list_poloidal_plot(self, create_test_file): - save_dir, ds = create_test_file metadata = ds.metadata @@ -340,7 +327,6 @@ def test_animate_list_poloidal_plot(self, create_test_file): plt.close() def test_animate_list_subplots_adjust(self, create_test_file): - save_dir, ds = create_test_file with pytest.warns(UserWarning): @@ -357,7 +343,6 @@ def test_animate_list_subplots_adjust(self, create_test_file): plt.close() def test_animate_list_vmin(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -372,7 +357,6 @@ def test_animate_list_vmin(self, create_test_file): plt.close() def test_animate_list_vmin_list(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -387,7 +371,6 @@ def test_animate_list_vmin_list(self, create_test_file): plt.close() def test_animate_list_vmax(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -402,7 +385,6 @@ def test_animate_list_vmax(self, create_test_file): plt.close() def test_animate_list_vmax_list(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -417,7 +399,6 @@ def test_animate_list_vmax_list(self, create_test_file): plt.close() def test_animate_list_logscale(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -432,7 +413,6 @@ def test_animate_list_logscale(self, create_test_file): plt.close() def test_animate_list_logscale_float(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -447,7 +427,6 @@ def test_animate_list_logscale_float(self, create_test_file): plt.close() def test_animate_list_logscale_list(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( @@ -463,7 +442,6 @@ def test_animate_list_logscale_list(self, create_test_file): plt.close() def test_animate_list_titles_list(self, create_test_file): - save_dir, ds = create_test_file animation = ds.isel(z=3).bout.animate_list( diff --git a/xbout/tests/test_boutdataarray.py b/xbout/tests/test_boutdataarray.py index 01f6c138..280aa309 100644 --- a/xbout/tests/test_boutdataarray.py +++ b/xbout/tests/test_boutdataarray.py @@ -209,7 +209,6 @@ def test_to_field_aligned(self, bout_xyt_example_files, nz, permute_dims): "permute_dims", [False, pytest.param(True, marks=pytest.mark.long)] ) def test_to_field_aligned_dask(self, bout_xyt_example_files, permute_dims): - nz = 6 dataset_list = bout_xyt_example_files( @@ -1034,7 +1033,6 @@ def test_add_cartesian_coordinates(self, bout_xyt_example_files): ) def test_ddx(self, bout_xyt_example_files): - nx = 64 dataset_list = bout_xyt_example_files( @@ -1069,7 +1067,6 @@ def test_ddx(self, bout_xyt_example_files): ) def test_ddy(self, bout_xyt_example_files): - ny = 64 dataset_list, gridfilepath = bout_xyt_example_files( diff --git a/xbout/tests/test_boutdataset.py b/xbout/tests/test_boutdataset.py index 18e212fd..6ca5259a 100644 --- a/xbout/tests/test_boutdataset.py +++ b/xbout/tests/test_boutdataset.py @@ -261,7 +261,6 @@ def test_to_field_aligned(self, bout_xyt_example_files, nz): ) # noqa: E501 def test_to_field_aligned_dask(self, bout_xyt_example_files): - nz = 6 dataset_list = bout_xyt_example_files( @@ -2027,7 +2026,6 @@ def test_reload_all(self, tmp_path_factory, bout_xyt_example_files, geometry): def test_save_dtype( self, tmp_path_factory, bout_xyt_example_files, save_dtype, separate_vars ): - # Create data path = bout_xyt_example_files( tmp_path_factory, nxpe=1, nype=1, nt=1, write_to_disk=True diff --git a/xbout/tests/test_load.py b/xbout/tests/test_load.py index df7eef76..6621695c 100644 --- a/xbout/tests/test_load.py +++ b/xbout/tests/test_load.py @@ -814,7 +814,6 @@ def create_bout_grid_ds(xsize=2, ysize=4, guards={}, topology="core", ny_inner=0 class TestStripMetadata: def test_strip_metadata(self): - original = create_bout_ds() assert original["NXPE"] == 1 @@ -1534,7 +1533,6 @@ def create_example_grid_file_fci(tmp_path_factory): @pytest.fixture def create_example_files_fci(tmp_path_factory): - return _bout_xyt_example_files( tmp_path_factory, lengths=fci_shape, From 49ee207c122503212b2f6c9dff9853c066e3175f Mon Sep 17 00:00:00 2001 From: John Omotani Date: Thu, 16 Mar 2023 09:06:31 +0000 Subject: [PATCH 4/4] Fix test_animate_list_1d_multiline for matplotlib-3.7.x Matplotlib changed the types of the subplots, so the old check does not work. --- xbout/tests/test_animate.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/xbout/tests/test_animate.py b/xbout/tests/test_animate.py index 2842dafa..13a37bdc 100644 --- a/xbout/tests/test_animate.py +++ b/xbout/tests/test_animate.py @@ -190,16 +190,7 @@ def test_animate_list_1d_multiline(self, create_test_file): assert isinstance(animation.blocks[3], Line) # check there were actually 3 subplots - assert ( - len( - [ - x - for x in plt.gcf().get_axes() - if isinstance(x, matplotlib.axes.Subplot) - ] - ) - == 3 - ) + assert len([x for x in plt.gcf().get_axes() if x.get_xlabel() != ""]) == 3 plt.close()