Skip to content

Commit

Permalink
MNT: Update matplotlib.axes calls to super()
Browse files Browse the repository at this point in the history
Clean up some of the code to use super() instead of matpotlib.axes.xxx

The quiver tests also needed to be updated to drop the "self" argument
from the first location.
  • Loading branch information
greglucas committed Oct 2, 2021
1 parent b4368b2 commit 914c87e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 29 deletions.
42 changes: 19 additions & 23 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,7 @@ def get_tightbbox(self, renderer, *args, **kwargs):
# Shared processing steps
self._draw_preprocess(renderer)

return matplotlib.axes.Axes.get_tightbbox(
self, renderer, *args, **kwargs)
return super().get_tightbbox(renderer, *args, **kwargs)

@matplotlib.artist.allow_rasterization
def draw(self, renderer=None, **kwargs):
Expand All @@ -555,10 +554,10 @@ def draw(self, renderer=None, **kwargs):
**factory_kwargs)
self._done_img_factory = True

return matplotlib.axes.Axes.draw(self, renderer=renderer, **kwargs)
return super().draw(renderer=renderer, **kwargs)

def _update_title_position(self, renderer):
matplotlib.axes.Axes._update_title_position(self, renderer)
super()._update_title_position(renderer)
if not self._gridliners:
return

Expand Down Expand Up @@ -597,7 +596,7 @@ def __str__(self):

def cla(self):
"""Clear the current axes and adds boundary lines."""
result = matplotlib.axes.Axes.cla(self)
result = super().cla()
self.xaxis.set_visible(False)
self.yaxis.set_visible(False)
# Enable tight autoscaling.
Expand Down Expand Up @@ -940,8 +939,7 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True):
See :meth:`~matplotlib.axes.Axes.imshow()` for more details.
"""
matplotlib.axes.Axes.autoscale_view(self, tight=tight,
scalex=scalex, scaley=scaley)
super().autoscale_view(tight=tight, scalex=scalex, scaley=scaley)
# Limit the resulting bounds to valid area.
if scalex and self._autoscaleXon:
bounds = self.get_xbound()
Expand Down Expand Up @@ -1350,7 +1348,7 @@ def imshow(self, img, *args, **kwargs):

if (transform is None or transform == self.transData or
same_projection and inside_bounds):
result = matplotlib.axes.Axes.imshow(self, img, *args, **kwargs)
result = super().imshow(img, *args, **kwargs)
else:
extent = kwargs.pop('extent', None)
img = np.asanyarray(img)
Expand Down Expand Up @@ -1412,8 +1410,7 @@ def imshow(self, img, *args, **kwargs):
if img.dtype.kind == 'u':
img[:, :, 3] *= 255

result = matplotlib.axes.Axes.imshow(self, img, *args,
extent=extent, **kwargs)
result = super().imshow(img, *args, extent=extent, **kwargs)

return result

Expand Down Expand Up @@ -1578,10 +1575,9 @@ def _gen_axes_patch(self):
def _gen_axes_spines(self, locations=None, offset=0.0, units='inches'):
# generate some axes spines, as some Axes super class machinery
# requires them. Just make them invisible
spines = matplotlib.axes.Axes._gen_axes_spines(self,
locations=locations,
offset=offset,
units=units)
spines = super()._gen_axes_spines(locations=locations,
offset=offset,
units=units)
for spine in spines.values():
spine.set_visible(False)

Expand Down Expand Up @@ -1663,7 +1659,7 @@ def contour(self, *args, **kwargs):
The default is False, to compute the contours in data-space.
"""
result = matplotlib.axes.Axes.contour(self, *args, **kwargs)
result = super().contour(*args, **kwargs)

# We need to compute the dataLim correctly for contours.
bboxes = [col.get_datalim(self.transData)
Expand Down Expand Up @@ -1711,7 +1707,7 @@ def contourf(self, *args, **kwargs):
if not hasattr(sub_trans, 'force_path_ccw'):
sub_trans.force_path_ccw = True

result = matplotlib.axes.Axes.contourf(self, *args, **kwargs)
result = super().contourf(*args, **kwargs)

# We need to compute the dataLim correctly for contours.
bboxes = [col.get_datalim(self.transData)
Expand Down Expand Up @@ -1749,7 +1745,7 @@ def scatter(self, *args, **kwargs):
'geodetic, consider using the cyllindrical form '
'(PlateCarree or RotatedPole).')

result = matplotlib.axes.Axes.scatter(self, *args, **kwargs)
result = super().scatter(*args, **kwargs)
self.autoscale_view()
return result

Expand All @@ -1776,7 +1772,7 @@ def hexbin(self, x, y, *args, **kwargs):
x = pairs[:, 0]
y = pairs[:, 1]

result = matplotlib.axes.Axes.hexbin(self, x, y, *args, **kwargs)
result = super().hexbin(x, y, *args, **kwargs)
self.autoscale_view()
return result

Expand All @@ -1794,7 +1790,7 @@ def pcolormesh(self, *args, **kwargs):
# Add in an argument checker to handle Matplotlib's potential
# interpolation when coordinate wraps are involved
args = self._wrap_args(*args, **kwargs)
result = matplotlib.axes.Axes.pcolormesh(self, *args, **kwargs)
result = super().pcolormesh(*args, **kwargs)
# Wrap the quadrilaterals if necessary
result = self._wrap_quadmesh(result, **kwargs)
# Re-cast the QuadMesh as a GeoQuadMesh to enable future wrapping
Expand Down Expand Up @@ -1979,7 +1975,7 @@ def pcolor(self, *args, **kwargs):
# Add in an argument checker to handle Matplotlib's potential
# interpolation when coordinate wraps are involved
args = self._wrap_args(*args, **kwargs)
result = matplotlib.axes.Axes.pcolor(self, *args, **kwargs)
result = super().pcolor(*args, **kwargs)

# Update the datalim for this pcolor.
limits = result.get_datalim(self.transData)
Expand Down Expand Up @@ -2061,7 +2057,7 @@ def quiver(self, x, y, u, v, *args, **kwargs):
if (x.ndim == 1 and y.ndim == 1) and (x.shape != u.shape):
x, y = np.meshgrid(x, y)
u, v = self.projection.transform_vectors(t, x, y, u, v)
return matplotlib.axes.Axes.quiver(self, x, y, u, v, *args, **kwargs)
return super().quiver(x, y, u, v, *args, **kwargs)

@_add_transform
def barbs(self, x, y, u, v, *args, **kwargs):
Expand Down Expand Up @@ -2136,7 +2132,7 @@ def barbs(self, x, y, u, v, *args, **kwargs):
if (x.ndim == 1 and y.ndim == 1) and (x.shape != u.shape):
x, y = np.meshgrid(x, y)
u, v = self.projection.transform_vectors(t, x, y, u, v)
return matplotlib.axes.Axes.barbs(self, x, y, u, v, *args, **kwargs)
return super().barbs(x, y, u, v, *args, **kwargs)

@_add_transform
def streamplot(self, x, y, u, v, **kwargs):
Expand Down Expand Up @@ -2213,7 +2209,7 @@ def streamplot(self, x, y, u, v, **kwargs):
message = 'Warning: converting a masked element to nan.'
warnings.filterwarnings('ignore', message=message,
category=UserWarning)
sp = matplotlib.axes.Axes.streamplot(self, x, y, u, v, **kwargs)
sp = super().streamplot(x, y, u, v, **kwargs)
return sp

def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions lib/cartopy/mpl/geocollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class GeoQuadMesh(QuadMesh):

def get_array(self):
# Retrieve the array - use copy to avoid any chance of overwrite
A = super(QuadMesh, self).get_array().copy()
A = super().get_array().copy()
# If the input array has a mask, retrieve the associated data
if hasattr(self, '_wrapped_mask'):
A[self._wrapped_mask] = self._wrapped_collection_fix.get_array()
Expand All @@ -43,7 +43,7 @@ def set_array(self, A):

# Now that we have prepared the collection data, call on
# through to the underlying implementation.
super(QuadMesh, self).set_array(A)
super().set_array(A)

def set_clim(self, vmin=None, vmax=None):
# Update _wrapped_collection_fix color limits if it is there.
Expand Down
8 changes: 4 additions & 4 deletions lib/cartopy/tests/mpl/test_quiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def test_quiver_transform_xyuv_1d(self):
self.ax.quiver(self.x2d.ravel(), self.y2d.ravel(),
self.u.ravel(), self.v.ravel(), transform=self.rp)
args, kwargs = patch.call_args
assert len(args) == 5
assert len(args) == 4
assert sorted(kwargs.keys()) == ['transform']
shapes = [arg.shape for arg in args[1:]]
shapes = [arg.shape for arg in args]
# Assert that all the shapes have been broadcast.
assert shapes == [(70, )] * 4

Expand All @@ -45,9 +45,9 @@ def test_quiver_transform_xy_1d_uv_2d(self):
with mock.patch('matplotlib.axes.Axes.quiver') as patch:
self.ax.quiver(self.x, self.y, self.u, self.v, transform=self.rp)
args, kwargs = patch.call_args
assert len(args) == 5
assert len(args) == 4
assert sorted(kwargs.keys()) == ['transform']
shapes = [arg.shape for arg in args[1:]]
shapes = [arg.shape for arg in args]
# Assert that all the shapes have been broadcast.
assert shapes == [(7, 10)] * 4

Expand Down

0 comments on commit 914c87e

Please sign in to comment.