Skip to content

Commit

Permalink
more power_transformer performance tweaks (#538)
Browse files Browse the repository at this point in the history
* more power_transformer performance tweaks

* clean up

* update comments & var names

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
mathause and pre-commit-ci[bot] authored Oct 3, 2024
1 parent 72d2fd9 commit 2b5da4a
Showing 1 changed file with 21 additions and 28 deletions.
49 changes: 21 additions & 28 deletions mesmer/stats/_power_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,31 +70,26 @@ def _yeo_johnson_transform_optimized(data):
data_log1p = np.log1p(np.abs(data))

pos = data >= 0
neg = ~pos

def _inner(lambdas):

# NOTE: this code is copied from sklearn's PowerTransformer, see
# NOTE: this code is adapted from sklearn's PowerTransformer, see
# https://github.com/scikit-learn/scikit-learn/blob/8721245511de2f225ff5f9aa5f5fadce663cd4a3/sklearn/preprocessing/_data.py#L3396
# we acknowledge there is an inconsistency in the comparison of lambdas

lambda_eq_0 = np.abs(lambdas) < eps
lambda_eq_2 = np.abs(lambdas - 2) <= eps
# align lambdas for pos and neg data - so we only have two cases
# NOTE: cannot do this inplace; `where` is faster than copying & subsetting
lambdas = np.where(pos, lambdas, 2.0 - lambdas)

sel_a = pos & lambda_eq_0
sel_b = pos & ~lambda_eq_0
sel_c = neg & ~lambda_eq_2
sel_d = neg & lambda_eq_2
# NOTE: abs(2 - a) == abs(a - 2)
lmbds_eq_0_or_2 = np.abs(lambdas) <= eps
lmbds_ne_0_or_2 = ~lmbds_eq_0_or_2

transf[sel_a] = data_log1p[sel_a]
transf[lmbds_eq_0_or_2] = data_log1p[lmbds_eq_0_or_2]

lmbds = lambdas[sel_b]
transf[sel_b] = np.expm1(data_log1p[sel_b] * lmbds) / lmbds
lmbds = lambdas[lmbds_ne_0_or_2]
transf[lmbds_ne_0_or_2] = np.expm1(data_log1p[lmbds_ne_0_or_2] * lmbds) / lmbds

lmbds = 2 - lambdas[sel_c]
transf[sel_c] = -np.expm1(data_log1p[sel_c] * lmbds) / lmbds

transf[sel_d] = -data_log1p[sel_d]
np.copysign(transf, data, out=transf)

return transf

Expand Down Expand Up @@ -157,33 +152,34 @@ def _yeo_johnson_optimize_lambda_np(monthly_residuals, yearly_pred):
_yeo_johnson_transform = _yeo_johnson_transform_optimized(monthly_residuals)

data_log1p = np.sign(monthly_residuals) * np.log1p(np.abs(monthly_residuals))
data_log1p_sum = data_log1p.sum()

def _neg_log_likelihood(coeffs):
"""Return the negative log likelihood of the observed local monthly residuals
as a function of lambda.
"""
lambdas = lambda_function(coeffs, yearly_pred)

# version with own power transform
transformed_resids = _yeo_johnson_transform(lambdas)

n_samples = monthly_residuals.shape[0]
loglikelihood = -n_samples / 2 * np.log(transformed_resids.var())
loglikelihood += ((lambdas - 1) * data_log1p).sum()

loglikelihood += (lambdas * data_log1p).sum() - data_log1p_sum

return -loglikelihood

bounds = np.array([[0, np.inf], [-0.1, 0.1]])
first_guess = np.array([1, 0])
first_guess = np.array([1.0, 0.0])

coeffs = minimize(
res = minimize(
_neg_log_likelihood,
x0=first_guess,
bounds=bounds,
method="Nelder-Mead",
).x
)

return coeffs
return res.x


def get_lambdas_from_covariates(coeffs, yearly_pred):
Expand Down Expand Up @@ -261,20 +257,17 @@ def fit_yeo_johnson_transform(monthly_residuals, yearly_pred, time_dim="time"):
if not isinstance(yearly_pred, xr.DataArray):
raise TypeError(f"Expected a `xr.DataArray`, got {type(yearly_pred)}")

monthly_resids_grouped = monthly_residuals.groupby(time_dim + ".month")

coeffs = []
for month in range(1, 13):
for month in range(12):

# align time dimension
monthly_data = monthly_resids_grouped[month]
monthly_data[time_dim] = yearly_pred[time_dim]
monthly_data = monthly_residuals.isel({time_dim: slice(month, None, 12)})

res = xr.apply_ufunc(
_yeo_johnson_optimize_lambda_np,
monthly_data,
yearly_pred,
input_core_dims=[[time_dim], [time_dim]],
exclude_dims={time_dim},
output_core_dims=[["coeff"]],
output_dtypes=[float],
vectorize=True,
Expand Down

0 comments on commit 2b5da4a

Please sign in to comment.