From d0399138830e2293350333b74946a38f3e3f1acc Mon Sep 17 00:00:00 2001 From: Jake <37048747+Jacob-Stevens-Haas@users.noreply.github.com> Date: Thu, 29 Aug 2024 18:15:05 -0700 Subject: [PATCH] WIP: consolidate common trapping operations This also involves creating a simple EnstrophyMat class. A lot of operations carry around the precomputed (inv)root, and we don't need a whole trapping optimizer to check those operations. --- .../trapping_utils.py | 95 +++---- pysindy/optimizers/trapping_sr3.py | 231 ++++++++++-------- 2 files changed, 155 insertions(+), 171 deletions(-) diff --git a/examples/8_trapping_sindy_examples/trapping_utils.py b/examples/8_trapping_sindy_examples/trapping_utils.py index 2805da8e..419bf634 100644 --- a/examples/8_trapping_sindy_examples/trapping_utils.py +++ b/examples/8_trapping_sindy_examples/trapping_utils.py @@ -14,6 +14,12 @@ from dysts.analysis import sample_initial_conditions import pymech.neksuite as nek +from pysindy.optimizers.trapping_sr3 import EnstrophyMat +from pysindy.optimizers.trapping_sr3 import _convert_quad_terms_to_ens_basis +from pysindy.optimizers.trapping_sr3 import _create_A_symm +from pysindy.optimizers.trapping_sr3 import _permutation_asymmetry + + # Initialize quadratic SINDy library, with custom ordering # to be consistent with the constraint sindy_library = ps.PolynomialLibrary(include_bias=True) @@ -32,17 +38,12 @@ # define the objective function to be minimized by simulated annealing def obj_function(m, L_obj, Q_obj, P_obj): - lsv, sing_vals, rsv = np.linalg.svd(P_obj) - P_rt = lsv @ np.diag(np.sqrt(sing_vals)) @ rsv - P_rt_inv = lsv @ np.diag(np.sqrt(1 / sing_vals)) @ rsv - mQ_full = np.tensordot(Q_obj, m, axes=([2], [0])) + np.tensordot( - Q_obj, m, axes=([1], [0]) - ) - A_obj = L_obj + mQ_full - As = (P_rt @ A_obj @ P_rt_inv + P_rt_inv @ A_obj.T @ P_rt) / 2 + ens = EnstrophyMat(P_obj) + As = _create_A_symm(L_obj, 2 * Q_obj, m, ens) eigvals, eigvecs = np.linalg.eigh(As) return eigvals[-1] + def get_trapping_radius(max_eigval, eps_Q, d): x = Symbol("x") delta = max_eigval**2 - 4 * eps_Q * np.linalg.norm(d, 2) / 3 @@ -57,48 +58,29 @@ def get_trapping_radius(max_eigval, eps_Q, d): return rad_trap, rad_stab -def check_local_stability(Xi, sindy_opt, mean_val): - mod_matrix = sindy_opt.mod_matrix - rt_mod_mat = sindy_opt.rt_mod_mat - rt_inv_mod_mat = sindy_opt.rt_inv_mod_mat +def check_local_stability(Xi, sindy_opt: ps.TrappingSR3, mean_val): + mod_matrix = sindy_opt.enstrophy.P + rt_mod_mat = sindy_opt.enstrophy.P_root opt_m = sindy_opt.m_history_[-1] PC_tensor = sindy_opt.PC_ PL_tensor_unsym = sindy_opt.PL_unsym_ - PL_tensor = sindy_opt.PL_ PM_tensor = sindy_opt.PM_ PQ_tensor = sindy_opt.PQ_ - mPM = np.tensordot(PM_tensor, opt_m, axes=([2], [0])) - P_tensor = PL_tensor_unsym + mPM - As = np.tensordot(P_tensor, Xi, axes=([3, 2], [0, 1])) - As = (rt_mod_mat @ As @ rt_inv_mod_mat + rt_inv_mod_mat @ As.T @ rt_mod_mat) / 2 + p_As = _create_A_symm(PL_tensor_unsym, PM_tensor, opt_m, sindy_opt.enstrophy) + As = np.tensordot(p_As, Xi, axes=([3, 2], [0, 1])) eigvals, _ = np.linalg.eigh(As) print("optimal m: ", opt_m) print("As eigvals: ", np.sort(eigvals)) max_eigval = np.sort(eigvals)[-1] C = np.tensordot(PC_tensor, Xi, axes=([2, 1], [0, 1])) L = np.tensordot(PL_tensor_unsym, Xi, axes=([3, 2], [0, 1])) - Q = np.tensordot( - mod_matrix, np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1])), axes=([1], [0]) - ) - Q = (Q + np.transpose(Q, [1, 2, 0]) + np.transpose(Q, [2, 0, 1])) - Q = np.tensordot( - rt_inv_mod_mat, - np.tensordot( - rt_inv_mod_mat, - np.tensordot( - rt_inv_mod_mat, - Q, - axes=([1], [0]) - ), - axes=([0], [1]) - ), - axes=([0], [2]) - ) - # Q = np.einsum("ya,abc,bd,cf", rt_inv_mod_mat, Q, rt_inv_mod_mat, rt_inv_mod_mat) - eps_Q = np.sqrt(np.sum(Q ** 2)) - print(r'0.5 * |tilde{H}_0|_F = ', 0.5 * eps_Q) - print(r'0.5 * |tilde{H}_0|_F^2 / beta = ', 0.5 * eps_Q ** 2 / sindy_opt.beta) Q = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [0, 1])) + PQ = np.tensordot(mod_matrix, Q, axes=([1], [0])) + H0 = _permutation_asymmetry(PQ) * 3 + H0tilde = _convert_quad_terms_to_ens_basis(H0, sindy_opt.enstrophy) + eps_Q = np.sqrt(np.sum(H0tilde**2)) + print(r"0.5 * |tilde{H}_0|_F = ", 0.5 * eps_Q) + print(r"0.5 * |tilde{H}_0|_F^2 / beta = ", 0.5 * eps_Q**2 / sindy_opt.beta) d = C + np.dot(L, opt_m) + np.dot(np.tensordot(Q, opt_m, axes=([2], [0])), opt_m) d = rt_mod_mat @ d Rm, R_ls = get_trapping_radius(max_eigval, eps_Q, d) @@ -112,8 +94,8 @@ def check_local_stability(Xi, sindy_opt, mean_val): # use optimal m, calculate and plot the stability radius when the third-order # energy-preserving scheme slightly breaks -def make_trap_progress_plots(r, sindy_opt): - mod_matrix = sindy_opt.mod_matrix +def make_trap_progress_plots(r, sindy_opt: ps.TrappingSR3): + mod_matrix = sindy_opt.enstrophy.P PC_tensor = sindy_opt.PC_ PL_tensor_unsym = sindy_opt.PL_unsym_ PQ_tensor = sindy_opt.PQ_ @@ -132,42 +114,23 @@ def make_trap_progress_plots(r, sindy_opt): np.tensordot(PQ_tensor, Xi, axes=([4, 3], [1, 0])), axes=([1], [0]), ) - Q_ep = (Q + np.transpose(Q, [1, 2, 0]) + np.transpose(Q, [2, 0, 1])) - Qijk_permsum = np.tensordot( - sindy_opt.rt_inv_mod_mat, - np.tensordot( - sindy_opt.rt_inv_mod_mat, - np.tensordot( - sindy_opt.rt_inv_mod_mat, - Q_ep, - axes=([1], [0]) - ), - axes=([0], [1]) - ), - axes=([0], [2]) - ) - eps_Q = np.sqrt(np.sum(Qijk_permsum ** 2)) + Q_ep = _permutation_asymmetry(Q) * 3 + Qijk_permsum = _convert_quad_terms_to_ens_basis(Q_ep, sindy_opt.enstrophy) + eps_Q = np.sqrt(np.sum(Qijk_permsum**2)) Q = np.tensordot(PQ_tensor, Xi, axes=([4, 3], [1, 0])) d = ( C + np.dot(L, ms[i]) + np.dot(np.tensordot(Q, ms[i], axes=([2], [0])), ms[i]) ) - d = sindy_opt.rt_mod_mat @ d - delta = ( - eigs[i][-1] ** 2 - - 4 * eps_Q * np.linalg.norm(d, 2) / 3 - ) + d = sindy_opt.enstrophy.P_root @ d + delta = eigs[i][-1] ** 2 - 4 * eps_Q * np.linalg.norm(d, 2) / 3 if delta < 0: Rm = 0 DA = 0 else: - Rm = -(3 / (2 * eps_Q)) * ( - eigs[i][-1] + np.sqrt(delta) - ) - DA = (3 / (2 * eps_Q)) * ( - -eigs[i][-1] + np.sqrt(delta) - ) + Rm = -(3 / (2 * eps_Q)) * (eigs[i][-1] + np.sqrt(delta)) + DA = (3 / (2 * eps_Q)) * (-eigs[i][-1] + np.sqrt(delta)) rhos_plus.append(DA) rhos_minus.append(Rm) try: diff --git a/pysindy/optimizers/trapping_sr3.py b/pysindy/optimizers/trapping_sr3.py index 119a65d3..198ad17b 100644 --- a/pysindy/optimizers/trapping_sr3.py +++ b/pysindy/optimizers/trapping_sr3.py @@ -1,4 +1,5 @@ import warnings +from functools import partial from itertools import combinations as combo_nr from itertools import product from itertools import repeat @@ -6,6 +7,7 @@ from typing import cast from typing import NewType from typing import Optional +from typing import TypeVar from typing import Union import cvxpy as cp @@ -31,6 +33,24 @@ NTarget = NewType("NTarget", int) +class EnstrophyMat: + """Pre-compute some useful factors of an enstrophy matrix + + The matrix, root, and root inverse are frequently used in transformation + between the original and enstrophy bases + """ + + P: Float2D + P_root: Float2D + P_root_inv: Float2D + + def __init__(self, P): + self.P = P + lsv, sing_vals, rsv = np.linalg.svd(P) + self.P_root = lsv @ np.diag(np.sqrt(sing_vals)) @ rsv + self.P_root_inv = lsv @ np.diag(np.sqrt(1 / sing_vals)) @ rsv + + class TrappingSR3(ConstrainedSR3): """ Generalized trapping variant of sparse relaxed regularized regression. @@ -218,14 +238,9 @@ def __init__( ) self._n_tgts = 1 if self.mod_matrix is None: - self.mod_matrix = np.eye(self._n_tgts) - - # get U, S, V -- note that mod_matrix is positive definite so U = V - lsv, sing_vals, rsv = np.linalg.svd(self.mod_matrix) - # scipy.linalg.sqrtm - self.rt_mod_mat = lsv @ np.diag(np.sqrt(sing_vals)) @ rsv # get the square root - self.rt_inv_mod_mat = lsv @ np.diag(np.sqrt(1 / sing_vals)) @ rsv # get the inverse of the square root + mod_matrix = np.eye(self._n_tgts) + self.enstrophy = EnstrophyMat(mod_matrix) if method == "global": if hasattr(kwargs, "constraint_separation_index"): constraint_separation_index = kwargs["constraint_separation_index"] @@ -236,7 +251,7 @@ def __init__( constraint_rhs, constraint_lhs = _make_constraints( self._n_tgts, include_bias=_include_bias ) - constraint_lhs = np.tensordot(constraint_lhs, self.mod_matrix, axes=1) + constraint_lhs = np.tensordot(constraint_lhs, self.enstrophy.P, axes=1) constraint_order = kwargs.pop("constraint_order", "feature") if constraint_order == "target": constraint_lhs = np.transpose(constraint_lhs, [0, 2, 1]) @@ -482,31 +497,15 @@ def _objective(self, x, y, coef_sparse, A, PW, k): # Compute the errors sindy_loss = (y - np.dot(x, coef_sparse)) ** 2 relax_loss = (A - PW) ** 2 - Qijk = np.einsum("ya,abcde,ed", self.mod_matrix, self.PQ_, coef_sparse) - # This is H0 in the paper - Qijk_permsum = ( - Qijk + np.transpose(Qijk, [1, 2, 0]) + np.transpose(Qijk, [2, 0, 1]) - ) - # This is H0tilde in the paper -- the thing we actually need to minimize - Qijk_permsum = np.tensordot( - self.rt_inv_mod_mat, - np.tensordot( - self.rt_inv_mod_mat, - np.tensordot( - self.rt_inv_mod_mat, - Qijk_permsum, - axes=([1], [0]) - ), - axes=([0], [1]) - ), - axes=([0], [2]) - ) - # Qijk_permsum = np.einsum("ya,abc,bd,cf", self.rt_inv_mod_mat, Qijk_permsum, self.rt_inv_mod_mat, self.rt_inv_mod_mat) + Qijk = np.einsum("ya,abcde,ed", self.enstrophy.P, self.PQ_, coef_sparse) + # Qijk is H0 in the paper + Qijk_permsum = _permutation_asymmetry(Qijk) * 3 + H0tilde = _convert_quad_terms_to_ens_basis(Qijk_permsum, self.enstrophy) L1 = self.threshold * np.sum(np.abs(coef_sparse.flatten())) sindy_loss = 0.5 * np.sum(sindy_loss) relax_loss = 0.5 * np.sum(relax_loss) / self.eta nonlin_ens_loss = 0.5 * np.sum(Qijk**2) / self.alpha - cubic_ens_loss = 0.5 * np.sum(Qijk_permsum ** 2) / self.beta + cubic_ens_loss = 0.5 * np.sum(H0tilde**2) / self.beta obj = sindy_loss + relax_loss + L1 if self.method == "local": @@ -519,41 +518,19 @@ def _objective(self, x, y, coef_sparse, A, PW, k): ) return obj - def _update_coef_sparse_rs( - self, n_tgts, n_features, var_len, x_expanded, y, Pmatrix, A, coef_prev - ): + def _update_coef_sparse_rs(self, var_len, x_expanded, y, Pmatrix, A, coef_prev): """Solve coefficient update with CVXPY if threshold != 0""" xi, cost = self._create_var_and_part_cost(var_len, x_expanded, y) cost = cost + cp.sum_squares(Pmatrix @ xi - A.flatten()) / self.eta - # new terms minimizing quadratic piece ||P^Q @ xi||_2^2 if self.method == "local": - Q = np.reshape( - self.PQ_, (n_tgts * n_tgts * n_tgts, n_features * n_tgts), "F" - ) - cost = cost + 0.5 * cp.sum_squares(Q @ xi) / self.alpha - Q = np.reshape(self.PQ_, (n_tgts, n_tgts, n_tgts, n_features * n_tgts), "F") - Q = np.tensordot(self.mod_matrix, Q, axes=([1], [0])) - Q_ep = Q + np.transpose(Q, [1, 2, 0, 3]) + np.transpose(Q, [2, 0, 1, 3]) - # This is H0tilde in the paper -- the thing we actually need to minimize - Qijk_permsum = np.tensordot( - self.rt_inv_mod_mat, - np.tensordot( - self.rt_inv_mod_mat, - np.tensordot( - self.rt_inv_mod_mat, - Q_ep, - axes=([1], [0]) - ), - axes=([0], [1]) - ), - axes=([0], [2]) - ) - #Qijk_permsum = np.einsum("ya,abcd,be,cf", self.rt_inv_mod_mat, Q_ep, self.rt_inv_mod_mat, self.rt_inv_mod_mat) - Qijk_permsum = np.reshape( - Qijk_permsum, (n_tgts * n_tgts * n_tgts, n_features * n_tgts), "F" - ) - cost = cost + 0.5 * cp.sum_squares(Qijk_permsum @ xi) / self.beta + p_Q = np.reshape(self.PQ_, (-1, var_len), "F") + p_PQ = np.tensordot(self.enstrophy.P, self.PQ_, axes=([1], [0])) + p_PQ_ep = _permutation_asymmetry(p_PQ) + p_H0tilde = _convert_quad_terms_to_ens_basis(p_PQ_ep, self.enstrophy) + p_H0tilde = np.reshape(p_H0tilde, (-1, var_len), "F") + cost = cost + 0.5 * cp.sum_squares(p_Q @ xi) / self.alpha + cost = cost + 0.5 * cp.sum_squares(p_H0tilde @ xi) / self.beta return self._update_coef_cvxpy(xi, cost, var_len, coef_prev, self.eps_solver) @@ -596,29 +573,10 @@ def _update_coef_nonsparse_rs( hess += pTp / self.eta if self.method == "local": PQTPQ = np.tensordot(self.PQ_, self.PQ_, axes=([0, 1, 2], [0, 1, 2])) - # This is H0 in the paper - PQ = np.einsum("ya,abcde->ybcde", self.mod_matrix, self.PQ_) - # This is H0tilde in the paper -- the thing we actually need to minimize - PQ_ep = ( - PQ - + np.transpose(PQ, [1, 2, 0, 3, 4]) - + np.transpose(PQ, [2, 0, 1, 3, 4]) - ) - PQ_ep = np.tensordot( - self.rt_inv_mod_mat, - np.tensordot( - self.rt_inv_mod_mat, - np.tensordot( - self.rt_inv_mod_mat, - PQ_ep, - axes=([1], [0]) - ), - axes=([0], [1]) - ), - axes=([0], [2]) - ) - # np.einsum("ya,abcde,bg,ch", self.rt_inv_mod_mat, PQ_ep, self.rt_inv_mod_mat, self.rt_inv_mod_mat) - PQTPQ_ep = np.tensordot(PQ_ep, PQ_ep, axes=([0, 1, 2], [0, 1, 2])) + p_PQ = np.einsum("ya,abcde->ybcde", self.enstrophy.P, self.PQ_) + p_H0 = _permutation_asymmetry(p_PQ) * 3 + p_H0tilde = _convert_quad_terms_to_ens_basis(p_H0, self.enstrophy) + PQTPQ_ep = np.tensordot(p_H0tilde, p_H0tilde, axes=([0, 1, 2], [0, 1, 2])) hess += PQTPQ / self.alpha + PQTPQ_ep / self.beta PaTA = np.einsum("bacd,ab->cd", P_A, quad_energy_coeff_A) @@ -635,7 +593,7 @@ def _update_coef_nonsparse_rs( def _solve_m_relax_and_split( self, trap_ctr: Float1D, - A: Float2D, + prev_A: Float2D, coef_sparse: np.ndarray[tuple[NFeat, NTarget], AnyFloat], ) -> tuple[Float1D, Float2D]: """Solves the (m, A) algorithm update. @@ -646,25 +604,24 @@ def _solve_m_relax_and_split( Returns the new trap center (m) and the new A """ # prox-gradient descent for (A, m) - # Calculate projection matrix from Quad terms to As - mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) - p = self.PL_unsym_ + mPM - p = np.einsum("ya,abcd,bz->yzcd", self.rt_mod_mat, p, self.rt_inv_mod_mat) - p = (p + np.transpose(p, [1, 0, 2, 3])) / 2 - PW = np.tensordot(p, coef_sparse, axes=([3, 2], [0, 1])) - - # Calculate As and its quad term components + # Calculate As + p_AS = _create_A_symm(self.PL_unsym_, self.PM_, trap_ctr, self.enstrophy) + PW = np.tensordot(p_AS, coef_sparse, axes=([3, 2], [0, 1])) + + # Calculate quadratic terms of As as a function of m PMW = np.tensordot(self.PM_, coef_sparse, axes=([4, 3], [0, 1])) - PMW = np.einsum("ya,abc,bz->yzc", self.rt_mod_mat, PMW, self.rt_inv_mod_mat) + PMW = np.einsum( + "ya,abc,bz->yzc", self.enstrophy.P_root, PMW, self.enstrophy.P_root_inv + ) PMW = (PMW + np.transpose(PMW, [1, 0, 2])) / 2 # Calculate error in quadratic balance, and adjust trap center - A_b = (A - PW) / self.eta + A_b = (prev_A - PW) / self.eta # PQWT_PW is gradient of some loss in m PMT_PW = np.tensordot(PMW, A_b, axes=([2, 1], [0, 1])) trap_new = trap_ctr - self.alpha_m * PMT_PW # Update A - A_new = self._update_A(A - self.alpha_A * A_b, PW) + A_new = self._update_A(prev_A - self.alpha_A * A_b, PW) return trap_new, A_new def _solve_nonsparse_relax_and_split(self, hess, gradient_constant): @@ -748,22 +705,18 @@ def _reduce(self, x, y): # Begin optimization loop objective_history = [] for k in range(self.max_iter): - # update P tensor from the newest trap center - mPM = np.tensordot(self.PM_, trap_ctr, axes=([2], [0])) - P_A = np.einsum( - "ya,abcd,bz->yzcd", self.rt_mod_mat, self.PL_unsym_ + mPM, self.rt_inv_mod_mat - ) - P_A = (P_A + np.transpose(P_A, [1, 0, 2, 3])) / 2 - Pmatrix = P_A.reshape(n_tgts * n_tgts, n_tgts * n_features) - self.p_history_.append(P_A) + # update p_AS tensor from the newest trap center + p_AS = _create_A_symm(self.PL_unsym_, self.PM_, trap_ctr, self.enstrophy) + Pmatrix = p_AS.reshape(n_tgts * n_tgts, n_tgts * n_features) + self.p_history_.append(p_AS) coef_prev = coef_sparse if (self.threshold > 0.0) or self.inequality_constraints: coef_sparse = self._update_coef_sparse_rs( - n_tgts, n_features, var_len, x_expanded, y, Pmatrix, A, coef_prev + var_len, x_expanded, y, Pmatrix, A, coef_prev ) else: - coef_sparse = self._update_coef_nonsparse_rs(x, y, P_A, A) + coef_sparse = self._update_coef_nonsparse_rs(x, y, p_AS, A) # If problem over xi becomes infeasible, break out of the loop if coef_sparse is None: @@ -777,7 +730,7 @@ def _reduce(self, x, y): trap_ctr = trap_prev_ctr break self.history_.append(coef_sparse.T) - PW = np.tensordot(P_A, coef_sparse, axes=([3, 2], [0, 1])) + PW = np.tensordot(p_AS, coef_sparse, axes=([3, 2], [0, 1])) # (m,A) update finished, append the result self.m_history_.append(trap_ctr) @@ -949,3 +902,71 @@ def _equality_constrained_linlsq( inv1 = np.linalg.pinv(hess, rcond=1e-10) inv2 = np.linalg.pinv(C @ inv1 @ C.T, rcond=1e-10) return inv1 @ (grad_const + C.T @ inv2 @ (d - C @ inv1 @ grad_const)) + + +TwoOrFourD = TypeVar("TwoOrFourD", Float2D, Float4D) + + +def _create_A_symm( + L_obj: TwoOrFourD, M_obj: Float3D | Float5D, trap_ctr: Float1D, ens: EnstrophyMat +) -> TwoOrFourD: + r"""Create the enstrophy/energy growth quadratic form + + In the paper, this is :math:`A^S`. This function can be used + to create either the matrix itself or a projector from SINDy coefficient + layout to the matrix. Note that L and Q themselves are the unsymmetrized + variants. + + Args: + L_obj: The linear terms in the original differential equation. This + can either be the coefficients themselves, or a projector onto the + coefficients + M_obj: The quadratic form of the original differential equation, + plus its transpose of the 2nd and 3rd axes. See eqn 3.8 of + Schlegel and Noack 2015. This can be the quadratic form, or + a projector onto the quadratic form. If a projector, it must match + L_obj. + trap_ctr: The posited center of the trapping region. + ens: the enstrophy matrix of the system + """ + mPM = np.einsum("ijk...,k->ij...", M_obj, trap_ctr) + A = np.einsum("ya,ab...,bz->yz...", ens.P_root, L_obj + mPM, ens.P_root_inv) + A_S = (A + np.einsum("ij...->ji...", A)) / 2 + return A_S + + +Q_Arr = TypeVar("Q_Arr", Float3D, Float5D) + + +def _permutation_asymmetry(Q_obj: Q_Arr) -> Q_Arr: + r"""Calculate the permutation-asymmetric part of the first 3 axes of Q + + In the paper, this defines the directions of cubic energy growth. It is + used to create :math:`\tilde{Q}'`, its 2D flattening, :math:`H_0`, + and its enstrophy-basis (z-space) version, :math:`\tilde {H_0}` + + This works on both the true quadratic terms as well as the projector + onto the quadratic terms. + + Note: The paper uses three times this quantity. + """ + p1 = partial(np.einsum, "ijk...->jki...") + p2 = partial(np.einsum, "ijk...->kij...") + return (Q_obj + p1(Q_obj) + p2(Q_obj)) / 3 + + +def _convert_quad_terms_to_ens_basis(PQ: Q_Arr, ens: EnstrophyMat) -> Q_Arr: + r"""Convert quadratic enstrophy terms to enstrophy basis. + + In the paper, this captures the change from :math:`\tilde{Q}=PQ`, the + quadratic enstrophy terms acting on :math:`y`, to the quadratic + terms acting on :math:`z=P^{1/2}y`. It is also used to convert + the cubic enstrophy growth terms to cubic growth terms in the enstrophy + basis, i.e. :math:`\tilde {H_0}` from :math:`H_0`. + + This works on both the true quadratic terms as well as the projector + onto the quadratic terms + """ + return np.einsum( + "xa,abc...,by,cz->xyz...", ens.P_root_inv, PQ, ens.P_root_inv, ens.P_root_inv + )