Skip to content

Commit

Permalink
refactor weighted operations (#287)
Browse files Browse the repository at this point in the history
* refactor weighted operations

* update changelog
  • Loading branch information
mathause authored Sep 7, 2023
1 parent 68bbc86 commit c7e2064
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 124 deletions.
11 changes: 6 additions & 5 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ v0.9.0 - unreleased
New Features
^^^^^^^^^^^^

- Refactored statistical functionality for linear regression:
- Extracted statistical functionality for linear regression:
- Create :py:class:`mesmer.stats.linear_regression.LinearRegression` which encapsulates
``fit``, ``predict``, etc. methods around linear regression
(`#134 <https://github.com/MESMER-group/mesmer/pull/134>`_).
Expand All @@ -22,7 +22,7 @@ New Features
(`#221 <https://github.com/MESMER-group/mesmer/pull/221>`_).
By `Mathias Hauser <https://github.com/mathause>`_.

- Refactored statistical functionality for auto regression:
- Extracted statistical functionality for auto regression:
- Add ``mesmer.stats.auto_regression._fit_auto_regression_xr``: xarray wrapper to fit an
auto regression model (`#139 <https://github.com/MESMER-group/mesmer/pull/139>`_).
By `Mathias Hauser <https://github.com/mathause>`_.
Expand All @@ -33,7 +33,7 @@ New Features
(`#176 <https://github.com/MESMER-group/mesmer/pull/176>`_).
By `Mathias Hauser <https://github.com/mathause>`_.

- Refactored functions dealing with the spatial covariance and its localization:
- Extracted functions dealing with the spatial covariance and its localization:
- Add xarray wrappers :py:func:`mesmer.stats.localized_covariance.adjust_covariance_ar1`
and :py:func:`mesmer.stats.localized_covariance.find_localized_empirical_covariance`
(`#191 <https://github.com/MESMER-group/mesmer/pull/191>`__).
Expand All @@ -60,8 +60,9 @@ New Features
- Added functions to mask the ocean and Antarctica (`#219
<https://github.com/MESMER-group/mesmer/pull/219>`_). By `Mathias Hauser
<https://github.com/mathause>`_.
- Added functions to calculate the weighted global mean (`#220
<https://github.com/MESMER-group/mesmer/pull/220>`_). By `Mathias Hauser
- Added functions to calculate the weighted global mean (
`#220 <https://github.com/MESMER-group/mesmer/pull/220>`_
`#287 <https://github.com/MESMER-group/mesmer/pull/287>`_). By `Mathias Hauser
<https://github.com/mathause>`_.
- Added functions to wrap arrays to [-180, 180) and [0, 360), respectively (`#270
<https://github.com/MESMER-group/mesmer/pull/270>`_ and `#273
Expand Down
16 changes: 14 additions & 2 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ Geo-spatial
Data handling
=============

Grid manipulation
-----------------

.. autosummary::
:toctree: generated/

Expand All @@ -70,13 +73,22 @@ Data handling
~core.grid.unstack_lat_lon_and_align
~core.grid.unstack_lat_lon
~core.grid.align_to_coords

Masking regions
---------------

~core.mask.mask_ocean_fraction
~core.mask.mask_ocean
~core.mask.mask_antarctica
~core.globmean.lat_weights
~core.globmean.weighted_mean
~core.regionmaskcompat.mask_3D_frac_approx

Weighted operarions: calculate global mean
------------------------------------------

~core.weighted.global_mean
~core.weighted.lat_weights
~core.weighted.weighted_mean

Legacy functions
================

Expand Down
4 changes: 2 additions & 2 deletions mesmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

from . import calibrate_mesmer, core, create_emulations, io, utils
from .core import globmean, grid, mask
from .core import grid, mask, weighted

__all__ = [
"calibrate_mesmer",
Expand All @@ -18,7 +18,7 @@
"io",
"mask",
"utils",
"globmean",
"weighted",
]

try:
Expand Down
46 changes: 38 additions & 8 deletions mesmer/core/globmean.py → mesmer/core/weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def _weighted_if_dim(obj, weights, dims):
# https://github.com/pydata/xarray/issues/7027

def _weighted_mean(da):
if all(dim in da.dims for dim in dims):
if dims is None or all(dim in da.dims for dim in dims):
return da.weighted(weights).mean(dims, keep_attrs=True)
return da

Expand All @@ -34,8 +34,8 @@ def lat_weights(lat_coords):
return weights


def weighted_mean(data, weights, x_dim="lon", y_dim="lat"):
"""Calculate the area-weighted global mean
def weighted_mean(data, weights, dims=None):
"""weighted mean - convinience function which ignores data_vars missing dims
Parameters
----------
Expand All @@ -44,10 +44,8 @@ def weighted_mean(data, weights, x_dim="lon", y_dim="lat"):
weights : xr.DataArray
DataArray containing the area of each grid cell (or a measure proportional to
the grid cell area).
x_dim : str, default: "lon"
Name of the x-dimension.
y_dim : str, default: "lat"
Name of the y-dimension.
dims : Hashable or Iterable of Hashable, optional
Dimension(s) over which to apply the weighted ``mean``.
Returns
-------
Expand All @@ -56,10 +54,42 @@ def weighted_mean(data, weights, x_dim="lon", y_dim="lat"):
"""

if isinstance(dims, str):
dims = [dims]

# ensure grids are equal
try:
xr.align(data, weights, join="exact")
except ValueError:
raise ValueError("`data` and `weights` don't exactly align.")

return _weighted_if_dim(data, weights, [x_dim, y_dim])
return _weighted_if_dim(data, weights, dims)


def global_mean(data, weights=None, x_dim="lon", y_dim="lat"):
"""calculate global weighted mean
Parameters
----------
data : xr.Dataset | xr.DataArray
Array reduce to the global mean.
weights : xr.DataArray, optional
DataArray containing the area of each grid cell (or a measure proportional to
the grid cell area). If not given will compute it from the cosine of the
latitudes.
x_dim : str, default: "lon"
Name of the x-dimension.
y_dim : str, default: "lat"
Name of the y-dimension.
Returns
-------
obj : xr.Dataset | xr.DataArray
Array converted to an unstructured grid.
"""

if weights is None:
weights = lat_weights(data[y_dim])

return weighted_mean(data, weights, [x_dim, y_dim])
107 changes: 0 additions & 107 deletions tests/unit/test_globmean.py

This file was deleted.

Loading

0 comments on commit c7e2064

Please sign in to comment.