Skip to content

Commit

Permalink
TST: Limit test_trapping_sr3_quadratic_library to small data
Browse files Browse the repository at this point in the history
Also test just the optimizer, not the whole SINDy pipeline
Cuts this test duration, previously the longest, by 95%
  • Loading branch information
Jacob-Stevens-Haas committed Aug 9, 2023
1 parent fafe3e0 commit c1bbdd7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 23 deletions.
43 changes: 21 additions & 22 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,50 +478,49 @@ def test_stable_linear_sr3_linear_library(params):
dict(thresholder="l1", threshold=1e-5),
dict(
thresholder="weighted_l1",
thresholds=np.zeros((3, 9)),
thresholds=np.zeros((1, 2)),
eta=1e5,
alpha_m=1e4,
alpha_A=1e5,
),
dict(thresholder="weighted_l1", thresholds=1e-5 * np.ones((3, 9))),
dict(thresholder="weighted_l1", thresholds=1e-5 * np.ones((1, 2))),
dict(thresholder="l2", threshold=0),
dict(thresholder="l2", threshold=1e-5),
dict(thresholder="weighted_l2", thresholds=np.zeros((3, 9))),
dict(thresholder="weighted_l2", thresholds=1e-5 * np.ones((3, 9))),
dict(thresholder="weighted_l2", thresholds=np.zeros((1, 2))),
dict(thresholder="weighted_l2", thresholds=1e-5 * np.ones((1, 2))),
],
)
def test_trapping_sr3_quadratic_library(params, trapping_sr3_params, quadratic_library):
np.random.seed(100)
x = np.random.standard_normal((100, 3))
t = np.arange(0, 1, 0.1)
x = np.exp(-t).reshape((-1, 1))
x_dot = -x
features = np.hstack([x, x**2])

params.update(trapping_sr3_params)

opt = TrappingSR3(**params)
model = SINDy(optimizer=opt, feature_library=quadratic_library)
model.fit(x)
assert opt.PL_unsym_.shape == (3, 3, 3, 9)
assert opt.PL_.shape == (3, 3, 3, 9)
assert opt.PQ_.shape == (3, 3, 3, 3, 9)
check_is_fitted(model)
opt.fit(features, x_dot)
assert opt.PL_unsym_.shape == (1, 1, 1, 2)
assert opt.PL_.shape == (1, 1, 1, 2)
assert opt.PQ_.shape == (1, 1, 1, 1, 2)
check_is_fitted(opt)

# Rerun with identity constraints
r = 3
N = 9
r = x.shape[1]
N = 2
p = r + r * (r - 1) + int(r * (r - 1) * (r - 2) / 6.0)
params["constraint_rhs"] = np.zeros(p)
params["constraint_lhs"] = np.eye(p, r * N)

opt = TrappingSR3(**params)
model = SINDy(optimizer=opt, feature_library=quadratic_library)
model.fit(x)
assert opt.PL_unsym_.shape == (3, 3, 3, 9)
assert opt.PL_.shape == (3, 3, 3, 9)
assert opt.PQ_.shape == (3, 3, 3, 3, 9)
check_is_fitted(model)
opt.fit(features, x_dot)
assert opt.PL_unsym_.shape == (1, 1, 1, 2)
assert opt.PL_.shape == (1, 1, 1, 2)
assert opt.PQ_.shape == (1, 1, 1, 1, 2)
check_is_fitted(opt)
# check is solve was infeasible first
if not np.allclose(opt.m_history_[-1], opt.m_history_[0]):
zero_inds = [0, 1, 3, 6, 9, 12, 15, 18, 21, 24]
assert np.allclose((model.coefficients().flatten())[zero_inds], 0.0, atol=1e-5)
assert np.allclose((opt.coef_.flatten())[0], 0.0, atol=1e-5)


def test_trapping_cubic_library():
Expand Down
2 changes: 1 addition & 1 deletion test/test_optimizers_complexity.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_complexity_parameter(

optimizers = [
WrappedOptimizer(opt_cls(**{reg_name: reg_value}), unbias=True)
for reg_value in [3, 1, 0.3, 0.1, 0.01]
for reg_value in [10, 1, 0.1, 0.01]
]

for opt in optimizers:
Expand Down

0 comments on commit c1bbdd7

Please sign in to comment.