diff --git a/dodola/core.py b/dodola/core.py index 6dbffb4..2dbeb94 100644 --- a/dodola/core.py +++ b/dodola/core.py @@ -3,8 +3,10 @@ Math stuff and business logic goes here. This is the "business logic". """ +from typing import Union import warnings import logging +from numba import float32, float64, jit, types import numpy as np import xarray as xr from xclim import sdba, set_options @@ -19,6 +21,220 @@ # Assume data input here is generally clean and valid. +@jit( + [ + float32[:, :](float32[:, :], float32[:, :], float32[:], types.unicode_type), + float64[:, :](float64[:, :], float64[:, :], float64[:], types.unicode_type), + float64[:, :](float64[:, :], float32[:, :], float64[:], types.unicode_type), + float32[:](float32[:], float32[:], float32[:], types.unicode_type), + float64[:](float64[:], float64[:], float64[:], types.unicode_type), + ], + nopython=True, +) +def _argsort(arr_coarse, arr_fine, q, arr_sort="coarse"): + if arr_coarse.ndim == 1: + inds = np.argsort(arr_coarse) + if arr_sort == "coarse": + out = arr_coarse[inds] + elif arr_sort == "fine": + out = arr_fine[inds] + else: + out = np.empty((arr_coarse.shape[0], q.size), dtype=arr_coarse.dtype) + for index in range(out.shape[0]): + inds = np.argsort(arr_coarse[index]) + if arr_sort == "coarse": + out[index] = arr_coarse[index][inds] + elif arr_sort == "fine": + out[index] = arr_fine[index][inds] + return out + + +def argsort(da_ref_coarse, da_ref_fine, q, dim, arr_sort="coarse", axis=0): + """Sort ref_coarse or ref_fine (specified by the arr_sort input arg) + with the indices used to quantile ref_coarse""" + # We have two cases : + # - When all dims are processed : we stack them and use _argsort1d + # - When the quantiles are vectorized over some dims, these are also stacked and then _argsort2D is used. + # All this stacking is so we can cover all ND+1D cases with one numba function. + + # check if there are any nulls in input arrays + # if there are nulls, remove them + da_ref_coarse = da_ref_coarse.dropna(dim="time") + da_ref_fine = da_ref_fine.dropna(dim="time") + + # Stack the dims and send to the last position + # This is in case there are more than one + dims = [dim] if isinstance(dim, str) else dim + tem = xr.core.utils.get_temp_dimname(da_ref_coarse.dims, "temporal") + da_ref_coarse = da_ref_coarse.stack({tem: dims}) + da_ref_fine = da_ref_fine.stack({tem: dims}) + + # So we cut in half the definitions to declare in numba + if not hasattr(q, "dtype") or q.dtype != da_ref_coarse.dtype: + q = np.array(q, dtype=da_ref_coarse.dtype) + + if len(da_ref_coarse.dims) > 1: + # There are some extra dims + extra = xr.core.utils.get_temp_dimname(da_ref_coarse.dims, "extra") + da_ref_coarse = da_ref_coarse.stack({extra: set(da_ref_coarse.dims) - {tem}}) + da_ref_fine = da_ref_fine.stack({extra: set(da_ref_fine.dims) - {tem}}) + da_ref_coarse = da_ref_coarse.transpose(..., tem) + da_ref_fine = da_ref_fine.transpose(..., tem) + + if da_ref_coarse.values.shape != da_ref_fine.values.shape: + raise ValueError("shape of coarse values does not match fine values") + + if da_ref_coarse.values.shape[1] != q.shape[0]: + raise ValueError( + "shape of q is {} and shape of ref coarse/fine is {}".format( + q.shape, da_ref_coarse.values.shape + ) + ) + + out = _argsort(da_ref_coarse.values, da_ref_fine.values, q, arr_sort) + + res = xr.DataArray( + out, + dims=(extra, "quantiles"), + coords={extra: da_ref_coarse[extra], "quantiles": q}, + attrs=da_ref_coarse.attrs, + ).unstack(extra) + + else: + # All dims are processed + res = xr.DataArray( + _argsort(da_ref_coarse.values, da_ref_fine.values, q, arr_sort), + dims=("quantiles"), + coords={"quantiles": q}, + attrs=da_ref_coarse.attrs, + ) + + return res + + +@sdba.base.map_groups( + af=[sdba.Grouper.PROP, "quantiles"], + ref_coarse_q=[sdba.Grouper.PROP, "quantiles"], +) +def _qplad_train(ds, *, dim, kind, quantiles): + """QPLAD: Train step on one group. + + Dataset variables: + ref_coarse : training target, coarse resolution + ref_fine : training target, fine resolution + """ + # compute indices of days corresponding to each quantile for ref coarse + # sort ref coarse with those indices (corresponding to # of quantiles) + ref_coarse_q = argsort( + ds.ref_coarse, ds.ref_fine, quantiles, dim, arr_sort="coarse", axis=0 + ) + + # sort ref fine with the same indices + ref_fine_q = argsort( + ds.ref_coarse, ds.ref_fine, quantiles, dim, arr_sort="fine", axis=0 + ) + + # compute adjustment factors as difference bw course and fine for those days + af = sdba.utils.get_correction(ref_coarse_q, ref_fine_q, kind) + + return xr.Dataset(data_vars=dict(af=af, ref_coarse_q=ref_coarse_q)) + + +@sdba.base.map_blocks(reduces=[sdba.Grouper.PROP, "quantiles"], scen=[], sim_q=[]) +def _qplad_adjust(ds, *, group, interp, extrapolation, kind): + """QPLAD: Adjust process on one block. + + Dataset variables: + af : Adjustment factors + hist_q : Quantiles over the training data + sim : Data to adjust. + """ + af, _ = sdba.utils.extrapolate_qm(ds.af, ds.ref_coarse_q, method=extrapolation) + + sel = {dim: ds.sim_q[dim] for dim in set(af.dims).intersection(set(ds.sim_q.dims))} + sel["quantiles"] = ds.sim_q + af = sdba.utils.broadcast(af, ds.sim, group=group, interp=interp, sel=sel) + + scen = sdba.utils.apply_correction(ds.sim, af, kind) + return xr.Dataset(dict(scen=scen, sim_q=ds.sim_q)) + + +class QuantilePreservingAnalogDownscaling(sdba.adjustment.TrainAdjust): + r"""Quantile-Preserving Localized Analogs Downscaling. + + Adjustment factors are computed between the corresponding days of `ref_coarse` and `ref_fine`. + Quantiles of `sim` are matched to the corresponding quantiles of `AFs` and corrected accordingly. + + Parameters + ---------- + Train step: + + nquantiles : int + The number of quantiles to use. Two endpoints at 1e-6 and 1 - 1e-6 will not be added. + kind : {'+', '*'} + The adjustment kind, either additive or multiplicative. + group : Union[str, Grouper] + The grouping information. See :py:class:`xclim.sdba.base.Grouper` for details. + + + """ + + _allow_diff_calendars = False + + @classmethod + def _train( + cls, + ref: xr.DataArray, + hist: xr.DataArray, + *, + nquantiles: int = 20, + kind: str = sdba.utils.ADDITIVE, + group: Union[str, sdba.Grouper] = "time", + ): + + quantiles = equally_spaced_nodes(nquantiles, eps=None).astype(ref.dtype) + + ds = _qplad_train( + xr.Dataset({"ref_coarse": ref, "ref_fine": hist}), + group=group, + quantiles=quantiles, + kind=kind, + ) + + ds.af.attrs.update( + standard_name="Adjustment factors", + long_name="Quantile Preserving Localized Analogs Downscaling Adjustment Factors", + ) + ds.ref_coarse_q.attrs.update( + standard_name="Empirical quantiles", + long_name="Empirical quantiles of coarse reference data", + ) + + return ds, {"group": group, "kind": kind} + + def _adjust(self, sim): + + # match quantiles from sim to corresponding AFs for that DOY + ds = xr.Dataset( + { + "sim": sim.drop("sim_q"), + "af": self.ds.af, + "sim_q": sim.sim_q, + "ref_coarse_q": self.ds.ref_coarse_q, + } + ) + + out = _qplad_adjust( + ds, + group=self.group, + interp="linear", + extrapolation="constant", + kind=self.kind, + ) + + return out.scen + + def train_quantiledeltamapping( reference, historical, variable, kind, quantiles_n=100, window_n=31 ): @@ -202,7 +418,7 @@ def train_analogdownscaling( Returns ------- - xclim.sdba.adjustment.QuantilePreservingAnalogDownscaling + QuantilePreservingAnalogDownscaling """ # QPLAD method requires that the number of quantiles equals @@ -222,7 +438,7 @@ def train_analogdownscaling( ) ) - qplad = sdba.adjustment.QuantilePreservingAnalogDownscaling.train( + qplad = QuantilePreservingAnalogDownscaling.train( ref=coarse_reference[variable], hist=fine_reference[variable], kind=str(kind), @@ -256,7 +472,7 @@ def adjust_analogdownscaling(simulation, qplad, variable): variable = str(variable) if isinstance(qplad, xr.Dataset): - qplad = sdba.adjustment.QuantilePreservingAnalogDownscaling.from_dataset(qplad) + qplad = QuantilePreservingAnalogDownscaling.from_dataset(qplad) out = qplad.adjust(simulation[variable]).to_dataset(name=variable) diff --git a/environment.yaml b/environment.yaml index 4c7aec5..770326d 100644 --- a/environment.yaml +++ b/environment.yaml @@ -19,9 +19,8 @@ dependencies: - pytest-cov - python=3.9 - s3fs=2022.1.0 +- xclim=0.31.0 - xarray=0.21.1 - xesmf=0.6.2 - bottleneck=1.3.2 - zarr=2.11.0 -- pip: - - git+https://github.com/ClimateImpactLab/xclim@63023d27f89a457c752568ffcec2e9ce9ad7a81a