Skip to content

Commit

Permalink
CLN: Extract the common parts of CVX cost function in Trapping SR3
Browse files Browse the repository at this point in the history
Also extracted relevant parts of test_trapping_inequality_constraints
  • Loading branch information
Jacob-Stevens-Haas committed Aug 15, 2023
1 parent a640b5c commit 62d4cb4
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 54 deletions.
23 changes: 10 additions & 13 deletions pysindy/optimizers/trapping_sr3.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,10 @@ def _objective(self, x, y, coef_sparse, A, PW, q):
)
return 0.5 * np.sum(R2) + 0.5 * np.sum(A2) / self.eta + L1

def _solve_sparse_relax_and_split(self, r, N, x_expanded, y, Pmatrix, A, coef_prev):
"""Solve coefficient update with CVXPY if threshold != 0"""
xi = cp.Variable(N * r)
def _create_var_and_part_cost(
self, var_len: int, x_expanded: np.ndarray, y: np.ndarray
) -> tuple[cp.Variable, cp.Expression]:
xi = cp.Variable(var_len)
cost = cp.sum_squares(x_expanded @ xi - y.flatten())
if self.thresholder.lower() == "l1":
cost = cost + self.threshold * cp.norm1(xi)
Expand All @@ -514,6 +515,11 @@ def _solve_sparse_relax_and_split(self, r, N, x_expanded, y, Pmatrix, A, coef_pr
cost = cost + self.threshold * cp.norm2(xi) ** 2
elif self.thresholder.lower() == "weighted_l2":
cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi) ** 2
return xi, cost

def _solve_sparse_relax_and_split(self, r, N, x_expanded, y, Pmatrix, A, coef_prev):
"""Solve coefficient update with CVXPY if threshold != 0"""
xi, cost = self._create_var_and_part_cost(N * r, x_expanded, y)
cost = cost + cp.sum_squares(Pmatrix @ xi - A.flatten()) / self.eta
if self.use_constraints:
if self.inequality_constraints:
Expand Down Expand Up @@ -612,16 +618,7 @@ def _solve_direct_cvxpy(self, r, N, x_expanded, y, Pmatrix, coef_prev):
problem, solved in CVXPY, so convergence/quality guarantees are
not available here!
"""
xi = cp.Variable(N * r)
cost = cp.sum_squares(x_expanded @ xi - y.flatten())
if self.thresholder.lower() == "l1":
cost = cost + self.threshold * cp.norm1(xi)
elif self.thresholder.lower() == "weighted_l1":
cost = cost + cp.norm1(np.ravel(self.thresholds) @ xi)
elif self.thresholder.lower() == "l2":
cost = cost + self.threshold * cp.norm2(xi) ** 2
elif self.thresholder.lower() == "weighted_l2":
cost = cost + cp.norm2(np.ravel(self.thresholds) @ xi) ** 2
xi, cost = self._create_var_and_part_cost(N * r, x_expanded, y)
cost = cost + cp.lambda_max(cp.reshape(Pmatrix @ xi, (r, r))) / self.eta
if self.use_constraints:
if self.inequality_constraints:
Expand Down
68 changes: 27 additions & 41 deletions test/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,19 +893,31 @@ def test_constrained_inequality_constraints(data_lorenz, params):
@pytest.mark.parametrize(
"params",
[
dict(thresholder="l1", threshold=0.0005),
dict(thresholder="weighted_l1", thresholds=0.0005 * np.ones((3, 9))),
dict(thresholder="l2", threshold=0.0005),
dict(thresholder="weighted_l2", thresholds=0.0005 * np.ones((3, 9))),
dict(thresholder="l1", threshold=2, expected=2.5),
dict(thresholder="weighted_l1", thresholds=0.5 * np.ones((1, 2)), expected=1.0),
dict(thresholder="l2", threshold=2, expected=1.5),
dict(
thresholder="weighted_l2", thresholds=0.5 * np.ones((1, 2)), expected=0.75
),
],
)
def test_trapping_inequality_constraints(data_lorenz, params):
x, t = data_lorenz
constraint_rhs = np.array([-10.0, 28.0])
constraint_matrix = np.zeros((2, 27))
constraint_matrix[0, 0] = 1.0
constraint_matrix[1, 10] = 1.0
feature_names = ["x", "y", "z"]
def test_trapping_cost_function(params):
expected = params.pop("expected")
opt = TrappingSR3(inequality_constraints=True, relax_optim=True, **params)
x = np.eye(2)
y = np.ones(2)
xi, cost = opt._create_var_and_part_cost(2, x, y)
xi.value = np.array([0.5, 0.5])
np.testing.assert_allclose(cost.value, expected)


def test_trapping_inequality_constraints():
t = np.arange(0, 1, 0.1)
x = np.stack((t, t**2)).T
y = x[:, 0] + 0.1 * x[:, 1]
constraint_rhs = np.array([0.1])
constraint_matrix = np.zeros((1, 2))
constraint_matrix[0, 1] = 0.1

# Run Trapping SR3 without CVXPY for the m solve
opt = TrappingSR3(
Expand All @@ -914,45 +926,19 @@ def test_trapping_inequality_constraints(data_lorenz, params):
constraint_order="feature",
inequality_constraints=True,
relax_optim=True,
**params,
)
poly_lib = PolynomialLibrary(degree=2, include_bias=False)
model = SINDy(
optimizer=opt,
feature_library=poly_lib,
feature_names=feature_names,
)
model.fit(x, t=t[1] - t[0])
# This sometimes fails with L2 norm so just check the model is fitted
check_is_fitted(model)

opt.fit(x, y)
assert np.all(np.dot(constraint_matrix, (opt.coef_).flatten()) <= constraint_rhs)
# Run Trapping SR3 with CVXPY for the m solve
opt = TrappingSR3(
constraint_lhs=constraint_matrix,
constraint_rhs=constraint_rhs,
constraint_order="feature",
inequality_constraints=True,
relax_optim=False,
**params,
)
model = SINDy(
optimizer=opt,
feature_library=poly_lib,
differentiation_method=FiniteDifference(drop_endpoints=True),
feature_names=feature_names,
)
model.fit(x, t=t[1] - t[0])

# This sometimes fails with L2 norm or different versions of CVXPY
# so just check the model is fitted
check_is_fitted(model)
# assert np.all(
# np.dot(constraint_matrix, (model.coefficients()).flatten()) <= constraint_rhs
# ) or np.allclose(
# np.dot(constraint_matrix, (model.coefficients()).flatten()),
# constraint_rhs,
# atol=1e-3,
# )
opt.fit(x, y)
assert np.all(np.dot(constraint_matrix, (opt.coef_).flatten()) <= constraint_rhs)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 62d4cb4

Please sign in to comment.