Skip to content

Commit

Permalink
Refactor interval scores and improve testing (#82)
Browse files Browse the repository at this point in the history
* exclude type checking blocks from test coverage report

* fix failing tests for weighted interval score

* refactoring and fixes to interval scores

* update docs for interval scores

* improve testing for interval scores

* run formatting
  • Loading branch information
frazane authored Oct 23, 2024
1 parent efe651e commit 6e39f5a
Show file tree
Hide file tree
Showing 18 changed files with 335 additions and 213 deletions.
29 changes: 0 additions & 29 deletions docs/api/interval.md
Original file line number Diff line number Diff line change
@@ -1,34 +1,5 @@
# Interval Score

## Interval or Winkler Score

For a prediction interval (PI), the interval or Winkler score is given by:

$$
\text{IS} = \begin{cases}
(u - l) + \frac{2}{\alpha}(l - y) & \text{for } y < l \\
(u - l) & \text{for } l \leq y \leq u \\
(u - l) + \frac{2}{\alpha}(y - u) & \text{for } y > u. \\
\end{cases}
$$

for an $(1 - \alpha)$PI of $[l, u]$ and the true value $y$ [@gneiting_strictly_2007, @bracher2021evaluating @winkler1972decision].

## Weighted Interval Score

The weighted interval score (WIS) is defined as

$$
\text{WIS}_{\alpha_{0:K}}(F, y) = \frac{1}{K+0.5}(w_0 \times |y - m| + \sum_{k=1}^K (w_k \times IS_{\alpha_k}(F, y)))
$$

where $m$ denotes the median prediction, $w_0$ denotes the weight of the median prediction, $IS_{\alpha_k}(F, y)$ denotes the interval score for the $1 - \alpha$ prediction interval and $w_k$ is the according weight. The WIS is calculated for a set of (central) PIs and the predictive median [@bracher2021evaluating]. The weights are an optional parameter and default weight is the canonical weight $w_k = \frac{2}{\alpha_k}$ and $w_0 = 0.5$. For these weights, it holds that:

$$
\text{WIS}_{\alpha_{0:K}}(F, y) \approx \text{CRPS}(F, y).
$$


::: scoringrules.interval_score

::: scoringrules.weighted_interval_score
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,13 @@ dev-dependencies = [
]

[tool.ruff]
line-length = 88

[tool.ruff.lint]
ignore = ["E741"]

[tool.coverage.run]
omit = ["**/_gufuncs.py", "**/_gufunc.py"]

[tool.coverage.report]
exclude_also = ["if tp.TYPE_CHECKING:"]
73 changes: 65 additions & 8 deletions scoringrules/_crps.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,12 @@ def crps_ensemble(
if axis != -1:
forecasts = B.moveaxis(forecasts, axis, -1)

if not sorted_ensemble and estimator not in ["nrg", "akr", "akr_circperm", "fair"]:
if not sorted_ensemble and estimator not in [
"nrg",
"akr",
"akr_circperm",
"fair",
]:
forecasts = B.sort(forecasts, axis=-1)

if backend == "numba":
Expand Down Expand Up @@ -865,7 +870,14 @@ def crps_gtclogistic(
>>> sr.crps_gtclogistic(0.0, 0.1, 0.4, -1.0, 1.0, 0.1, 0.1)
"""
return crps.gtclogistic(
observation, location, scale, lower, upper, lmass, umass, backend=backend
observation,
location,
scale,
lower,
upper,
lmass,
umass,
backend=backend,
)


Expand Down Expand Up @@ -953,7 +965,14 @@ def crps_clogistic(
lmass = stats._logis_cdf((lower - location) / scale)
umass = 1 - stats._logis_cdf((upper - location) / scale)
return crps.gtclogistic(
observation, location, scale, lower, upper, lmass, umass, backend=backend
observation,
location,
scale,
lower,
upper,
lmass,
umass,
backend=backend,
)


Expand Down Expand Up @@ -990,7 +1009,14 @@ def crps_gtcnormal(
>>> sr.crps_gtcnormal(0.0, 0.1, 0.4, -1.0, 1.0, 0.1, 0.1)
"""
return crps.gtcnormal(
observation, location, scale, lower, upper, lmass, umass, backend=backend
observation,
location,
scale,
lower,
upper,
lmass,
umass,
backend=backend,
)


Expand Down Expand Up @@ -1078,7 +1104,14 @@ def crps_cnormal(
lmass = stats._norm_cdf((lower - location) / scale)
umass = 1 - stats._norm_cdf((upper - location) / scale)
return crps.gtcnormal(
observation, location, scale, lower, upper, lmass, umass, backend=backend
observation,
location,
scale,
lower,
upper,
lmass,
umass,
backend=backend,
)


Expand Down Expand Up @@ -1146,7 +1179,15 @@ def crps_gtct(
>>> sr.crps_gtct(0.0, 2.0, 0.1, 0.4, -1.0, 1.0, 0.1, 0.1)
"""
return crps.gtct(
observation, df, location, scale, lower, upper, lmass, umass, backend=backend
observation,
df,
location,
scale,
lower,
upper,
lmass,
umass,
backend=backend,
)


Expand Down Expand Up @@ -1192,7 +1233,15 @@ def crps_tt(
>>> sr.crps_tt(0.0, 2.0, 0.1, 0.4, -1.0, 1.0)
"""
return crps.gtct(
observation, df, location, scale, lower, upper, 0.0, 0.0, backend=backend
observation,
df,
location,
scale,
lower,
upper,
0.0,
0.0,
backend=backend,
)


Expand Down Expand Up @@ -1240,7 +1289,15 @@ def crps_ct(
lmass = stats._t_cdf((lower - location) / scale, df)
umass = 1 - stats._t_cdf((upper - location) / scale, df)
return crps.gtct(
observation, df, location, scale, lower, upper, lmass, umass, backend=backend
observation,
df,
location,
scale,
lower,
upper,
lmass,
umass,
backend=backend,
)


Expand Down
Loading

0 comments on commit 6e39f5a

Please sign in to comment.