Skip to content

Commit

Permalink
Merge pull request #1896 from greglucas/use-super
Browse files Browse the repository at this point in the history
MNT: Update matplotlib.axes calls to super()
  • Loading branch information
QuLogic authored Oct 2, 2021
2 parents b4368b2 + 914c87e commit f187916
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 29 deletions.
42 changes: 19 additions & 23 deletions lib/cartopy/mpl/geoaxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,8 +527,7 @@ def get_tightbbox(self, renderer, *args, **kwargs):
# Shared processing steps
self._draw_preprocess(renderer)

return matplotlib.axes.Axes.get_tightbbox(
self, renderer, *args, **kwargs)
return super().get_tightbbox(renderer, *args, **kwargs)

@matplotlib.artist.allow_rasterization
def draw(self, renderer=None, **kwargs):
Expand All @@ -555,10 +554,10 @@ def draw(self, renderer=None, **kwargs):
**factory_kwargs)
self._done_img_factory = True

return matplotlib.axes.Axes.draw(self, renderer=renderer, **kwargs)
return super().draw(renderer=renderer, **kwargs)

def _update_title_position(self, renderer):
matplotlib.axes.Axes._update_title_position(self, renderer)
super()._update_title_position(renderer)
if not self._gridliners:
return

Expand Down Expand Up @@ -597,7 +596,7 @@ def __str__(self):

def cla(self):
"""Clear the current axes and adds boundary lines."""
result = matplotlib.axes.Axes.cla(self)
result = super().cla()
self.xaxis.set_visible(False)
self.yaxis.set_visible(False)
# Enable tight autoscaling.
Expand Down Expand Up @@ -940,8 +939,7 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True):
See :meth:`~matplotlib.axes.Axes.imshow()` for more details.
"""
matplotlib.axes.Axes.autoscale_view(self, tight=tight,
scalex=scalex, scaley=scaley)
super().autoscale_view(tight=tight, scalex=scalex, scaley=scaley)
# Limit the resulting bounds to valid area.
if scalex and self._autoscaleXon:
bounds = self.get_xbound()
Expand Down Expand Up @@ -1350,7 +1348,7 @@ def imshow(self, img, *args, **kwargs):

if (transform is None or transform == self.transData or
same_projection and inside_bounds):
result = matplotlib.axes.Axes.imshow(self, img, *args, **kwargs)
result = super().imshow(img, *args, **kwargs)
else:
extent = kwargs.pop('extent', None)
img = np.asanyarray(img)
Expand Down Expand Up @@ -1412,8 +1410,7 @@ def imshow(self, img, *args, **kwargs):
if img.dtype.kind == 'u':
img[:, :, 3] *= 255

result = matplotlib.axes.Axes.imshow(self, img, *args,
extent=extent, **kwargs)
result = super().imshow(img, *args, extent=extent, **kwargs)

return result

Expand Down Expand Up @@ -1578,10 +1575,9 @@ def _gen_axes_patch(self):
def _gen_axes_spines(self, locations=None, offset=0.0, units='inches'):
# generate some axes spines, as some Axes super class machinery
# requires them. Just make them invisible
spines = matplotlib.axes.Axes._gen_axes_spines(self,
locations=locations,
offset=offset,
units=units)
spines = super()._gen_axes_spines(locations=locations,
offset=offset,
units=units)
for spine in spines.values():
spine.set_visible(False)

Expand Down Expand Up @@ -1663,7 +1659,7 @@ def contour(self, *args, **kwargs):
The default is False, to compute the contours in data-space.
"""
result = matplotlib.axes.Axes.contour(self, *args, **kwargs)
result = super().contour(*args, **kwargs)

# We need to compute the dataLim correctly for contours.
bboxes = [col.get_datalim(self.transData)
Expand Down Expand Up @@ -1711,7 +1707,7 @@ def contourf(self, *args, **kwargs):
if not hasattr(sub_trans, 'force_path_ccw'):
sub_trans.force_path_ccw = True

result = matplotlib.axes.Axes.contourf(self, *args, **kwargs)
result = super().contourf(*args, **kwargs)

# We need to compute the dataLim correctly for contours.
bboxes = [col.get_datalim(self.transData)
Expand Down Expand Up @@ -1749,7 +1745,7 @@ def scatter(self, *args, **kwargs):
'geodetic, consider using the cyllindrical form '
'(PlateCarree or RotatedPole).')

result = matplotlib.axes.Axes.scatter(self, *args, **kwargs)
result = super().scatter(*args, **kwargs)
self.autoscale_view()
return result

Expand All @@ -1776,7 +1772,7 @@ def hexbin(self, x, y, *args, **kwargs):
x = pairs[:, 0]
y = pairs[:, 1]

result = matplotlib.axes.Axes.hexbin(self, x, y, *args, **kwargs)
result = super().hexbin(x, y, *args, **kwargs)
self.autoscale_view()
return result

Expand All @@ -1794,7 +1790,7 @@ def pcolormesh(self, *args, **kwargs):
# Add in an argument checker to handle Matplotlib's potential
# interpolation when coordinate wraps are involved
args = self._wrap_args(*args, **kwargs)
result = matplotlib.axes.Axes.pcolormesh(self, *args, **kwargs)
result = super().pcolormesh(*args, **kwargs)
# Wrap the quadrilaterals if necessary
result = self._wrap_quadmesh(result, **kwargs)
# Re-cast the QuadMesh as a GeoQuadMesh to enable future wrapping
Expand Down Expand Up @@ -1979,7 +1975,7 @@ def pcolor(self, *args, **kwargs):
# Add in an argument checker to handle Matplotlib's potential
# interpolation when coordinate wraps are involved
args = self._wrap_args(*args, **kwargs)
result = matplotlib.axes.Axes.pcolor(self, *args, **kwargs)
result = super().pcolor(*args, **kwargs)

# Update the datalim for this pcolor.
limits = result.get_datalim(self.transData)
Expand Down Expand Up @@ -2061,7 +2057,7 @@ def quiver(self, x, y, u, v, *args, **kwargs):
if (x.ndim == 1 and y.ndim == 1) and (x.shape != u.shape):
x, y = np.meshgrid(x, y)
u, v = self.projection.transform_vectors(t, x, y, u, v)
return matplotlib.axes.Axes.quiver(self, x, y, u, v, *args, **kwargs)
return super().quiver(x, y, u, v, *args, **kwargs)

@_add_transform
def barbs(self, x, y, u, v, *args, **kwargs):
Expand Down Expand Up @@ -2136,7 +2132,7 @@ def barbs(self, x, y, u, v, *args, **kwargs):
if (x.ndim == 1 and y.ndim == 1) and (x.shape != u.shape):
x, y = np.meshgrid(x, y)
u, v = self.projection.transform_vectors(t, x, y, u, v)
return matplotlib.axes.Axes.barbs(self, x, y, u, v, *args, **kwargs)
return super().barbs(x, y, u, v, *args, **kwargs)

@_add_transform
def streamplot(self, x, y, u, v, **kwargs):
Expand Down Expand Up @@ -2213,7 +2209,7 @@ def streamplot(self, x, y, u, v, **kwargs):
message = 'Warning: converting a masked element to nan.'
warnings.filterwarnings('ignore', message=message,
category=UserWarning)
sp = matplotlib.axes.Axes.streamplot(self, x, y, u, v, **kwargs)
sp = super().streamplot(x, y, u, v, **kwargs)
return sp

def add_wmts(self, wmts, layer_name, wmts_kwargs=None, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions lib/cartopy/mpl/geocollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class GeoQuadMesh(QuadMesh):

def get_array(self):
# Retrieve the array - use copy to avoid any chance of overwrite
A = super(QuadMesh, self).get_array().copy()
A = super().get_array().copy()
# If the input array has a mask, retrieve the associated data
if hasattr(self, '_wrapped_mask'):
A[self._wrapped_mask] = self._wrapped_collection_fix.get_array()
Expand All @@ -43,7 +43,7 @@ def set_array(self, A):

# Now that we have prepared the collection data, call on
# through to the underlying implementation.
super(QuadMesh, self).set_array(A)
super().set_array(A)

def set_clim(self, vmin=None, vmax=None):
# Update _wrapped_collection_fix color limits if it is there.
Expand Down
8 changes: 4 additions & 4 deletions lib/cartopy/tests/mpl/test_quiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def test_quiver_transform_xyuv_1d(self):
self.ax.quiver(self.x2d.ravel(), self.y2d.ravel(),
self.u.ravel(), self.v.ravel(), transform=self.rp)
args, kwargs = patch.call_args
assert len(args) == 5
assert len(args) == 4
assert sorted(kwargs.keys()) == ['transform']
shapes = [arg.shape for arg in args[1:]]
shapes = [arg.shape for arg in args]
# Assert that all the shapes have been broadcast.
assert shapes == [(70, )] * 4

Expand All @@ -45,9 +45,9 @@ def test_quiver_transform_xy_1d_uv_2d(self):
with mock.patch('matplotlib.axes.Axes.quiver') as patch:
self.ax.quiver(self.x, self.y, self.u, self.v, transform=self.rp)
args, kwargs = patch.call_args
assert len(args) == 5
assert len(args) == 4
assert sorted(kwargs.keys()) == ['transform']
shapes = [arg.shape for arg in args[1:]]
shapes = [arg.shape for arg in args]
# Assert that all the shapes have been broadcast.
assert shapes == [(7, 10)] * 4

Expand Down

0 comments on commit f187916

Please sign in to comment.