Skip to content

Commit

Permalink
NBK: convert callers to 3-arg check_local_stability()
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob-Stevens-Haas committed Aug 30, 2024
1 parent c8cd72d commit c1bcc28
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 19 deletions.
14 changes: 7 additions & 7 deletions examples/8_trapping_sindy_examples/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@
"print(\"Frobenius Error = \", E_pred)\n",
"mean_val = np.mean(x_test_pred, axis=0)\n",
"mean_val = np.sqrt(np.sum(mean_val**2))\n",
"check_local_stability(r, Xi, sindy_opt, mean_val)\n",
"check_local_stability(Xi, sindy_opt, mean_val)\n",
"\n",
"# compute relative Frobenius error in the model coefficients\n",
"terms = sindy_library.get_feature_names()\n",
Expand Down Expand Up @@ -433,7 +433,7 @@
"print(\"Frobenius error = \", E_pred)\n",
"mean_val = np.mean(x_test_pred, axis=0)\n",
"mean_val = np.sqrt(np.sum(mean_val**2))\n",
"check_local_stability(r, Xi, sindy_opt, mean_val)\n",
"check_local_stability(Xi, sindy_opt, mean_val)\n",
"\n",
"# compute relative Frobenius error in the model coefficients\n",
"terms = sindy_library.get_feature_names()\n",
Expand Down Expand Up @@ -694,7 +694,7 @@
"\n",
"mean_val = np.mean(x_test_pred, axis=0)\n",
"mean_val = np.sqrt(np.sum(mean_val**2))\n",
"check_local_stability(r, Xi, sindy_opt, mean_val)\n",
"check_local_stability(Xi, sindy_opt, mean_val)\n",
"E_pred = np.linalg.norm(x_test - x_test_pred) / np.linalg.norm(x_test)\n",
"print(\"Frobenius error = \", E_pred)\n",
"\n",
Expand Down Expand Up @@ -924,7 +924,7 @@
"make_lissajou(r, x_train, x_test, x_train_pred, x_test_pred, \"mhd\")\n",
"mean_val = np.mean(x_test_pred, axis=0)\n",
"mean_val = np.sqrt(np.sum(mean_val**2))\n",
"check_local_stability(r, Xi, sindy_opt, mean_val)\n",
"check_local_stability(Xi, sindy_opt, mean_val)\n",
"E_pred = np.linalg.norm(x_test - x_test_pred) / np.linalg.norm(x_test)\n",
"print(E_pred)\n",
"\n",
Expand Down Expand Up @@ -1355,7 +1355,7 @@
"Q = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1]))\n",
"Q_sum = np.max(np.abs((Q + np.transpose(Q, [1, 2, 0]) + np.transpose(Q, [2, 0, 1]))))\n",
"print(\"Max deviation from the constraints = \", Q_sum)\n",
"if check_local_stability(r, Xi, sindy_opt, 1):\n",
"if check_local_stability(Xi, sindy_opt, 1):\n",
" x_train_pred = model.simulate(x_train[0, :], t, integrator_kws=integrator_keywords)\n",
" x_test_pred = model.simulate(a0, t, integrator_kws=integrator_keywords)\n",
" make_progress_plots(r, sindy_opt)\n",
Expand All @@ -1365,7 +1365,7 @@
" make_lissajou(r, x_train, x_test, x_train_pred, x_test_pred, \"VonKarman\")\n",
" mean_val = np.mean(x_test_pred, axis=0)\n",
" mean_val = np.sqrt(np.sum(mean_val**2))\n",
" check_local_stability(r, Xi, sindy_opt, mean_val)\n",
" check_local_stability(Xi, sindy_opt, mean_val)\n",
" A_guess = sindy_opt.A_history_[-1]\n",
" m_guess = sindy_opt.m_history_[-1]\n",
" E_pred = np.linalg.norm(x_test - x_test_pred) / np.linalg.norm(x_test)\n",
Expand Down Expand Up @@ -1462,7 +1462,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions examples/8_trapping_sindy_examples/example_dysts.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@
plt.grid(True)
plt.legend()

check_local_stability(r, Xi, sindy_opt, 1.0)
check_local_stability(Xi, sindy_opt, 1.0)
Xi_true = (true_coefficients[i].T)[: Xi.shape[0], :]

# run simulated annealing on the true system to make sure the system is amenable to trapping theorem
Expand Down Expand Up @@ -356,7 +356,7 @@
x_test_pred = model.simulate(x_test[0, :], t, integrator_kws=integrator_keywords)

# Check stability and try simulated annealing with the IDENTIFIED model
check_local_stability(r, Xi, sindy_opt, 1.0)
check_local_stability(Xi, sindy_opt, 1.0)
PL_tensor = sindy_opt.PL_
PM_tensor = sindy_opt.PM_
L = np.tensordot(PL_tensor, Xi, axes=([3, 2], [0, 1]))
Expand Down Expand Up @@ -483,7 +483,7 @@ def rhs(t, x):

model.fit(x_train, t=t_train)
Xi = model.coefficients().T
check_local_stability(r, Xi, sindy_opt, 1.0)
check_local_stability(Xi, sindy_opt, 1.0)

# Fit a baseline model -- this is almost always an unstable model!
model_baseline = ps.SINDy(
Expand Down
12 changes: 6 additions & 6 deletions examples/8_trapping_sindy_examples/trapping_extended.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@
mean_val = np.sqrt(np.sum(mean_val**2))
E_pred = np.linalg.norm(x_test - x_test_pred) / np.linalg.norm(x_test)
print("Frobenius error = ", E_pred)
check_local_stability(r, Xi, sindy_opt, mean_val)
check_local_stability(Xi, sindy_opt, mean_val)

# compute relative Frobenius error in the model coefficients
sigma = 10
Expand Down Expand Up @@ -327,7 +327,7 @@
Qenergy = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1]))
mean_val = np.mean(x_test_pred, axis=0)
mean_val = np.sqrt(np.sum(mean_val**2))
check_local_stability(r, Xi, sindy_opt, mean_val)
check_local_stability(Xi, sindy_opt, mean_val)
Q = np.tensordot(sindy_opt.PQ_, Xi, axes=([4, 3], [0, 1]))

# %% [markdown]
Expand Down Expand Up @@ -358,7 +358,7 @@
mean_val = np.sqrt(np.sum(mean_val**2))
E_pred = np.linalg.norm(x_test - x_test_pred) / np.linalg.norm(x_test)
print("Frobenius error = ", E_pred)
check_local_stability(r, Xi, sindy_opt, mean_val)
check_local_stability(Xi, sindy_opt, mean_val)

# compute relative Frobenius error in the model coefficients
coef_pred = np.linalg.norm(Xi_lorenz - Xi) / np.linalg.norm(Xi_lorenz)
Expand Down Expand Up @@ -412,7 +412,7 @@
Qenergy = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1]))
mean_val = np.mean(x_test_pred, axis=0)
mean_val = np.sqrt(np.sum(mean_val**2))
check_local_stability(r, Xi, sindy_opt, mean_val)
check_local_stability(Xi, sindy_opt, mean_val)
Q = np.tensordot(sindy_opt.PQ_, Xi, axes=([4, 3], [0, 1]))
print(
"Maximum deviation from having zero totally symmetric part: ",
Expand Down Expand Up @@ -473,7 +473,7 @@
Qenergy = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1]))
mean_val = np.mean(x_test_pred, axis=0)
mean_val = np.sqrt(np.sum(mean_val**2))
check_local_stability(r, Xi, sindy_opt, mean_val)
check_local_stability(Xi, sindy_opt, mean_val)
Q = np.tensordot(sindy_opt.PQ_, Xi, axes=([4, 3], [0, 1]))
print(
"Maximum deviation from having zero totally symmetric part: ",
Expand Down Expand Up @@ -512,7 +512,7 @@
mean_val = np.sqrt(np.sum(mean_val**2))
E_pred = np.linalg.norm(x_test - x_test_pred) / np.linalg.norm(x_test)
print("Frobenius error = ", E_pred)
check_local_stability(r, Xi, sindy_opt, mean_val)
check_local_stability(Xi, sindy_opt, mean_val)

# compute relative Frobenius error in the model coefficients
coef_pred = np.linalg.norm(Xi_lorenz - Xi) / np.linalg.norm(Xi_lorenz)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@
"Qenstrophy = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1]))\n",
"mean_val = np.mean(a, axis=0)\n",
"mean_val = np.sqrt(np.sum(mean_val**2))\n",
"check_local_stability(r, Xi, sindy_opt, mean_val)\n",
"check_local_stability(Xi, sindy_opt, mean_val)\n",
"enstrophy_model = model\n",
"# Q = np.tensordot(sindy_opt.PQ_, Xi, axes=([4, 3], [0, 1]))\n",
"# Q_sum = np.max(np.abs((Q + np.transpose(Q, [1, 2, 0]) + np.transpose(Q, [2, 0, 1]))))\n",
Expand Down Expand Up @@ -837,7 +837,7 @@
"Qenstrophy = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1]))\n",
"mean_val = np.mean(a, axis=0)\n",
"mean_val = np.sqrt(np.sum(mean_val**2))\n",
"check_local_stability(r, Xi, sindy_opt, mean_val)\n",
"check_local_stability(Xi, sindy_opt, mean_val)\n",
"enstrophy_model = model\n",
"Q = np.tensordot(sindy_opt.PQ_, Xi, axes=([4, 3], [0, 1]))\n",
"Q = np.tensordot(mod_matrix, Q, axes=([1], [0]))\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@
mean_val = np.sqrt(np.sum(mean_val**2))
check_stability(r, Xi, sindy_opt, mean_val, mod_matrix)

check_local_stability(r, Xi, sindy_opt, mean_val, mod_matrix)
check_local_stability(Xi, sindy_opt, mean_val, mod_matrix)
enstrophy_model = model
Q = np.tensordot(sindy_opt.PQ_, Xi, axes=([4, 3], [0, 1]))
Q = np.tensordot(mod_matrix, Q, axes=([1], [0]))
Expand Down

0 comments on commit c1bcc28

Please sign in to comment.