Skip to content

Commit

Permalink
mesmer_x: add tests for Expression
Browse files Browse the repository at this point in the history
  • Loading branch information
mathause committed Sep 18, 2024
1 parent 6c800e5 commit 3a6f75a
Showing 1 changed file with 122 additions and 0 deletions.
122 changes: 122 additions & 0 deletions tests/unit/test_mesmer_x_expression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest
import scipy as sp

from mesmer.mesmer_x import Expression

inf = float("inf")


def test_expression_genextreme():

expression_str = (
"genextreme(loc=c1 + c2 * __pred1__, scale=c3 + c4 * __pred2__**2, c=c5)"
)

expr = Expression(expression_str, expr_name="name")

assert expr.expression == expression_str
assert expr.expression_name == "name"

assert expr.distrib == sp.stats.genextreme
assert not expr.is_distrib_discrete

assert expr.parameters_list == ["c", "loc", "scale"]

bounds = {"c": [-inf, inf], "loc": [-inf, inf], "scale": [0, inf]}
assert expr.boundaries_parameters == bounds

param_expr = {"loc": "c1+c2*pred1", "scale": "c3+c4*pred2**2", "c": "c5"}
assert expr.parameters_expressions == param_expr

coeffs = ["c1", "c2", "c3", "c4", "c5"]
assert expr.coefficients_list == coeffs

coeffs_per_param = {"loc": ["c1", "c2"], "scale": ["c3", "c4"], "c": ["c5"]}
assert expr.coefficients_dict == coeffs_per_param


def test_expression_norm():

expression_str = "norm(loc=c1 + (c2 - c1) / ( 1 + np.exp(c3 * __GMT_t__ + c4 * __GMT_tm1__ - c5) ), scale=c6)"

expr = Expression(expression_str, expr_name="name")

assert expr.expression == expression_str
assert expr.expression_name == "name"

assert expr.distrib == sp.stats.norm
assert not expr.is_distrib_discrete

assert expr.parameters_list == ["loc", "scale"]

bounds = {"loc": [-inf, inf], "scale": [0, inf]}
assert expr.boundaries_parameters == bounds

param_expr = {"loc": "c1+(c2-c1)/(1+np.exp(c3*GMT_t+c4*GMT_tm1-c5))", "scale": "c6"}
assert expr.parameters_expressions == param_expr

coeffs = ["c1", "c2", "c3", "c4", "c5", "c6"]
assert expr.coefficients_list == coeffs

coeffs_per_param = {"loc": ["c1", "c2", "c3", "c4", "c5"], "scale": ["c6"]}
assert expr.coefficients_dict == coeffs_per_param


def test_expression_binom():
# a discrete distribution

expression_str = "binom(loc=c1, n=5, p=7)"

expr = Expression(expression_str, expr_name="name")

assert expr.expression == expression_str
assert expr.expression_name == "name"

assert expr.distrib == sp.stats.binom
assert not expr.is_distrib_discrete

assert expr.parameters_list == ["n", "p", "loc"]

bounds = {"n": [-inf, inf], "loc": [-inf, inf], "p": [-inf, inf]}
assert expr.boundaries_parameters == bounds

param_expr = {"loc": "c1", "n": "5", "p": "7"}
assert expr.parameters_expressions == param_expr

coeffs = ["c1"]
assert expr.coefficients_list == coeffs

coeffs_per_param = {"loc": ["c1"], "n": [], "p": []}
assert expr.coefficients_dict == coeffs_per_param


@pytest.mark.xfail(reason="https://github.com/MESMER-group/mesmer/issues/525")
def test_expression_exponpow():

expression_str = "exponpow(loc=c1, scale=c2+np.min([np.max(np.mean([__GMT_tm1__,__GMT_tp1__],axis=0)), math.gamma(__XYZ__)]), b=c3)"

expr = Expression(expression_str, expr_name="name")

assert expr.expression == expression_str
assert expr.expression_name == "name"

assert expr.distrib == sp.stats.exponpow
assert not expr.is_distrib_discrete

assert expr.parameters_list == ["loc", "scale"]

bounds = {"b": [inf, inf], "loc": [-inf, inf], "scale": [0, inf]}
assert expr.boundaries_parameters == bounds

param_expr = {
"loc": "c1",
"scale": "c2+np.min([np.max(np.mean([GMT_tm1,GMT_tp1],axis=0))]",
"b": "c3",
}
assert expr.parameters_expressions == param_expr

coeffs = ["c1", "c2", "c3", "c4", "c5", "c6"]
assert expr.coefficients_list == coeffs

coeffs_per_param = {"loc": ["c1"], "scale": ["c2"], "b": ["c3"]}
assert expr.coefficients_dict == coeffs_per_param

0 comments on commit 3a6f75a

Please sign in to comment.