Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: enable passing RGB(A) to polormesh #2166

Merged
merged 5 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 39 additions & 31 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
from cartopy.mpl.slippy_image_artist import SlippyImageArtist


assert packaging.version.parse(mpl.__version__).release[:2] >= (3, 4), \
_MPL_VERSION = packaging.version.parse(mpl.__version__)
assert _MPL_VERSION.release >= (3, 4), \
'Cartopy is only supported with Matplotlib 3.4 or greater.'

# A nested mapping from path, source CRS, and target projection to the
Expand Down Expand Up @@ -1796,7 +1797,7 @@ def _wrap_args(self, *args, **kwargs):
kwargs['shading'] = 'flat'
X = np.asanyarray(args[0])
Y = np.asanyarray(args[1])
nrows, ncols = np.asanyarray(args[2]).shape
nrows, ncols = np.asanyarray(args[2]).shape[:2]
Nx = X.shape[-1]
Ny = Y.shape[0]
if X.ndim != 2 or X.shape[0] == 1:
Expand Down Expand Up @@ -1843,12 +1844,13 @@ def _wrap_quadmesh(self, collection, **kwargs):
Ny, Nx, _ = coords.shape
if kwargs.get('shading') == 'gouraud':
# Gouraud shading has the same shape for coords and data
data_shape = Ny, Nx
data_shape = Ny, Nx, -1
else:
data_shape = Ny - 1, Nx - 1
data_shape = Ny - 1, Nx - 1, -1
# data array
C = collection.get_array().reshape(data_shape)

if C.shape[-1] == 1:
C = C.squeeze(axis=-1)
transformed_pts = self.projection.transform_points(
t, coords[..., 0], coords[..., 1])

Expand Down Expand Up @@ -1921,13 +1923,12 @@ def _wrap_quadmesh(self, collection, **kwargs):
"map it must be fully transparent.",
stacklevel=3)

# The original data mask (regardless of wrapped cells)
C_mask = getattr(C, 'mask', None)
# Get hold of masked versions of the array to be passed to set_array
# methods of QuadMesh and PolyQuadMesh
pcolormesh_data, pcolor_data, pcolor_mask = \
cartopy.mpl.geocollection._split_wrapped_mesh_data(C, mask)

# create the masked array to be used with this pcolormesh
full_mask = mask if C_mask is None else mask | C_mask
pcolormesh_data = np.ma.array(C, mask=full_mask)
collection.set_array(pcolormesh_data.ravel())
collection.set_array(pcolormesh_data)

# plot with slightly lower zorder to avoid odd issue
# where the main plot is obscured
Expand All @@ -1943,25 +1944,32 @@ def _wrap_quadmesh(self, collection, **kwargs):
# `pcolor` only draws polygons where the data is not
# masked, so this will only draw a limited subset of
# polygons that were actually wrapped.
# We will add the original data mask in later to
# make sure that set_array can work in future
# calls on the proper sized array inputs.
# NOTE: we don't use C.data here because C.data could
# contain nan's which would be masked in the
# pcolor routines, which we don't want. We will
# fill in the proper data later with set_array()
# calls.
pcolor_data = np.ma.array(np.zeros(C.shape),
mask=~mask)
pcolor_col = self.pcolor(coords[..., 0], coords[..., 1],
pcolor_data, zorder=zorder,
**kwargs)
# Now add back in the masked data if there was any
full_mask = ~mask if C_mask is None else ~mask | C_mask
pcolor_data = np.ma.array(C, mask=full_mask)
# The pcolor_col is now possibly shorter than the
# actual collection, so grab the masked cells
pcolor_col.set_array(pcolor_data[mask].ravel())

if _MPL_VERSION.release[:2] < (3, 8):
# We will add the original data mask in later to
# make sure that set_array can work in future
# calls on the proper sized array inputs.
# NOTE: we don't use C.data here because C.data could
# contain nan's which would be masked in the
# pcolor routines, which we don't want. We will
# fill in the proper data later with set_array()
# calls.
pcolor_zeros = np.ma.array(np.zeros(C.shape), mask=pcolor_mask)
pcolor_col = self.pcolor(coords[..., 0], coords[..., 1],
pcolor_zeros, zorder=zorder,
**kwargs)

# The pcolor_col is now possibly shorter than the
# actual collection, so grab the masked cells
pcolor_col.set_array(pcolor_data[mask].ravel())
else:
pcolor_col = self.pcolor(coords[..., 0], coords[..., 1],
pcolor_data, zorder=zorder,
**kwargs)
# Currently pcolor_col.get_array() will return a compressed array
# and warn unless we explicitly set the 2D array. This should be
# unnecessary with future matplotlib versions.
pcolor_col.set_array(pcolor_data)

pcolor_col.set_cmap(cmap)
pcolor_col.set_norm(norm)
Expand All @@ -1972,7 +1980,7 @@ def _wrap_quadmesh(self, collection, **kwargs):
# put the pcolor_col and mask on the pcolormesh
# collection so that users can do things post
# this method
collection._wrapped_mask = mask.ravel()
collection._wrapped_mask = mask
collection._wrapped_collection_fix = pcolor_col

return collection
Expand Down
89 changes: 78 additions & 11 deletions lib/cartopy/mpl/geocollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,43 @@
# This file is part of Cartopy and is released under the LGPL license.
# See COPYING and COPYING.LESSER in the root of the repository for full
# licensing details.
import matplotlib as mpl
from matplotlib.collections import QuadMesh
import numpy as np
import numpy.ma as ma
import packaging


_MPL_VERSION = packaging.version.parse(mpl.__version__)


def _split_wrapped_mesh_data(C, mask):
"""
Helper function for splitting GeoQuadMesh array values between the
pcolormesh and pcolor objects when wrapping. Apply a mask to the grid
cells that should not be plotted with each method.

"""
# The original data mask (regardless of wrapped cells)
C_mask = getattr(C, 'mask', None)
if C.ndim == 3:
# RGB(A) array.
if _MPL_VERSION.release < (3, 8):
raise ValueError("GeoQuadMesh wrapping for RGB(A) requires "
"Matplotlib v3.8 or later")

# mask will need an extra trailing dimension
mask = np.broadcast_to(mask[..., np.newaxis], C.shape)

# create the masked array to be used with pcolormesh
full_mask = mask if C_mask is None else mask | C_mask
pcolormesh_data = ma.array(C, mask=full_mask)

# create the masked array to be used with pcolor
full_mask = ~mask if C_mask is None else ~mask | C_mask
pcolor_data = ma.array(C, mask=full_mask)

return pcolormesh_data, pcolor_data, ~mask


class GeoQuadMesh(QuadMesh):
Expand All @@ -21,25 +56,57 @@ def get_array(self):
A = super().get_array().copy()
# If the input array has a mask, retrieve the associated data
if hasattr(self, '_wrapped_mask'):
A[self._wrapped_mask] = self._wrapped_collection_fix.get_array()
pcolor_data = self._wrapped_collection_fix.get_array()
mask = self._wrapped_mask
if _MPL_VERSION.release[:2] < (3, 8):
A[mask] = pcolor_data
else:
if A.ndim == 3: # RGB(A) data. Need to broadcast mask.
mask = mask[:, :, np.newaxis]
# np.copyto is not implemented for masked arrays so handle the
# mask explicitly
np.copyto(A.mask, pcolor_data.mask, where=mask)
np.copyto(A, pcolor_data, where=mask)

return A

def set_array(self, A):
# raise right away if A is 2-dimensional.
if A.ndim > 1:
raise ValueError('Collections can only map rank 1 arrays. '
'You likely want to call with a flattened array '
'using collection.set_array(A.ravel()) instead.')
# Check the shape is appropriate up front.
if _MPL_VERSION.release[:2] < (3, 8):
# Need to figure out existing shape from the coordinates.
height, width = self._coordinates.shape[0:-1]
if self._shading == 'flat':
h, w = height - 1, width - 1
else:
h, w = height, width
else:
h, w = super().get_array().shape[:2]

ok_shapes = [(h, w, 3), (h, w, 4), (h, w), (h * w,)]
if A.shape not in ok_shapes:
ok_shape_str = ' or '.join(map(str, ok_shapes))
raise ValueError(
f"A should have shape {ok_shape_str}, not {A.shape}")

if A.ndim == 1:
# Always use array with at least two dimensions. This is
# inconsistent with QuadMesh which stores whatever you give it, but
# for the wrapped case we need to match the 2D mask. Storing the
# 2D array also allows us to calculate ok_shapes on subsequent
# calls without using the private QuadMesh._shading attribute.
A = A.reshape((h, w))

# Only use the mask attribute if it is there.
if hasattr(self, '_wrapped_mask'):

# Update the pcolor data with the wrapped masked data
self._wrapped_collection_fix.set_array(A[self._wrapped_mask])
# If the input array was a masked array, keep that data masked
if hasattr(A, 'mask'):
A = np.ma.array(A, mask=self._wrapped_mask | A.mask)
A, pcolor_data, _ = _split_wrapped_mesh_data(A, self._wrapped_mask)

if _MPL_VERSION.release[:2] < (3, 8):
self._wrapped_collection_fix.set_array(
pcolor_data[self._wrapped_mask].ravel())
else:
A = np.ma.array(A, mask=self._wrapped_mask)
self._wrapped_collection_fix.set_array(pcolor_data)

# Now that we have prepared the collection data, call on
# through to the underlying implementation.
Expand Down
Loading
Loading