From d756b28da7d562b22f91ae88381fa2467a1a4f9f Mon Sep 17 00:00:00 2001 From: Ruth Comer Date: Mon, 1 May 2023 19:56:52 +0100 Subject: [PATCH] ENH: support RGB(A) in pcolormesh (wrapped case) Co-authored-by: Kyle Hofmann --- lib/cartopy/mpl/geoaxes.py | 48 +++++----- lib/cartopy/mpl/geocollection.py | 89 +++++++++++++++---- lib/cartopy/tests/mpl/test_mpl_integration.py | 81 ++++++++++++++--- 3 files changed, 166 insertions(+), 52 deletions(-) diff --git a/lib/cartopy/mpl/geoaxes.py b/lib/cartopy/mpl/geoaxes.py index e2c193c6f..73deeaf2c 100644 --- a/lib/cartopy/mpl/geoaxes.py +++ b/lib/cartopy/mpl/geoaxes.py @@ -1922,13 +1922,12 @@ def _wrap_quadmesh(self, collection, **kwargs): "map it must be fully transparent.", stacklevel=3) - # The original data mask (regardless of wrapped cells) - C_mask = getattr(C, 'mask', None) + # Get hold of masked versions of the array to be passed to set_array + # methods of QuadMesh and PolyQuadMesh + pcolormesh_data, pcolor_data, pcolor_mask = \ + cartopy.mpl.geocollection._split_wrapped_mesh_data(C, mask) - # create the masked array to be used with this pcolormesh - full_mask = mask if C_mask is None else mask | C_mask - pcolormesh_data = np.ma.array(C, mask=full_mask) - collection.set_array(pcolormesh_data.ravel()) + collection.set_array(pcolormesh_data) # plot with slightly lower zorder to avoid odd issue # where the main plot is obscured @@ -1944,28 +1943,31 @@ def _wrap_quadmesh(self, collection, **kwargs): # `pcolor` only draws polygons where the data is not # masked, so this will only draw a limited subset of # polygons that were actually wrapped. - # We will add the original data mask in later to - # make sure that set_array can work in future - # calls on the proper sized array inputs. - # NOTE: we don't use C.data here because C.data could - # contain nan's which would be masked in the - # pcolor routines, which we don't want. We will - # fill in the proper data later with set_array() - # calls. - pcolor_data = np.ma.array(np.zeros(C.shape), - mask=~mask) - pcolor_col = self.pcolor(coords[..., 0], coords[..., 1], - pcolor_data, zorder=zorder, - **kwargs) - # Now add back in the masked data if there was any - full_mask = ~mask if C_mask is None else ~mask | C_mask - pcolor_data = np.ma.array(C, mask=full_mask) if _MPL_VERSION.release[:2] < (3, 8): + # We will add the original data mask in later to + # make sure that set_array can work in future + # calls on the proper sized array inputs. + # NOTE: we don't use C.data here because C.data could + # contain nan's which would be masked in the + # pcolor routines, which we don't want. We will + # fill in the proper data later with set_array() + # calls. + pcolor_zeros = np.ma.array(np.zeros(C.shape), mask=pcolor_mask) + pcolor_col = self.pcolor(coords[..., 0], coords[..., 1], + pcolor_zeros, zorder=zorder, + **kwargs) + # The pcolor_col is now possibly shorter than the # actual collection, so grab the masked cells pcolor_col.set_array(pcolor_data[mask].ravel()) else: + pcolor_col = self.pcolor(coords[..., 0], coords[..., 1], + pcolor_data, zorder=zorder, + **kwargs) + # Currently pcolor_col.get_array() will return a compressed array + # and warn unless we explicitly set the 2D array. This should be + # unnecessary with future matplotlib versions. pcolor_col.set_array(pcolor_data) pcolor_col.set_cmap(cmap) @@ -1977,7 +1979,7 @@ def _wrap_quadmesh(self, collection, **kwargs): # put the pcolor_col and mask on the pcolormesh # collection so that users can do things post # this method - collection._wrapped_mask = mask.ravel() + collection._wrapped_mask = mask collection._wrapped_collection_fix = pcolor_col return collection diff --git a/lib/cartopy/mpl/geocollection.py b/lib/cartopy/mpl/geocollection.py index cc313d8e5..66fd1cbfe 100644 --- a/lib/cartopy/mpl/geocollection.py +++ b/lib/cartopy/mpl/geocollection.py @@ -6,12 +6,42 @@ import matplotlib as mpl from matplotlib.collections import QuadMesh import numpy as np +import numpy.ma as ma import packaging _MPL_VERSION = packaging.version.parse(mpl.__version__) +def _split_wrapped_mesh_data(C, mask): + """ + Helper function for splitting GeoQuadMesh array values between the + pcolormesh and pcolor objects when wrapping. Apply a mask to the grid + cells that should not be plotted with each method. + + """ + # The original data mask (regardless of wrapped cells) + C_mask = getattr(C, 'mask', None) + if C.ndim == 3: + # RGB(A) array. + if _MPL_VERSION.release < (3, 8): + raise ValueError("GeoQuadMesh wrapping for RGB(A) requires " + "Matplotlib v3.8 or later") + + # mask will need an extra trailing dimension + mask = np.broadcast_to(mask[..., np.newaxis], C.shape) + + # create the masked array to be used with pcolormesh + full_mask = mask if C_mask is None else mask | C_mask + pcolormesh_data = ma.array(C, mask=full_mask) + + # create the masked array to be used with pcolor + full_mask = ~mask if C_mask is None else ~mask | C_mask + pcolor_data = ma.array(C, mask=full_mask) + + return pcolormesh_data, pcolor_data, ~mask + + class GeoQuadMesh(QuadMesh): """ A QuadMesh designed to help handle the case when the mesh is wrapped. @@ -26,34 +56,55 @@ def get_array(self): A = super().get_array().copy() # If the input array has a mask, retrieve the associated data if hasattr(self, '_wrapped_mask'): - A[self._wrapped_mask] = np.ma.compressed( - self._wrapped_collection_fix.get_array()) + pcolor_data = self._wrapped_collection_fix.get_array() + mask = self._wrapped_mask + if _MPL_VERSION.release[:2] < (3, 8): + A[mask] = pcolor_data + else: + # np.copyto is not implemented for masked arrays so handle the + # mask explicitly + np.copyto(A.mask, pcolor_data.mask, where=mask) + np.copyto(A, pcolor_data, where=mask) + return A def set_array(self, A): - # raise right away if A is 2-dimensional. - if A.ndim > 1: - raise ValueError('Collections can only map rank 1 arrays. ' - 'You likely want to call with a flattened array ' - 'using collection.set_array(A.ravel()) instead.') + # Check the shape is appropriate up front. + if _MPL_VERSION.release[:2] < (3, 8): + # Need to figure out existing shape from the coordinates. + height, width = self._coordinates.shape[0:-1] + if self._shading == 'flat': + h, w = height - 1, width - 1 + else: + h, w = height, width + else: + h, w = super().get_array().shape[:2] + + ok_shapes = [(h, w, 3), (h, w, 4), (h, w), (h * w,)] + if A.shape not in ok_shapes: + ok_shape_str = ' or '.join(map(str, ok_shapes)) + raise ValueError( + f"A should have shape {ok_shape_str}, not {A.shape}") + + if A.ndim == 1: + # Always use array with at least two dimensions. This is + # inconsistent with QuadMesh which stores whatever you give it, but + # for the wrapped case we need to match the 2D mask. Storing the + # 2D array also allows us to calculate ok_shapes on subsequent + # calls without using the private QuadMesh._shading attribute. + A = A.reshape((h, w)) # Only use the mask attribute if it is there. if hasattr(self, '_wrapped_mask'): + # Update the pcolor data with the wrapped masked data - if _MPL_VERSION.release < (3, 8): - self._wrapped_collection_fix.set_array(A[self._wrapped_mask]) + A, pcolor_data, _ = _split_wrapped_mesh_data(A, self._wrapped_mask) + + if _MPL_VERSION.release[:2] < (3, 8): + self._wrapped_collection_fix.set_array( + pcolor_data[self._wrapped_mask].ravel()) else: - A_mask = getattr(A, 'mask', None) - full_mask = ~self._wrapped_mask - if A_mask is not None: - full_mask |= A_mask - pcolor_data = np.ma.array(A, mask=full_mask) self._wrapped_collection_fix.set_array(pcolor_data) - # If the input array was a masked array, keep that data masked - if hasattr(A, 'mask'): - A = np.ma.array(A, mask=self._wrapped_mask | A.mask) - else: - A = np.ma.array(A, mask=self._wrapped_mask) # Now that we have prepared the collection data, call on # through to the underlying implementation. diff --git a/lib/cartopy/tests/mpl/test_mpl_integration.py b/lib/cartopy/tests/mpl/test_mpl_integration.py index b1e059990..66e8da48b 100644 --- a/lib/cartopy/tests/mpl/test_mpl_integration.py +++ b/lib/cartopy/tests/mpl/test_mpl_integration.py @@ -6,6 +6,7 @@ import re +import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np import pytest @@ -240,10 +241,46 @@ def test_cursor_values(): r.encode('ascii', 'ignore')) +SKIP_PRE_MPL38 = pytest.mark.skipif( + MPL_VERSION.release[:2] < (3, 8), reason='mpl < 3.8') +PARAMETRIZE_PCOLORMESH_WRAP = pytest.mark.parametrize( + 'mesh_data_kind', + [ + 'standard', + pytest.param('rgb', marks=SKIP_PRE_MPL38), + pytest.param('rgba', marks=SKIP_PRE_MPL38), + ], + ids=['standard', 'rgb', 'rgba'], +) + + +def _to_rgb(data, mesh_data_kind): + """ + Helper function to convert array to RGB(A) where required + """ + if mesh_data_kind in ('rgb', 'rgba'): + cmap = plt.get_cmap() + norm = mcolors.Normalize() + new_data = cmap(norm(data)) + if mesh_data_kind == 'rgb': + new_data = new_data[..., 0:3] + if np.ma.is_masked(data): + # Use data's mask as an alpha channel + mask = np.ma.getmaskarray(data) + mask = np.broadcast_to( + mask[..., np.newaxis], new_data.shape).copy() + new_data = np.ma.array(new_data, mask=mask) + + return new_data + + return data + + +@PARAMETRIZE_PCOLORMESH_WRAP @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap1.png', tolerance=1.27) -def test_pcolormesh_global_with_wrap1(): +def test_pcolormesh_global_with_wrap1(mesh_data_kind): # make up some realistic data with bounds (such as data from the UM) nx, ny = 36, 18 xbnds = np.linspace(0, 360, nx, endpoint=True) @@ -254,6 +291,8 @@ def test_pcolormesh_global_with_wrap1(): data = data[:-1, :-1] fig = plt.figure() + data = _to_rgb(data, mesh_data_kind) + ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree()) ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(), snap=False) ax.coastlines() @@ -285,8 +324,9 @@ def test_pcolormesh_get_array_with_mask(): result = c.get_array() assert not np.ma.is_masked(result) - assert np.array_equal(data.ravel(), result), \ - 'Data supplied does not match data retrieved in wrapped case' + np.testing.assert_array_equal( + data, result, + err_msg='Data supplied does not match data retrieved in wrapped case') ax.coastlines() ax.set_global() # make sure everything is visible @@ -319,10 +359,11 @@ def test_pcolormesh_get_array_with_mask(): 'Data supplied does not match data retrieved in unwrapped case' +@PARAMETRIZE_PCOLORMESH_WRAP @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap2.png', tolerance=1.87) -def test_pcolormesh_global_with_wrap2(): +def test_pcolormesh_global_with_wrap2(mesh_data_kind): # make up some realistic data with bounds (such as data from the UM) nx, ny = 36, 18 xbnds, xstep = np.linspace(0, 360, nx - 1, retstep=True, endpoint=True) @@ -337,6 +378,8 @@ def test_pcolormesh_global_with_wrap2(): data = data[:-1, :-1] fig = plt.figure() + data = _to_rgb(data, mesh_data_kind) + ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree()) ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(), snap=False) ax.coastlines() @@ -350,10 +393,11 @@ def test_pcolormesh_global_with_wrap2(): return fig +@PARAMETRIZE_PCOLORMESH_WRAP @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap3.png', tolerance=1.42) -def test_pcolormesh_global_with_wrap3(): +def test_pcolormesh_global_with_wrap3(mesh_data_kind): nx, ny = 33, 17 xbnds = np.linspace(-1.875, 358.125, nx, endpoint=True) ybnds = np.linspace(91.25, -91.25, ny, endpoint=True) @@ -371,6 +415,8 @@ def test_pcolormesh_global_with_wrap3(): data = np.ma.masked_greater(data, 2.6) fig = plt.figure() + data = _to_rgb(data, mesh_data_kind) + ax = fig.add_subplot(3, 1, 1, projection=ccrs.PlateCarree(-45)) c = ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(), snap=False) @@ -393,10 +439,11 @@ def test_pcolormesh_global_with_wrap3(): return fig +@PARAMETRIZE_PCOLORMESH_WRAP @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap3.png', tolerance=1.42) -def test_pcolormesh_set_array_with_mask(): +def test_pcolormesh_set_array_with_mask(mesh_data_kind): """Testing that set_array works with masked arrays properly.""" nx, ny = 33, 17 xbnds = np.linspace(-1.875, 358.125, nx, endpoint=True) @@ -419,10 +466,15 @@ def test_pcolormesh_set_array_with_mask(): bad_data_mask = np.ma.array(bad_data, mask=~data.mask) fig = plt.figure() + data = _to_rgb(data, mesh_data_kind) + bad_data = _to_rgb(bad_data, mesh_data_kind) + bad_data_mask = _to_rgb(bad_data_mask, mesh_data_kind) + ax = fig.add_subplot(3, 1, 1, projection=ccrs.PlateCarree(-45)) c = ax.pcolormesh(xbnds, ybnds, bad_data, norm=norm, transform=ccrs.PlateCarree(), snap=False) - c.set_array(data.ravel()) + + c.set_array(data) assert c._wrapped_collection_fix is not None, \ 'No pcolormesh wrapping was done when it should have been.' @@ -432,7 +484,10 @@ def test_pcolormesh_set_array_with_mask(): ax = fig.add_subplot(3, 1, 2, projection=ccrs.PlateCarree(-1.87499952)) c2 = ax.pcolormesh(xbnds, ybnds, bad_data_mask, norm=norm, transform=ccrs.PlateCarree(), snap=False) - c2.set_array(data.ravel()) + if mesh_data_kind == 'standard': + c2.set_array(data.ravel()) + else: + c2.set_array(data) ax.coastlines() ax.set_global() # make sure everything is visible @@ -578,6 +633,8 @@ def test_pcolormesh_diagonal_wrap(): assert hasattr(mesh, "_wrapped_collection_fix") +@pytest.mark.skipif(MPL_VERSION.release[:2] >= (3, 8), + reason='redundant from mpl v3.8') def test_pcolormesh_nan_wrap(): # Check that data with nan's as input still creates # the proper number of pcolor cells and those aren't @@ -622,14 +679,18 @@ def test_pcolormesh_mercator_wrap(): return ax.figure +@PARAMETRIZE_PCOLORMESH_WRAP @pytest.mark.natural_earth @pytest.mark.mpl_image_compare(filename='pcolormesh_mercator_wrap.png') -def test_pcolormesh_wrap_set_array(): +def test_pcolormesh_wrap_set_array(mesh_data_kind): x = np.linspace(0, 360, 73) y = np.linspace(-87.5, 87.5, 36) X, Y = np.meshgrid(*[np.deg2rad(c) for c in (x, y)]) Z = np.cos(Y) + 0.375 * np.sin(2. * X) Z = Z[:-1, :-1] + + Z = _to_rgb(Z, mesh_data_kind) + ax = plt.axes(projection=ccrs.Mercator()) norm = plt.Normalize(np.min(Z), np.max(Z)) ax.coastlines() @@ -637,7 +698,7 @@ def test_pcolormesh_wrap_set_array(): coll = ax.pcolormesh(x, y, np.ones(Z.shape), norm=norm, transform=ccrs.PlateCarree(), snap=False) # Now update the plot with the set_array method - coll.set_array(Z.ravel()) + coll.set_array(Z) return ax.figure