Skip to content

Commit

Permalink
ENH: support RGB(A) in pcolormesh (wrapped case)
Browse files Browse the repository at this point in the history
  • Loading branch information
rcomer committed Jun 13, 2023
1 parent e170ecc commit 1597fa4
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 52 deletions.
48 changes: 25 additions & 23 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1922,13 +1922,12 @@ def _wrap_quadmesh(self, collection, **kwargs):
"map it must be fully transparent.",
stacklevel=3)

# The original data mask (regardless of wrapped cells)
C_mask = getattr(C, 'mask', None)
# Get hold of masked versions of the array to be passed to set_array
# methods of QuadMesh and PolyQuadMesh
pcolormesh_data, pcolor_data, pcolor_mask = \
cartopy.mpl.geocollection._split_wrapped_mesh_data(C, mask)

# create the masked array to be used with this pcolormesh
full_mask = mask if C_mask is None else mask | C_mask
pcolormesh_data = np.ma.array(C, mask=full_mask)
collection.set_array(pcolormesh_data.ravel())
collection.set_array(pcolormesh_data)

# plot with slightly lower zorder to avoid odd issue
# where the main plot is obscured
Expand All @@ -1944,28 +1943,31 @@ def _wrap_quadmesh(self, collection, **kwargs):
# `pcolor` only draws polygons where the data is not
# masked, so this will only draw a limited subset of
# polygons that were actually wrapped.
# We will add the original data mask in later to
# make sure that set_array can work in future
# calls on the proper sized array inputs.
# NOTE: we don't use C.data here because C.data could
# contain nan's which would be masked in the
# pcolor routines, which we don't want. We will
# fill in the proper data later with set_array()
# calls.
pcolor_data = np.ma.array(np.zeros(C.shape),
mask=~mask)
pcolor_col = self.pcolor(coords[..., 0], coords[..., 1],
pcolor_data, zorder=zorder,
**kwargs)
# Now add back in the masked data if there was any
full_mask = ~mask if C_mask is None else ~mask | C_mask
pcolor_data = np.ma.array(C, mask=full_mask)

if _MPL_VERSION.release[:2] < (3, 8):
# We will add the original data mask in later to
# make sure that set_array can work in future
# calls on the proper sized array inputs.
# NOTE: we don't use C.data here because C.data could
# contain nan's which would be masked in the
# pcolor routines, which we don't want. We will
# fill in the proper data later with set_array()
# calls.
pcolor_zeros = np.ma.array(np.zeros(C.shape), mask=pcolor_mask)
pcolor_col = self.pcolor(coords[..., 0], coords[..., 1],
pcolor_zeros, zorder=zorder,
**kwargs)

# The pcolor_col is now possibly shorter than the
# actual collection, so grab the masked cells
pcolor_col.set_array(pcolor_data[mask].ravel())
else:
pcolor_col = self.pcolor(coords[..., 0], coords[..., 1],
pcolor_data, zorder=zorder,
**kwargs)
# Currently pcolor_col.get_array() will return a compressed array
# and warn unless we explicitly set the 2D array. This should be
# unnecessary with future matplotlib versions.
pcolor_col.set_array(pcolor_data)

pcolor_col.set_cmap(cmap)
Expand All @@ -1977,7 +1979,7 @@ def _wrap_quadmesh(self, collection, **kwargs):
# put the pcolor_col and mask on the pcolormesh
# collection so that users can do things post
# this method
collection._wrapped_mask = mask.ravel()
collection._wrapped_mask = mask
collection._wrapped_collection_fix = pcolor_col

return collection
Expand Down
89 changes: 70 additions & 19 deletions lib/cartopy/mpl/geocollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,42 @@
import matplotlib as mpl
from matplotlib.collections import QuadMesh
import numpy as np
import numpy.ma as ma
import packaging


_MPL_VERSION = packaging.version.parse(mpl.__version__)


def _split_wrapped_mesh_data(C, mask):
"""
Helper function for splitting GeoQuadMesh array values between the
pcolormesh and pcolor objects when wrapping. Apply a mask to the grid
cells that should not be plotted with each method.
"""
# The original data mask (regardless of wrapped cells)
C_mask = getattr(C, 'mask', None)
if C.ndim == 3:
# RGB(A) array.
if _MPL_VERSION.release < (3, 8):
raise ValueError("GeoQuadMesh wrapping for RGB(A) requires "
"Matplotlib v3.8 or later")

# mask will need an extra trailing dimension
mask = np.broadcast_to(mask[..., np.newaxis], C.shape)

# create the masked array to be used with pcolormesh
full_mask = mask if C_mask is None else mask | C_mask
pcolormesh_data = ma.array(C, mask=full_mask)

# create the masked array to be used with pcolor
full_mask = ~mask if C_mask is None else ~mask | C_mask
pcolor_data = ma.array(C, mask=full_mask)

return pcolormesh_data, pcolor_data, ~mask


class GeoQuadMesh(QuadMesh):
"""
A QuadMesh designed to help handle the case when the mesh is wrapped.
Expand All @@ -26,34 +56,55 @@ def get_array(self):
A = super().get_array().copy()
# If the input array has a mask, retrieve the associated data
if hasattr(self, '_wrapped_mask'):
A[self._wrapped_mask] = np.ma.compressed(
self._wrapped_collection_fix.get_array())
pcolor_data = self._wrapped_collection_fix.get_array()
mask = self._wrapped_mask
if _MPL_VERSION.release[:2] < (3, 8):
A[mask] = pcolor_data
else:
# np.copyto is not implemented for masked arrays so handle the
# mask explicitly
np.copyto(A.mask, pcolor_data.mask, where=mask)
np.copyto(A, pcolor_data, where=mask)

return A

def set_array(self, A):
# raise right away if A is 2-dimensional.
if A.ndim > 1:
raise ValueError('Collections can only map rank 1 arrays. '
'You likely want to call with a flattened array '
'using collection.set_array(A.ravel()) instead.')
# Check the shape is appropriate up front.
if _MPL_VERSION.release[:2] < (3, 8):
# Need to figure out existing shape from the coordinates.
height, width = self._coordinates.shape[0:-1]
if self._shading == 'flat':
h, w = height - 1, width - 1
else:
h, w = height, width
else:
h, w = super().get_array().shape[:2]

ok_shapes = [(h, w, 3), (h, w, 4), (h, w), (h * w,)]
if A.shape not in ok_shapes:
ok_shape_str = ' or '.join(map(str, ok_shapes))
raise ValueError(
f"A should have shape {ok_shape_str}, not {A.shape}")

if A.ndim == 1:
# Always use array with at least two dimensions. This is
# inconsistent with QuadMesh which stores whatever you give it, but
# for the wrapped case we need to match the 2D mask. Storing the
# 2D array also allows us to calculate ok_shapes on subsequent
# calls without using the private QuadMesh._shading attribute.
A = A.reshape((h, w))

# Only use the mask attribute if it is there.
if hasattr(self, '_wrapped_mask'):

# Update the pcolor data with the wrapped masked data
if _MPL_VERSION.release < (3, 8):
self._wrapped_collection_fix.set_array(A[self._wrapped_mask])
A, pcolor_data, _ = _split_wrapped_mesh_data(A, self._wrapped_mask)

if _MPL_VERSION.release[:2] < (3, 8):
self._wrapped_collection_fix.set_array(
pcolor_data[self._wrapped_mask].ravel())
else:
A_mask = getattr(A, 'mask', None)
full_mask = ~self._wrapped_mask
if A_mask is not None:
full_mask |= A_mask
pcolor_data = np.ma.array(A, mask=full_mask)
self._wrapped_collection_fix.set_array(pcolor_data)
# If the input array was a masked array, keep that data masked
if hasattr(A, 'mask'):
A = np.ma.array(A, mask=self._wrapped_mask | A.mask)
else:
A = np.ma.array(A, mask=self._wrapped_mask)

# Now that we have prepared the collection data, call on
# through to the underlying implementation.
Expand Down
81 changes: 71 additions & 10 deletions lib/cartopy/tests/mpl/test_mpl_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import re

import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pytest
Expand Down Expand Up @@ -240,10 +241,46 @@ def test_cursor_values():
r.encode('ascii', 'ignore'))


SKIP_PRE_MPL38 = pytest.mark.skipif(
MPL_VERSION.release[:2] < (3, 8), reason='mpl < 3.8')
PARAMETRIZE_PCOLORMESH_WRAP = pytest.mark.parametrize(
'mesh_data_kind',
[
'standard',
pytest.param('rgb', marks=SKIP_PRE_MPL38),
pytest.param('rgba', marks=SKIP_PRE_MPL38),
],
ids=['standard', 'rgb', 'rgba'],
)


def _to_rgb(data, mesh_data_kind):
"""
Helper function to convert array to RGB(A) where required
"""
if mesh_data_kind in ('rgb', 'rgba'):
cmap = plt.get_cmap()
norm = mcolors.Normalize()
new_data = cmap(norm(data))
if mesh_data_kind == 'rgb':
new_data = new_data[..., 0:3]
if np.ma.is_masked(data):
# Use data's mask as an alpha channel
mask = np.ma.getmaskarray(data)
mask = np.broadcast_to(
mask[..., np.newaxis], new_data.shape).copy()
new_data = np.ma.array(new_data, mask=mask)

return new_data

return data


@PARAMETRIZE_PCOLORMESH_WRAP
@pytest.mark.natural_earth
@pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap1.png',
tolerance=1.27)
def test_pcolormesh_global_with_wrap1():
def test_pcolormesh_global_with_wrap1(mesh_data_kind):
# make up some realistic data with bounds (such as data from the UM)
nx, ny = 36, 18
xbnds = np.linspace(0, 360, nx, endpoint=True)
Expand All @@ -254,6 +291,8 @@ def test_pcolormesh_global_with_wrap1():
data = data[:-1, :-1]
fig = plt.figure()

data = _to_rgb(data, mesh_data_kind)

ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree())
ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(), snap=False)
ax.coastlines()
Expand Down Expand Up @@ -285,8 +324,9 @@ def test_pcolormesh_get_array_with_mask():

result = c.get_array()
assert not np.ma.is_masked(result)
assert np.array_equal(data.ravel(), result), \
'Data supplied does not match data retrieved in wrapped case'
np.testing.assert_array_equal(
data, result,
err_msg='Data supplied does not match data retrieved in wrapped case')

ax.coastlines()
ax.set_global() # make sure everything is visible
Expand Down Expand Up @@ -319,10 +359,11 @@ def test_pcolormesh_get_array_with_mask():
'Data supplied does not match data retrieved in unwrapped case'


@PARAMETRIZE_PCOLORMESH_WRAP
@pytest.mark.natural_earth
@pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap2.png',
tolerance=1.87)
def test_pcolormesh_global_with_wrap2():
def test_pcolormesh_global_with_wrap2(mesh_data_kind):
# make up some realistic data with bounds (such as data from the UM)
nx, ny = 36, 18
xbnds, xstep = np.linspace(0, 360, nx - 1, retstep=True, endpoint=True)
Expand All @@ -337,6 +378,8 @@ def test_pcolormesh_global_with_wrap2():
data = data[:-1, :-1]
fig = plt.figure()

data = _to_rgb(data, mesh_data_kind)

ax = fig.add_subplot(2, 1, 1, projection=ccrs.PlateCarree())
ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(), snap=False)
ax.coastlines()
Expand All @@ -350,10 +393,11 @@ def test_pcolormesh_global_with_wrap2():
return fig


@PARAMETRIZE_PCOLORMESH_WRAP
@pytest.mark.natural_earth
@pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap3.png',
tolerance=1.42)
def test_pcolormesh_global_with_wrap3():
def test_pcolormesh_global_with_wrap3(mesh_data_kind):
nx, ny = 33, 17
xbnds = np.linspace(-1.875, 358.125, nx, endpoint=True)
ybnds = np.linspace(91.25, -91.25, ny, endpoint=True)
Expand All @@ -371,6 +415,8 @@ def test_pcolormesh_global_with_wrap3():
data = np.ma.masked_greater(data, 2.6)
fig = plt.figure()

data = _to_rgb(data, mesh_data_kind)

ax = fig.add_subplot(3, 1, 1, projection=ccrs.PlateCarree(-45))
c = ax.pcolormesh(xbnds, ybnds, data, transform=ccrs.PlateCarree(),
snap=False)
Expand All @@ -393,10 +439,11 @@ def test_pcolormesh_global_with_wrap3():
return fig


@PARAMETRIZE_PCOLORMESH_WRAP
@pytest.mark.natural_earth
@pytest.mark.mpl_image_compare(filename='pcolormesh_global_wrap3.png',
tolerance=1.42)
def test_pcolormesh_set_array_with_mask():
def test_pcolormesh_set_array_with_mask(mesh_data_kind):
"""Testing that set_array works with masked arrays properly."""
nx, ny = 33, 17
xbnds = np.linspace(-1.875, 358.125, nx, endpoint=True)
Expand All @@ -419,10 +466,15 @@ def test_pcolormesh_set_array_with_mask():
bad_data_mask = np.ma.array(bad_data, mask=~data.mask)
fig = plt.figure()

data = _to_rgb(data, mesh_data_kind)
bad_data = _to_rgb(bad_data, mesh_data_kind)
bad_data_mask = _to_rgb(bad_data_mask, mesh_data_kind)

ax = fig.add_subplot(3, 1, 1, projection=ccrs.PlateCarree(-45))
c = ax.pcolormesh(xbnds, ybnds, bad_data,
norm=norm, transform=ccrs.PlateCarree(), snap=False)
c.set_array(data.ravel())

c.set_array(data)
assert c._wrapped_collection_fix is not None, \
'No pcolormesh wrapping was done when it should have been.'

Expand All @@ -432,7 +484,10 @@ def test_pcolormesh_set_array_with_mask():
ax = fig.add_subplot(3, 1, 2, projection=ccrs.PlateCarree(-1.87499952))
c2 = ax.pcolormesh(xbnds, ybnds, bad_data_mask,
norm=norm, transform=ccrs.PlateCarree(), snap=False)
c2.set_array(data.ravel())
if mesh_data_kind == 'standard':
c2.set_array(data.ravel())
else:
c2.set_array(data)
ax.coastlines()
ax.set_global() # make sure everything is visible

Expand Down Expand Up @@ -578,6 +633,8 @@ def test_pcolormesh_diagonal_wrap():
assert hasattr(mesh, "_wrapped_collection_fix")


@pytest.mark.skipif(MPL_VERSION.release[:2] >= (3, 8),
reason='redundant from mpl v3.8')
def test_pcolormesh_nan_wrap():
# Check that data with nan's as input still creates
# the proper number of pcolor cells and those aren't
Expand Down Expand Up @@ -622,22 +679,26 @@ def test_pcolormesh_mercator_wrap():
return ax.figure


@PARAMETRIZE_PCOLORMESH_WRAP
@pytest.mark.natural_earth
@pytest.mark.mpl_image_compare(filename='pcolormesh_mercator_wrap.png')
def test_pcolormesh_wrap_set_array():
def test_pcolormesh_wrap_set_array(mesh_data_kind):
x = np.linspace(0, 360, 73)
y = np.linspace(-87.5, 87.5, 36)
X, Y = np.meshgrid(*[np.deg2rad(c) for c in (x, y)])
Z = np.cos(Y) + 0.375 * np.sin(2. * X)
Z = Z[:-1, :-1]

Z = _to_rgb(Z, mesh_data_kind)

ax = plt.axes(projection=ccrs.Mercator())
norm = plt.Normalize(np.min(Z), np.max(Z))
ax.coastlines()
# Start off with bad data
coll = ax.pcolormesh(x, y, np.ones(Z.shape), norm=norm,
transform=ccrs.PlateCarree(), snap=False)
# Now update the plot with the set_array method
coll.set_array(Z.ravel())
coll.set_array(Z)
return ax.figure


Expand Down

0 comments on commit 1597fa4

Please sign in to comment.