Skip to content

Commit

Permalink
MNT: changes required for greglucas/pcolor-2dmesh
Browse files Browse the repository at this point in the history
  • Loading branch information
rcomer committed Jun 5, 2023
1 parent 40f04d5 commit e170ecc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 6 deletions.
13 changes: 9 additions & 4 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
from cartopy.mpl.slippy_image_artist import SlippyImageArtist


assert packaging.version.parse(mpl.__version__).release[:2] >= (3, 4), \
_MPL_VERSION = packaging.version.parse(mpl.__version__)
assert _MPL_VERSION.release >= (3, 4), \
'Cartopy is only supported with Matplotlib 3.4 or greater.'

# A nested mapping from path, source CRS, and target projection to the
Expand Down Expand Up @@ -1959,9 +1960,13 @@ def _wrap_quadmesh(self, collection, **kwargs):
# Now add back in the masked data if there was any
full_mask = ~mask if C_mask is None else ~mask | C_mask
pcolor_data = np.ma.array(C, mask=full_mask)
# The pcolor_col is now possibly shorter than the
# actual collection, so grab the masked cells
pcolor_col.set_array(pcolor_data[mask].ravel())

if _MPL_VERSION.release[:2] < (3, 8):
# The pcolor_col is now possibly shorter than the
# actual collection, so grab the masked cells
pcolor_col.set_array(pcolor_data[mask].ravel())
else:
pcolor_col.set_array(pcolor_data)

pcolor_col.set_cmap(cmap)
pcolor_col.set_norm(norm)
Expand Down
18 changes: 16 additions & 2 deletions lib/cartopy/mpl/geocollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
# This file is part of Cartopy and is released under the LGPL license.
# See COPYING and COPYING.LESSER in the root of the repository for full
# licensing details.
import matplotlib as mpl
from matplotlib.collections import QuadMesh
import numpy as np
import packaging


_MPL_VERSION = packaging.version.parse(mpl.__version__)


class GeoQuadMesh(QuadMesh):
Expand All @@ -21,7 +26,8 @@ def get_array(self):
A = super().get_array().copy()
# If the input array has a mask, retrieve the associated data
if hasattr(self, '_wrapped_mask'):
A[self._wrapped_mask] = self._wrapped_collection_fix.get_array()
A[self._wrapped_mask] = np.ma.compressed(
self._wrapped_collection_fix.get_array())
return A

def set_array(self, A):
Expand All @@ -34,7 +40,15 @@ def set_array(self, A):
# Only use the mask attribute if it is there.
if hasattr(self, '_wrapped_mask'):
# Update the pcolor data with the wrapped masked data
self._wrapped_collection_fix.set_array(A[self._wrapped_mask])
if _MPL_VERSION.release < (3, 8):
self._wrapped_collection_fix.set_array(A[self._wrapped_mask])
else:
A_mask = getattr(A, 'mask', None)
full_mask = ~self._wrapped_mask
if A_mask is not None:
full_mask |= A_mask
pcolor_data = np.ma.array(A, mask=full_mask)
self._wrapped_collection_fix.set_array(pcolor_data)
# If the input array was a masked array, keep that data masked
if hasattr(A, 'mask'):
A = np.ma.array(A, mask=self._wrapped_mask | A.mask)
Expand Down

0 comments on commit e170ecc

Please sign in to comment.