Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symmetric Metrics & Metrics improvements #44

Merged
merged 11 commits into from
Oct 14, 2024
12 changes: 12 additions & 0 deletions docs/api/utilities/metrics/spatial.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@

---

::: exponax.metrics.sMAE

---

::: exponax.metrics.sMSE

---

::: exponax.metrics.sRMSE

---

::: exponax.metrics.spatial_norm

---
Expand Down
5 changes: 4 additions & 1 deletion docs/examples/on_metrics_simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
"3. Rooted metrics (i.e., related to the RMSE)\n",
"\n",
"Then for each of the three, there is both the absolute version and a\n",
"relative/normalized version\n",
"relative/normalized version. For all spatial-based metrics, MAE, MSE, and RMSE\n",
"also come with a symmetric version.\n",
"\n",
"All metrics computation work on single state arrays, i.e., arrays with a leading channel axis and one, two, or three subsequent spatial axes. **The arrays shall not have leading batch axes.** To work with batched arrays use `jax.vmap` and then reduce, e.g., by `jnp.mean`. Alternatively, use the convinience wrapper [`exponax.metrics.mean_metric`][].\n",
"\n",
"All metrics **sum over the channel axis**.\n",
"\n",
" ⚠️ ⚠️ ⚠️ ⚠️ ⚠️ This notebook is a WIP, it will come with future release of Exponax ⚠️ ⚠️ ⚠️ ⚠️ ⚠️"
]
},
Expand Down
6 changes: 6 additions & 0 deletions exponax/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
nMAE,
nMSE,
nRMSE,
sMAE,
sMSE,
spatial_aggregator,
spatial_norm,
sRMSE,
)
from ._utils import mean_metric

Expand All @@ -31,6 +34,9 @@
"nMAE",
"nMSE",
"nRMSE",
"sMAE",
"sMSE",
"sRMSE",
"fourier_aggregator",
"fourier_norm",
"fourier_MAE",
Expand Down
2 changes: 1 addition & 1 deletion exponax/metrics/_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def fourier_aggregator(
!!! info
The result of this function (under default settings) is (up to rounding
errors) identical to [`exponax.metrics.spatial_aggregator`][] for
`inner_exponent=1.0`. As such, it can be a consistent counterpart for
`inner_exponent=2.0`. As such, it can be a consistent counterpart for
metrics based on the `L²(Ω)` functional norm.

!!! tip
Expand Down
207 changes: 195 additions & 12 deletions exponax/metrics/_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ def spatial_aggregator(
and the right is not, there is the following relation between a continuous
function `u(x)` and its discretely sampled counterpart `uₕ`

‖ u(x) ‖ᵖ_Lᵖ(Ω) = (∫_Ω |u(x)|ᵖ dx)^(1/p) = ( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^(1/p)
‖ u(x) ‖_Lᵖ(Ω) = (∫_Ω |u(x)|ᵖ dx)^(1/p) ( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^(1/p)

where the summation `∑ᵢ` must be understood as a sum over all `Nᴰ` points
across all spatial dimensions. The `inner_exponent` corresponds to `p` in
the above formula. This function allows setting the outer exponent `q`
manually. If it is not specified, it is set to `1/q = 1/p` to get a valid
norm.
the above formula. This function also allows setting the outer exponent `q`
which via

( (L/N)ᴰ ∑ᵢ|uᵢ|ᵖ )^q

If it is not specified, it is set to `q = 1/p` to get a valid norm.

!!! tip
To apply this function to a state tensor with a leading channel axis,
Expand All @@ -40,7 +43,7 @@ def spatial_aggregator(
**Arguments:**

- `state_no_channel`: The state tensor **without a leading channel
dimension**.
axis**.
- `num_spatial_dims`: The number of spatial dimensions. If not specified,
it is inferred from the number of axes in `state_no_channel`.
- `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`.
Expand Down Expand Up @@ -84,7 +87,7 @@ def spatial_norm(
state: Float[Array, "C ... N"],
state_ref: Optional[Float[Array, "C ... N"]] = None,
*,
mode: Literal["absolute", "normalized"] = "absolute",
mode: Literal["absolute", "normalized", "symmetric"] = "absolute",
domain_extent: float = 1.0,
inner_exponent: float = 2.0,
outer_exponent: Optional[float] = None,
Expand All @@ -97,13 +100,18 @@ def spatial_norm(
control, consider using [`exponax.metrics.spatial_aggregator`][] directly.

This function allows providing a second state (`state_ref`) to compute
either the absolute or normalized difference. The `"absolute"` mode computes
either the absolute, normalized, or symmetric difference. The `"absolute"`
mode computes

(‖|uₕ uₕʳ|ᵖ ‖_L²(Ω))^q
(‖uₕ - uₕʳ‖_L^p(Ω))^(q*p)

while the `"normalized"` mode computes

(‖|uₕ − uₕʳ|ᵖ‖_ L²(Ω))^q / (‖|uₕʳ|ᵖ‖_ L²(Ω))^q
(‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) / ((‖uₕʳ‖_L^p(Ω))^(q*p))

and the `"symmetric"` mode computes

2 * (‖uₕ - uₕʳ‖_L^p(Ω))^(q*p) / ((‖uₕ‖_L^p(Ω))^(q*p) + (‖uₕʳ‖_L^p(Ω))^(q*p))

In either way, the channels are summed **after** the aggregation. The
`inner_exponent` corresponds to `p` in the above formulas. The
Expand All @@ -124,7 +132,8 @@ def spatial_norm(
- `state_ref`: The reference state tensor. Must have the same shape as
`state`. If not specified, only the absolute norm of `state` is
computed.
- `mode`: The mode of the norm. Either `"absolute"` or `"normalized"`.
- `mode`: The mode of the norm. Either `"absolute"`, `"normalized"`, or
`"symmetric"`.
- `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`.
- `inner_exponent`: The exponent `p` in the L^p norm.
- `outer_exponent`: The exponent `q` the result after aggregation is raised
Expand All @@ -133,6 +142,8 @@ def spatial_norm(
if state_ref is None:
if mode == "normalized":
raise ValueError("mode 'normalized' requires state_ref")
if mode == "symmetric":
raise ValueError("mode 'symmetric' requires state_ref")
diff = state
else:
diff = state - state_ref
Expand All @@ -157,6 +168,27 @@ def spatial_norm(
)(state_ref)
normalized_diff_per_channel = diff_norm_per_channel / ref_norm_per_channel
norm_per_channel = normalized_diff_per_channel
elif mode == "symmetric":
state_norm_per_channel = jax.vmap(
lambda s: spatial_aggregator(
s,
domain_extent=domain_extent,
inner_exponent=inner_exponent,
outer_exponent=outer_exponent,
),
)(state)
ref_norm_per_channel = jax.vmap(
lambda r: spatial_aggregator(
r,
domain_extent=domain_extent,
inner_exponent=inner_exponent,
outer_exponent=outer_exponent,
),
)(state_ref)
symmetric_diff_per_channel = (
2 * diff_norm_per_channel / (state_norm_per_channel + ref_norm_per_channel)
)
norm_per_channel = symmetric_diff_per_channel
else:
norm_per_channel = diff_norm_per_channel

Expand Down Expand Up @@ -255,6 +287,55 @@ def nMAE(
)


def sMAE(
u_pred: Float[Array, "C ... N"],
u_ref: Float[Array, "C ... N"],
*,
domain_extent: float = 1.0,
) -> float:
"""
Compute the symmetric mean absolute error (sMAE) between two states.

∑_(channels) [2 ∑_(space) (L/N)ᴰ |uₕ - uₕʳ| / (∑_(space) (L/N)ᴰ |uₕ| + ∑_(space) (L/N)ᴰ |uₕʳ|)]

Given the correct `domain_extent`, this is consistent to the following
functional norm:

2 ∫_Ω |u(x) - uʳ(x)| dx / (∫_Ω |u(x)| dx + ∫_Ω |uʳ(x)| dx)

The channel axis is summed **after** the aggregation.

!!! tip
To apply this function to a state tensor with a leading batch axis, use
`jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
a helper for this, [`exponax.metrics.mean_metric`][] is provided.

!!! info
This symmetric metric is bounded between 0 and C with C being the number
of channels.


**Arguments:**

- `u_pred`: The state array, must follow the `Exponax` convention with a
leading channel axis, and either one, two, or three subsequent spatial
axes.
- `u_ref`: The reference state array. Must have the same shape as `u_pred`.
- `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
provide to get the correctly consistent norm. If this metric is used an
optimization objective, it can often be ignored since it only
contributes a multiplicative factor.
"""
return spatial_norm(
u_pred,
u_ref,
mode="symmetric",
domain_extent=domain_extent,
inner_exponent=1.0,
outer_exponent=1.0,
)


def MSE(
u_pred: Float[Array, "C ... N"],
u_ref: Optional[Float[Array, "C ... N"]] = None,
Expand Down Expand Up @@ -347,6 +428,55 @@ def nMSE(
)


def sMSE(
u_pred: Float[Array, "C ... N"],
u_ref: Float[Array, "C ... N"],
*,
domain_extent: float = 1.0,
) -> float:
"""
Compute the symmetric mean squared error (sMSE) between two states.

∑_(channels) [2 ∑_(space) (L/N)ᴰ |uₕ - uₕʳ|² / (∑_(space) (L/N)ᴰ |uₕ|² + ∑_(space) (L/N)ᴰ |uₕʳ|²)]

Given the correct `domain_extent`, this is consistent to the following
functional norm:

2 ∫_Ω |u(x) - uʳ(x)|² dx / (∫_Ω |u(x)|² dx + ∫_Ω |uʳ(x)|² dx)

The channel axis is summed **after** the aggregation.

!!! tip
To apply this function to a state tensor with a leading batch axis, use
`jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
a helper for this, [`exponax.metrics.mean_metric`][] is provided.

!!! info
This symmetric metric is bounded between 0 and C with C being the number
of channels.


**Arguments:**

- `u_pred`: The state array, must follow the `Exponax` convention with a
leading channel axis, and either one, two, or three subsequent spatial
axes.
- `u_ref`: The reference state array. Must have the same shape as `u_pred`.
- `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
provide to get the correctly consistent norm. If this metric is used an
optimization objective, it can often be ignored since it only
contributes a multiplicative factor.
"""
return spatial_norm(
u_pred,
u_ref,
mode="symmetric",
domain_extent=domain_extent,
inner_exponent=2.0,
outer_exponent=1.0,
)


def RMSE(
u_pred: Float[Array, "C ... N"],
u_ref: Optional[Float[Array, "C ... N"]] = None,
Expand All @@ -361,7 +491,7 @@ def RMSE(
Given the correct `domain_extent`, this is consistent to the following
functional norm:

(‖ u - uʳ ‖_L²(Ω)) = (∫_Ω |u(x) - uʳ(x)|² dx)
(‖ u - uʳ ‖_L²(Ω)) = (∫_Ω |u(x) - uʳ(x)|² dx)

The channel axis is summed **after** the aggregation. Hence, it is also
summed **after** the square root. If you need the RMSE per channel, consider
Expand Down Expand Up @@ -411,7 +541,7 @@ def nRMSE(
Given the correct `domain_extent`, this is consistent to the following
functional norm:

(‖ u - uʳ ‖_L²(Ω) / ‖ uʳ ‖_L²(Ω)) = (∫_Ω |u(x) - uʳ(x)|² dx / ∫_Ω
(‖ u - uʳ ‖_L²(Ω) / ‖ uʳ ‖_L²(Ω)) = (∫_Ω |u(x) - uʳ(x)|² dx / ∫_Ω
|uʳ(x)|² dx

The channel axis is summed **after** the aggregation. Hence, it is also
Expand Down Expand Up @@ -444,3 +574,56 @@ def nRMSE(
inner_exponent=2.0,
outer_exponent=0.5,
)


def sRMSE(
u_pred: Float[Array, "C ... N"],
u_ref: Float[Array, "C ... N"],
*,
domain_extent: float = 1.0,
) -> float:
"""
Compute the symmetric root mean squared error (sRMSE) between two states.

∑_(channels) [2 √(∑_(space) (L/N)ᴰ |uₕ - uₕʳ|²) / (√(∑_(space) (L/N)ᴰ
|uₕ|²) + √(∑_(space) (L/N)ᴰ |uₕʳ|²))]

Given the correct `domain_extent`, this is consistent to the following
functional norm:

2 √(∫_Ω |u(x) - uʳ(x)|² dx) / (√(∫_Ω |u(x)|² dx) + √(∫_Ω |uʳ(x)|² dx))

The channel axis is summed **after** the aggregation. Hence, it is also
summed **after** the square root and after normalization. If you need more
fine-grained control, consider using
[`exponax.metrics.spatial_aggregator`][] directly.

!!! tip
To apply this function to a state tensor with a leading batch axis, use
`jax.vmap`. Then the batch axis can be reduced, e.g., by `jnp.mean`. As
a helper for this, [`exponax.metrics.mean_metric`][] is provided.

!!! info
This symmetric metric is bounded between 0 and C with C being the number
of channels.


**Arguments:**

- `u_pred`: The state array, must follow the `Exponax` convention with a
leading channel axis, and either one, two, or three subsequent spatial
axes.
- `u_ref`: The reference state array. Must have the same shape as `u_pred`.
- `domain_extent`: The extent `L` of the domain `Ω = (0, L)ᴰ`. Must be
provide to get the correctly consistent norm. If this metric is used an
optimization objective, it can often be ignored since it only contributes
a multiplicative factor
"""
return spatial_norm(
u_pred,
u_ref,
mode="symmetric",
domain_extent=domain_extent,
inner_exponent=2.0,
outer_exponent=0.5,
)
3 changes: 2 additions & 1 deletion exponax/metrics/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ def mean_metric(
'meanifies' a metric function to operate on arrays with a leading batch axis
"""
wrapped_fn = lambda *a: metric_fn(*a, **kwargs)
return jnp.mean(jax.vmap(wrapped_fn)(*args))
metric_per_sample = jax.vmap(wrapped_fn, in_axes=0)(*args)
return jnp.mean(metric_per_sample, axis=0)
12 changes: 12 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,15 @@ def test_constant_offset(num_spatial_dims: int):
assert ex.metrics.nMSE(u_0, u_1) == pytest.approx((2.0 - 4.0) ** 2 / (4.0) ** 2)
assert ex.metrics.nMSE(u_0, u_1) == pytest.approx(1 / 4)

# == approx(0.4
assert ex.metrics.sMSE(u_1, u_0) == pytest.approx(
2.0 * (4.0 - 2.0) ** 2 / ((2.0) ** 2 + (4.0) ** 2)
)
assert ex.metrics.sMSE(u_1, u_0) == pytest.approx(0.4)

# Symmetric metric must be symmetric
assert ex.metrics.sMSE(u_0, u_1) == ex.metrics.sMSE(u_1, u_0)

assert ex.metrics.RMSE(u_1, u_0, domain_extent=1.0) == pytest.approx(2.0)
assert ex.metrics.RMSE(u_1, u_0, domain_extent=DOMAIN_EXTENT) == pytest.approx(
jnp.sqrt(DOMAIN_EXTENT**num_spatial_dims * 4.0)
Expand All @@ -60,6 +69,9 @@ def test_constant_offset(num_spatial_dims: int):
)
assert ex.metrics.nRMSE(u_0, u_1) == pytest.approx(0.5)

# == approx(2/3)
assert ex.metrics.sRMSE(u_1, u_0) == pytest.approx(2 / 3)

# The Fourier nRMSE should be identical to the spatial nRMSE
# assert ex.metrics.fourier_nRMSE(u_1, u_0) == ex.metrics.nRMSE(u_1, u_0)
# assert ex.metrics.fourier_nRMSE(u_0, u_1) == ex.metrics.nRMSE(u_0, u_1)
Expand Down
Loading