From 914c87e1eb316bb6bfd2b3387578c960d6a24b5c Mon Sep 17 00:00:00 2001 From: Greg Lucas Date: Sat, 2 Oct 2021 09:21:43 -0600 Subject: [PATCH] MNT: Update matplotlib.axes calls to super() Clean up some of the code to use super() instead of matpotlib.axes.xxx The quiver tests also needed to be updated to drop the "self" argument from the first location. --- lib/cartopy/mpl/geoaxes.py | 42 +++++++++++++--------------- lib/cartopy/mpl/geocollection.py | 4 +-- lib/cartopy/tests/mpl/test_quiver.py | 8 +++--- 3 files changed, 25 insertions(+), 29 deletions(-) diff --git a/lib/cartopy/mpl/geoaxes.py b/lib/cartopy/mpl/geoaxes.py index fea3c412a..4fc11b005 100644 --- a/lib/cartopy/mpl/geoaxes.py +++ b/lib/cartopy/mpl/geoaxes.py @@ -527,8 +527,7 @@ def get_tightbbox(self, renderer, *args, **kwargs): # Shared processing steps self._draw_preprocess(renderer) - return matplotlib.axes.Axes.get_tightbbox( - self, renderer, *args, **kwargs) + return super().get_tightbbox(renderer, *args, **kwargs) @matplotlib.artist.allow_rasterization def draw(self, renderer=None, **kwargs): @@ -555,10 +554,10 @@ def draw(self, renderer=None, **kwargs): **factory_kwargs) self._done_img_factory = True - return matplotlib.axes.Axes.draw(self, renderer=renderer, **kwargs) + return super().draw(renderer=renderer, **kwargs) def _update_title_position(self, renderer): - matplotlib.axes.Axes._update_title_position(self, renderer) + super()._update_title_position(renderer) if not self._gridliners: return @@ -597,7 +596,7 @@ def __str__(self): def cla(self): """Clear the current axes and adds boundary lines.""" - result = matplotlib.axes.Axes.cla(self) + result = super().cla() self.xaxis.set_visible(False) self.yaxis.set_visible(False) # Enable tight autoscaling. @@ -940,8 +939,7 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True): See :meth:`~matplotlib.axes.Axes.imshow()` for more details. """ - matplotlib.axes.Axes.autoscale_view(self, tight=tight, - scalex=scalex, scaley=scaley) + super().autoscale_view(tight=tight, scalex=scalex, scaley=scaley) # Limit the resulting bounds to valid area. if scalex and self._autoscaleXon: bounds = self.get_xbound() @@ -1350,7 +1348,7 @@ def imshow(self, img, *args, **kwargs): if (transform is None or transform == self.transData or same_projection and inside_bounds): - result = matplotlib.axes.Axes.imshow(self, img, *args, **kwargs) + result = super().imshow(img, *args, **kwargs) else: extent = kwargs.pop('extent', None) img = np.asanyarray(img) @@ -1412,8 +1410,7 @@ def imshow(self, img, *args, **kwargs): if img.dtype.kind == 'u': img[:, :, 3] *= 255 - result = matplotlib.axes.Axes.imshow(self, img, *args, - extent=extent, **kwargs) + result = super().imshow(img, *args, extent=extent, **kwargs) return result @@ -1578,10 +1575,9 @@ def _gen_axes_patch(self): def _gen_axes_spines(self, locations=None, offset=0.0, units='inches'): # generate some axes spines, as some Axes super class machinery # requires them. Just make them invisible - spines = matplotlib.axes.Axes._gen_axes_spines(self, - locations=locations, - offset=offset, - units=units) + spines = super()._gen_axes_spines(locations=locations, + offset=offset, + units=units) for spine in spines.values(): spine.set_visible(False) @@ -1663,7 +1659,7 @@ def contour(self, *args, **kwargs): The default is False, to compute the contours in data-space. """ - result = matplotlib.axes.Axes.contour(self, *args, **kwargs) + result = super().contour(*args, **kwargs) # We need to compute the dataLim correctly for contours. bboxes = [col.get_datalim(self.transData) @@ -1711,7 +1707,7 @@ def contourf(self, *args, **kwargs): if not hasattr(sub_trans, 'force_path_ccw'): sub_trans.force_path_ccw = True - result = matplotlib.axes.Axes.contourf(self, *args, **kwargs) + result = super().contourf(*args, **kwargs) # We need to compute the dataLim correctly for contours. bboxes = [col.get_datalim(self.transData) @@ -1749,7 +1745,7 @@ def scatter(self, *args, **kwargs): 'geodetic, consider using the cyllindrical form ' '(PlateCarree or RotatedPole).') - result = matplotlib.axes.Axes.scatter(self, *args, **kwargs) + result = super().scatter(*args, **kwargs) self.autoscale_view() return result @@ -1776,7 +1772,7 @@ def hexbin(self, x, y, *args, **kwargs): x = pairs[:, 0] y = pairs[:, 1] - result = matplotlib.axes.Axes.hexbin(self, x, y, *args, **kwargs) + result = super().hexbin(x, y, *args, **kwargs) self.autoscale_view() return result @@ -1794,7 +1790,7 @@ def pcolormesh(self, *args, **kwargs): # Add in an argument checker to handle Matplotlib's potential # interpolation when coordinate wraps are involved args = self._wrap_args(*args, **kwargs) - result = matplotlib.axes.Axes.pcolormesh(self, *args, **kwargs) + result = super().pcolormesh(*args, **kwargs) # Wrap the quadrilaterals if necessary result = self._wrap_quadmesh(result, **kwargs) # Re-cast the QuadMesh as a GeoQuadMesh to enable future wrapping @@ -1979,7 +1975,7 @@ def pcolor(self, *args, **kwargs): # Add in an argument checker to handle Matplotlib's potential # interpolation when coordinate wraps are involved args = self._wrap_args(*args, **kwargs) - result = matplotlib.axes.Axes.pcolor(self, *args, **kwargs) + result = super().pcolor(*args, **kwargs) # Update the datalim for this pcolor. limits = result.get_datalim(self.transData) @@ -2061,7 +2057,7 @@ def quiver(self, x, y, u, v, *args, **kwargs): if (x.ndim == 1 and y.ndim == 1) and (x.shape != u.shape): x, y = np.meshgrid(x, y) u, v = self.projection.transform_vectors(t, x, y, u, v) - return matplotlib.axes.Axes.quiver(self, x, y, u, v, *args, **kwargs) + return super().quiver(x, y, u, v, *args, **kwargs) @_add_transform def barbs(self, x, y, u, v, *args, **kwargs): @@ -2136,7 +2132,7 @@ def barbs(self, x, y, u, v, *args, **kwargs): if (x.ndim == 1 and y.ndim == 1) and (x.shape != u.shape): x, y = np.meshgrid(x, y) u, v = self.projection.transform_vectors(t, x, y, u, v) - return matplotlib.axes.Axes.barbs(self, x, y, u, v, *args, **kwargs) + return super().barbs(x, y, u, v, *args, **kwargs) @_add_transform def streamplot(self, x, y, u, v, **kwargs): @@ -2213,7 +2209,7 @@ def streamplot(self, x, y, u, v, **kwargs): message = 'Warning: converting a masked element to nan.' warnings.filterwarnings('ignore', message=message, category=UserWarning) - sp = matplotlib.axes.Axes.streamplot(self, x, y, u, v, **kwargs) + sp = super().streamplot(x, y, u, v, **kwargs) return sp def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs): diff --git a/lib/cartopy/mpl/geocollection.py b/lib/cartopy/mpl/geocollection.py index e66bc7d89..bcda9fca4 100644 --- a/lib/cartopy/mpl/geocollection.py +++ b/lib/cartopy/mpl/geocollection.py @@ -18,7 +18,7 @@ class GeoQuadMesh(QuadMesh): def get_array(self): # Retrieve the array - use copy to avoid any chance of overwrite - A = super(QuadMesh, self).get_array().copy() + A = super().get_array().copy() # If the input array has a mask, retrieve the associated data if hasattr(self, '_wrapped_mask'): A[self._wrapped_mask] = self._wrapped_collection_fix.get_array() @@ -43,7 +43,7 @@ def set_array(self, A): # Now that we have prepared the collection data, call on # through to the underlying implementation. - super(QuadMesh, self).set_array(A) + super().set_array(A) def set_clim(self, vmin=None, vmax=None): # Update _wrapped_collection_fix color limits if it is there. diff --git a/lib/cartopy/tests/mpl/test_quiver.py b/lib/cartopy/tests/mpl/test_quiver.py index 63ccce8f5..7234b32cd 100644 --- a/lib/cartopy/tests/mpl/test_quiver.py +++ b/lib/cartopy/tests/mpl/test_quiver.py @@ -34,9 +34,9 @@ def test_quiver_transform_xyuv_1d(self): self.ax.quiver(self.x2d.ravel(), self.y2d.ravel(), self.u.ravel(), self.v.ravel(), transform=self.rp) args, kwargs = patch.call_args - assert len(args) == 5 + assert len(args) == 4 assert sorted(kwargs.keys()) == ['transform'] - shapes = [arg.shape for arg in args[1:]] + shapes = [arg.shape for arg in args] # Assert that all the shapes have been broadcast. assert shapes == [(70, )] * 4 @@ -45,9 +45,9 @@ def test_quiver_transform_xy_1d_uv_2d(self): with mock.patch('matplotlib.axes.Axes.quiver') as patch: self.ax.quiver(self.x, self.y, self.u, self.v, transform=self.rp) args, kwargs = patch.call_args - assert len(args) == 5 + assert len(args) == 4 assert sorted(kwargs.keys()) == ['transform'] - shapes = [arg.shape for arg in args[1:]] + shapes = [arg.shape for arg in args] # Assert that all the shapes have been broadcast. assert shapes == [(7, 10)] * 4