Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port Diana Gergel's QPLAD from CIL xclim fork #197

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
222 changes: 219 additions & 3 deletions dodola/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
Math stuff and business logic goes here. This is the "business logic".
"""

from typing import Union
import warnings
import logging
from numba import float32, float64, jit, types
import numpy as np
import xarray as xr
from xclim import sdba, set_options
Expand All @@ -19,6 +21,220 @@
# Assume data input here is generally clean and valid.


@jit(
[
float32[:, :](float32[:, :], float32[:, :], float32[:], types.unicode_type),
float64[:, :](float64[:, :], float64[:, :], float64[:], types.unicode_type),
float64[:, :](float64[:, :], float32[:, :], float64[:], types.unicode_type),
float32[:](float32[:], float32[:], float32[:], types.unicode_type),
float64[:](float64[:], float64[:], float64[:], types.unicode_type),
],
nopython=True,
)
def _argsort(arr_coarse, arr_fine, q, arr_sort="coarse"):
if arr_coarse.ndim == 1:
inds = np.argsort(arr_coarse)
if arr_sort == "coarse":
out = arr_coarse[inds]
elif arr_sort == "fine":
out = arr_fine[inds]

Check warning on line 40 in dodola/core.py

View check run for this annotation

Codecov / codecov/patch

dodola/core.py#L35-L40

Added lines #L35 - L40 were not covered by tests
else:
out = np.empty((arr_coarse.shape[0], q.size), dtype=arr_coarse.dtype)
for index in range(out.shape[0]):
inds = np.argsort(arr_coarse[index])
if arr_sort == "coarse":
out[index] = arr_coarse[index][inds]
elif arr_sort == "fine":
out[index] = arr_fine[index][inds]
return out

Check warning on line 49 in dodola/core.py

View check run for this annotation

Codecov / codecov/patch

dodola/core.py#L42-L49

Added lines #L42 - L49 were not covered by tests


def argsort(da_ref_coarse, da_ref_fine, q, dim, arr_sort="coarse", axis=0):
"""Sort ref_coarse or ref_fine (specified by the arr_sort input arg)
with the indices used to quantile ref_coarse"""
# We have two cases :
# - When all dims are processed : we stack them and use _argsort1d
# - When the quantiles are vectorized over some dims, these are also stacked and then _argsort2D is used.
# All this stacking is so we can cover all ND+1D cases with one numba function.

# check if there are any nulls in input arrays
# if there are nulls, remove them
da_ref_coarse = da_ref_coarse.dropna(dim="time")
da_ref_fine = da_ref_fine.dropna(dim="time")

# Stack the dims and send to the last position
# This is in case there are more than one
dims = [dim] if isinstance(dim, str) else dim
tem = xr.core.utils.get_temp_dimname(da_ref_coarse.dims, "temporal")
da_ref_coarse = da_ref_coarse.stack({tem: dims})
da_ref_fine = da_ref_fine.stack({tem: dims})

# So we cut in half the definitions to declare in numba
if not hasattr(q, "dtype") or q.dtype != da_ref_coarse.dtype:
q = np.array(q, dtype=da_ref_coarse.dtype)

Check warning on line 74 in dodola/core.py

View check run for this annotation

Codecov / codecov/patch

dodola/core.py#L74

Added line #L74 was not covered by tests

if len(da_ref_coarse.dims) > 1:
# There are some extra dims
extra = xr.core.utils.get_temp_dimname(da_ref_coarse.dims, "extra")
da_ref_coarse = da_ref_coarse.stack({extra: set(da_ref_coarse.dims) - {tem}})
da_ref_fine = da_ref_fine.stack({extra: set(da_ref_fine.dims) - {tem}})
da_ref_coarse = da_ref_coarse.transpose(..., tem)
da_ref_fine = da_ref_fine.transpose(..., tem)

if da_ref_coarse.values.shape != da_ref_fine.values.shape:
raise ValueError("shape of coarse values does not match fine values")

Check warning on line 85 in dodola/core.py

View check run for this annotation

Codecov / codecov/patch

dodola/core.py#L85

Added line #L85 was not covered by tests

if da_ref_coarse.values.shape[1] != q.shape[0]:
raise ValueError(

Check warning on line 88 in dodola/core.py

View check run for this annotation

Codecov / codecov/patch

dodola/core.py#L88

Added line #L88 was not covered by tests
"shape of q is {} and shape of ref coarse/fine is {}".format(
q.shape, da_ref_coarse.values.shape
)
)

out = _argsort(da_ref_coarse.values, da_ref_fine.values, q, arr_sort)

res = xr.DataArray(
out,
dims=(extra, "quantiles"),
coords={extra: da_ref_coarse[extra], "quantiles": q},
attrs=da_ref_coarse.attrs,
).unstack(extra)

else:
# All dims are processed
res = xr.DataArray(

Check warning on line 105 in dodola/core.py

View check run for this annotation

Codecov / codecov/patch

dodola/core.py#L105

Added line #L105 was not covered by tests
_argsort(da_ref_coarse.values, da_ref_fine.values, q, arr_sort),
dims=("quantiles"),
coords={"quantiles": q},
attrs=da_ref_coarse.attrs,
)

return res


@sdba.base.map_groups(
af=[sdba.Grouper.PROP, "quantiles"],
ref_coarse_q=[sdba.Grouper.PROP, "quantiles"],
)
def _qplad_train(ds, *, dim, kind, quantiles):
"""QPLAD: Train step on one group.

Dataset variables:
ref_coarse : training target, coarse resolution
ref_fine : training target, fine resolution
"""
# compute indices of days corresponding to each quantile for ref coarse
# sort ref coarse with those indices (corresponding to # of quantiles)
ref_coarse_q = argsort(
ds.ref_coarse, ds.ref_fine, quantiles, dim, arr_sort="coarse", axis=0
)

# sort ref fine with the same indices
ref_fine_q = argsort(
ds.ref_coarse, ds.ref_fine, quantiles, dim, arr_sort="fine", axis=0
)

# compute adjustment factors as difference bw course and fine for those days
af = sdba.utils.get_correction(ref_coarse_q, ref_fine_q, kind)

return xr.Dataset(data_vars=dict(af=af, ref_coarse_q=ref_coarse_q))


@sdba.base.map_blocks(reduces=[sdba.Grouper.PROP, "quantiles"], scen=[], sim_q=[])
def _qplad_adjust(ds, *, group, interp, extrapolation, kind):
"""QPLAD: Adjust process on one block.

Dataset variables:
af : Adjustment factors
hist_q : Quantiles over the training data
sim : Data to adjust.
"""
af, _ = sdba.utils.extrapolate_qm(ds.af, ds.ref_coarse_q, method=extrapolation)

sel = {dim: ds.sim_q[dim] for dim in set(af.dims).intersection(set(ds.sim_q.dims))}
sel["quantiles"] = ds.sim_q
af = sdba.utils.broadcast(af, ds.sim, group=group, interp=interp, sel=sel)

scen = sdba.utils.apply_correction(ds.sim, af, kind)
return xr.Dataset(dict(scen=scen, sim_q=ds.sim_q))


class QuantilePreservingAnalogDownscaling(sdba.adjustment.TrainAdjust):
r"""Quantile-Preserving Localized Analogs Downscaling.

Adjustment factors are computed between the corresponding days of `ref_coarse` and `ref_fine`.
Quantiles of `sim` are matched to the corresponding quantiles of `AFs` and corrected accordingly.

Parameters
----------
Train step:

nquantiles : int
The number of quantiles to use. Two endpoints at 1e-6 and 1 - 1e-6 will not be added.
kind : {'+', '*'}
The adjustment kind, either additive or multiplicative.
group : Union[str, Grouper]
The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details.


"""

_allow_diff_calendars = False

@classmethod
def _train(
cls,
ref: xr.DataArray,
hist: xr.DataArray,
*,
nquantiles: int = 20,
kind: str = sdba.utils.ADDITIVE,
group: Union[str, sdba.Grouper] = "time",
):

quantiles = equally_spaced_nodes(nquantiles, eps=None).astype(ref.dtype)

ds = _qplad_train(
xr.Dataset({"ref_coarse": ref, "ref_fine": hist}),
group=group,
quantiles=quantiles,
kind=kind,
)

ds.af.attrs.update(
standard_name="Adjustment factors",
long_name="Quantile Preserving Localized Analogs Downscaling Adjustment Factors",
)
ds.ref_coarse_q.attrs.update(
standard_name="Empirical quantiles",
long_name="Empirical quantiles of coarse reference data",
)

return ds, {"group": group, "kind": kind}

def _adjust(self, sim):

# match quantiles from sim to corresponding AFs for that DOY
ds = xr.Dataset(
{
"sim": sim.drop("sim_q"),
"af": self.ds.af,
"sim_q": sim.sim_q,
"ref_coarse_q": self.ds.ref_coarse_q,
}
)

out = _qplad_adjust(
ds,
group=self.group,
interp="linear",
extrapolation="constant",
kind=self.kind,
)

return out.scen


def train_quantiledeltamapping(
reference, historical, variable, kind, quantiles_n=100, window_n=31
):
Expand Down Expand Up @@ -202,7 +418,7 @@

Returns
-------
xclim.sdba.adjustment.QuantilePreservingAnalogDownscaling
QuantilePreservingAnalogDownscaling
"""

# QPLAD method requires that the number of quantiles equals
Expand All @@ -222,7 +438,7 @@
)
)

qplad = sdba.adjustment.QuantilePreservingAnalogDownscaling.train(
qplad = QuantilePreservingAnalogDownscaling.train(
ref=coarse_reference[variable],
hist=fine_reference[variable],
kind=str(kind),
Expand Down Expand Up @@ -256,7 +472,7 @@
variable = str(variable)

if isinstance(qplad, xr.Dataset):
qplad = sdba.adjustment.QuantilePreservingAnalogDownscaling.from_dataset(qplad)
qplad = QuantilePreservingAnalogDownscaling.from_dataset(qplad)

out = qplad.adjust(simulation[variable]).to_dataset(name=variable)

Expand Down
3 changes: 1 addition & 2 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ dependencies:
- pytest-cov
- python=3.9
- s3fs=2022.1.0
- xclim=0.31.0
- xarray=0.21.1
- xesmf=0.6.2
- bottleneck=1.3.2
- zarr=2.11.0
- pip:
- git+https://github.com/ClimateImpactLab/xclim@63023d27f89a457c752568ffcec2e9ce9ad7a81a