diff --git a/docs/source/simsopt.geo.rst b/docs/source/simsopt.geo.rst index 83ad9655c..9bb006c3c 100644 --- a/docs/source/simsopt.geo.rst +++ b/docs/source/simsopt.geo.rst @@ -76,6 +76,14 @@ simsopt.geo.finitebuild module :undoc-members: :show-inheritance: +simsopt.geo.framedcurve module +------------------------------ + +.. automodule:: simsopt.geo.framedcurve + :members: + :undoc-members: + :show-inheritance: + simsopt.geo.jit module ---------------------- @@ -108,6 +116,14 @@ simsopt.geo.qfmsurface module :undoc-members: :show-inheritance: +simsopt.geo.strain_optimization module +-------------------------------------- + +.. automodule:: simsopt.geo.strain_optimization + :members: + :undoc-members: + :show-inheritance: + simsopt.geo.surface module -------------------------- diff --git a/examples/2_Intermediate/stage_two_optimization.py b/examples/2_Intermediate/stage_two_optimization.py index 66a500613..3dd147fcf 100755 --- a/examples/2_Intermediate/stage_two_optimization.py +++ b/examples/2_Intermediate/stage_two_optimization.py @@ -25,7 +25,6 @@ from pathlib import Path import numpy as np from scipy.optimize import minimize - from simsopt.field import BiotSavart, Current, coils_via_symmetries from simsopt.geo import (SurfaceRZFourier, curves_to_vtk, create_equally_spaced_curves, CurveLength, CurveCurveDistance, MeanSquaredCurvature, diff --git a/examples/2_Intermediate/strain_optimization.py b/examples/2_Intermediate/strain_optimization.py new file mode 100755 index 000000000..cb04ba2b7 --- /dev/null +++ b/examples/2_Intermediate/strain_optimization.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python + +""" +This script performs an optimization of the HTS tape winding angle +with respect to binormal curvature and torsional strain cost functions as defined in + + Paz Soldan, "Non-planar coil winding angle optimization for compatibility with + non-insulated high-temperature superconducting magnets", Journal of Plasma Physics + 86 (2020), doi:10.1017/S0022377820001208. + +The orientation of the tape is defined with respect to the Frenet-Serret Frame +""" + +import numpy as np +from scipy.optimize import minimize +from simsopt.geo import CoilStrain, LPTorsionalStrainPenalty, LPBinormalCurvatureStrainPenalty +from simsopt.geo import FrameRotation, FramedCurveFrenet, CurveXYZFourier +from simsopt.configs import get_hsx_data +from simsopt.util import in_github_actions + +MAXITER = 50 if in_github_actions else 400 + +curves, currents, ma = get_hsx_data(Nt_coils=10, ppp=10) +curve = curves[1] +scale_factor = 0.1 +curve_scaled = CurveXYZFourier(curve.quadpoints, curve.order) +curve_scaled.x = curve.x * scale_factor # scale coil to magnify the strains +rot_order = 10 # order of the Fourier expression for the rotation of the filament pack +width = 1e-3 # tape width + +curve_scaled.fix_all() # fix curve DOFs -> only optimize winding angle +rotation = FrameRotation(curve_scaled.quadpoints, rot_order) + +framedcurve = FramedCurveFrenet(curve_scaled, rotation) + +tor_threshold = 0.02 # Threshold for strain parameters +cur_threshold = 0.02 + +Jtor = LPTorsionalStrainPenalty(framedcurve, p=2, threshold=tor_threshold) +Jbin = LPBinormalCurvatureStrainPenalty( + framedcurve, p=2, threshold=cur_threshold) + +strain = CoilStrain(framedcurve, width) +JF = Jtor + Jbin + + +def fun(dofs): + JF.x = dofs + J = JF.J() + grad = JF.dJ() + outstr = f"Max torsional strain={np.max(strain.torsional_strain()):.1e}, Max curvature strain={np.max(strain.binormal_curvature_strain()):.1e}" + print(outstr) + return J, grad + + +f = fun +dofs = JF.x + +res = minimize(fun, dofs, jac=True, method='L-BFGS-B', + options={'maxiter': MAXITER, 'maxcor': 10, 'gtol': 1e-20, 'ftol': 1e-20}, tol=1e-20) diff --git a/examples/run_serial_examples b/examples/run_serial_examples index 9930f9b94..163ff6ad4 100755 --- a/examples/run_serial_examples +++ b/examples/run_serial_examples @@ -18,6 +18,7 @@ set -ex ./2_Intermediate/stage_two_optimization.py ./2_Intermediate/stage_two_optimization_stochastic.py ./2_Intermediate/stage_two_optimization_finite_beta.py +./2_Intermediate/strain_optimization.py ./2_Intermediate/permanent_magnet_MUSE.py ./2_Intermediate/permanent_magnet_QA.py ./2_Intermediate/permanent_magnet_PM4Stell.py diff --git a/src/simsopt/field/tracing.py b/src/simsopt/field/tracing.py index 21b34a019..7cf0c4020 100644 --- a/src/simsopt/field/tracing.py +++ b/src/simsopt/field/tracing.py @@ -798,7 +798,8 @@ class IterationStoppingCriterion(sopp.IterationStoppingCriterion): pass -def plot_poincare_data(fieldlines_phi_hits, phis, filename, mark_lost=False, aspect='equal', dpi=300, xlims=None, ylims=None, surf=None): +def plot_poincare_data(fieldlines_phi_hits, phis, filename, mark_lost=False, aspect='equal', dpi=300, xlims=None, + ylims=None, surf=None, s=2, marker='o'): """ Create a poincare plot. Usage: diff --git a/src/simsopt/geo/__init__.py b/src/simsopt/geo/__init__.py index de9ce4046..ad270e8e5 100644 --- a/src/simsopt/geo/__init__.py +++ b/src/simsopt/geo/__init__.py @@ -8,7 +8,7 @@ from .curvexyzfourier import * from .curveperturbed import * from .curveobjectives import * - +from .framedcurve import * from .finitebuild import * from .plotting import * @@ -21,6 +21,7 @@ from .surfacerzfourier import * from .surfacexyzfourier import * from .surfacexyztensorfourier import * +from .strain_optimization import * from .permanent_magnet_grid import * from .current_voxels_grid import * @@ -35,4 +36,5 @@ surfacerzfourier.__all__ + surfacexyzfourier.__all__ + surfacexyztensorfourier.__all__ + surfaceobjectives.__all__ + permanent_magnet_grid.__all__ + + strain_optimization.__all__ + framedcurve.__all__ + current_voxels_grid.__all__) diff --git a/src/simsopt/geo/finitebuild.py b/src/simsopt/geo/finitebuild.py index 739cf3df1..87a4cc60b 100644 --- a/src/simsopt/geo/finitebuild.py +++ b/src/simsopt/geo/finitebuild.py @@ -1,85 +1,28 @@ import numpy as np -import jax.numpy as jnp -from jax import vjp, jvp -import simsoptpp as sopp -from .._core.optimizable import Optimizable -from .._core.derivative import Derivative -from .curve import Curve -from .jit import jit +from .framedcurve import FramedCurve, FrameRotation, ZeroRotation, FramedCurveCentroid, FramedCurveFrenet """ The functions and classes in this model are used to deal with multifilament approximation of finite build coils. """ -__all__ = ['create_multifilament_grid', - 'CurveFilament', 'FilamentRotation', 'ZeroRotation'] +__all__ = ['create_multifilament_grid', 'CurveFilament'] -def create_multifilament_grid(curve, numfilaments_n, numfilaments_b, gapsize_n, gapsize_b, rotation_order=None, rotation_scaling=None): - """ - Create a regular grid of ``numfilaments_n * numfilaments_b`` many - filaments to approximate a finite-build coil. - - Note that "normal" and "binormal" in the function arguments here - refer not to the Frenet frame but rather to the "coil centroid - frame" defined by Singh et al., before rotation. - - Args: - curve: The underlying curve. - numfilaments_n: number of filaments in normal direction. - numfilaments_b: number of filaments in bi-normal direction. - gapsize_n: gap between filaments in normal direction. - gapsize_b: gap between filaments in bi-normal direction. - rotation_order: Fourier order (maximum mode number) to use in the expression for the rotation - of the filament pack. ``None`` means that the rotation is not optimized. - rotation_scaling: scaling for the rotation degrees of freedom. good - scaling improves the convergence of first order optimization - algorithms. If ``None``, then the default of ``1 / max(gapsize_n, gapsize_b)`` - is used. - """ - if numfilaments_n % 2 == 1: - shifts_n = np.arange(numfilaments_n) - numfilaments_n//2 - else: - shifts_n = np.arange(numfilaments_n) - numfilaments_n/2 + 0.5 - shifts_n = shifts_n * gapsize_n - if numfilaments_b % 2 == 1: - shifts_b = np.arange(numfilaments_b) - numfilaments_b//2 - else: - shifts_b = np.arange(numfilaments_b) - numfilaments_b/2 + 0.5 - shifts_b = shifts_b * gapsize_b - - if rotation_scaling is None: - rotation_scaling = 1/max(gapsize_n, gapsize_b) - if rotation_order is None: - rotation = ZeroRotation(curve.quadpoints) - else: - rotation = FilamentRotation(curve.quadpoints, rotation_order, scale=rotation_scaling) - filaments = [] - for i in range(numfilaments_n): - for j in range(numfilaments_b): - filaments.append(CurveFilament(curve, shifts_n[i], shifts_b[j], rotation)) - return filaments +class CurveFilament(FramedCurve): + def __init__(self, framedcurve, dn, db): + """ + Given a FramedCurve, defining a normal and + binormal vector, create a grid of curves by shifting + along the normal and binormal vector. -class CurveFilament(sopp.Curve, Curve): + The idea is explained well in Figure 1 in the reference: - def __init__(self, curve, dn, db, rotation=None): - """ - Implementation of the centroid frame introduced in Singh et al, "Optimization of finite-build stellarator coils", Journal of Plasma Physics 86 (2020), - doi:10.1017/S0022377820000756. Given a curve, one defines a normal and - binormal vector and then creates a grid of curves by shifting along the - normal and binormal vector. In addition, we specify an angle along the - curve that allows us to optimise for the rotation of the winding pack. - - The idea is explained well in Figure 1 in the reference above. - - Note that "normal" and "binormal" in the function arguments here - refer not to the Frenet frame but rather to the "coil centroid - frame" defined by Singh et al., before rotation. + doi:10.1017/S0022377820000756. Args: curve: the underlying curve @@ -87,18 +30,12 @@ def __init__(self, curve, dn, db, rotation=None): db: how far to move in binormal direction rotation: angle along the curve to rotate the frame. """ - self.curve = curve - sopp.Curve.__init__(self, curve.quadpoints) - deps = [curve] - if rotation is not None: - deps.append(rotation) - Curve.__init__(self, depends_on=deps) - self.curve = curve + self.curve = framedcurve.curve self.dn = dn self.db = db - if rotation is None: - rotation = ZeroRotation(curve.quadpoints) - self.rotation = rotation + self.rotation = framedcurve.rotation + self.framedcurve = framedcurve + FramedCurve.__init__(self, self.curve, self.rotation) def recompute_bell(self, parent=None): self.invalidate_cache() @@ -106,195 +43,93 @@ def recompute_bell(self, parent=None): def gamma_impl(self, gamma, quadpoints): assert quadpoints.shape[0] == self.curve.quadpoints.shape[0] assert np.linalg.norm(quadpoints - self.curve.quadpoints) < 1e-15 - c = self.curve - t, n, b = rotated_centroid_frame(c.gamma(), c.gammadash(), self.rotation.alpha(c.quadpoints)) + t, n, b = self.framedcurve.rotated_frame() gamma[:] = self.curve.gamma() + self.dn * n + self.db * b def gammadash_impl(self, gammadash): - c = self.curve - td, nd, bd = rotated_centroid_frame_dash( - c.gamma(), c.gammadash(), c.gammadashdash(), - self.rotation.alpha(c.quadpoints), self.rotation.alphadash(c.quadpoints) - ) + td, nd, bd = self.framedcurve.rotated_frame_dash() gammadash[:] = self.curve.gammadash() + self.dn * nd + self.db * bd def dgamma_by_dcoeff_vjp(self, v): - g = self.curve.gamma() - gd = self.curve.gammadash() - a = self.rotation.alpha(self.curve.quadpoints) - zero = np.zeros_like(v) - vg = rotated_centroid_frame_dcoeff_vjp0(g, gd, a, (zero, self.dn*v, self.db*v)) - vgd = rotated_centroid_frame_dcoeff_vjp1(g, gd, a, (zero, self.dn*v, self.db*v)) - va = rotated_centroid_frame_dcoeff_vjp2(g, gd, a, (zero, self.dn*v, self.db*v)) - return self.curve.dgamma_by_dcoeff_vjp(v + vg) \ + vg = self.framedcurve.rotated_frame_dcoeff_vjp(v, self.dn, self.db, 0) + vgd = self.framedcurve.rotated_frame_dcoeff_vjp(v, self.dn, self.db, 1) + vgdd = self.framedcurve.rotated_frame_dcoeff_vjp(v, self.dn, self.db, 2) + va = self.framedcurve.rotated_frame_dcoeff_vjp(v, self.dn, self.db, 3) + out = self.curve.dgamma_by_dcoeff_vjp(v + vg) \ + self.curve.dgammadash_by_dcoeff_vjp(vgd) \ - + self.rotation.dalpha_by_dcoeff_vjp(self.curve.quadpoints, va) + + self.rotation.dalpha_by_dcoeff_vjp(self.curve.quadpoints, va) + if vgdd is not None: + out += self.curve.dgammadashdash_by_dcoeff_vjp(vgdd) + return out def dgammadash_by_dcoeff_vjp(self, v): - g = self.curve.gamma() - gd = self.curve.gammadash() - gdd = self.curve.gammadashdash() - a = self.rotation.alpha(self.curve.quadpoints) - ad = self.rotation.alphadash(self.curve.quadpoints) - zero = np.zeros_like(v) - vg = rotated_centroid_frame_dash_dcoeff_vjp0(g, gd, gdd, a, ad, (zero, self.dn*v, self.db*v)) - vgd = rotated_centroid_frame_dash_dcoeff_vjp1(g, gd, gdd, a, ad, (zero, self.dn*v, self.db*v)) - vgdd = rotated_centroid_frame_dash_dcoeff_vjp2(g, gd, gdd, a, ad, (zero, self.dn*v, self.db*v)) - va = rotated_centroid_frame_dash_dcoeff_vjp3(g, gd, gdd, a, ad, (zero, self.dn*v, self.db*v)) - vad = rotated_centroid_frame_dash_dcoeff_vjp4(g, gd, gdd, a, ad, (zero, self.dn*v, self.db*v)) - return self.curve.dgamma_by_dcoeff_vjp(vg) \ + vg = self.framedcurve.rotated_frame_dash_dcoeff_vjp(v, self.dn, self.db, 0) + vgd = self.framedcurve.rotated_frame_dash_dcoeff_vjp(v, self.dn, self.db, 1) + vgdd = self.framedcurve.rotated_frame_dash_dcoeff_vjp(v, self.dn, self.db, 2) + vgddd = self.framedcurve.rotated_frame_dash_dcoeff_vjp(v, self.dn, self.db, 3) + va = self.framedcurve.rotated_frame_dash_dcoeff_vjp(v, self.dn, self.db, 4) + vad = self.framedcurve.rotated_frame_dash_dcoeff_vjp(v, self.dn, self.db, 5) + out = self.curve.dgamma_by_dcoeff_vjp(vg) \ + self.curve.dgammadash_by_dcoeff_vjp(v+vgd) \ + self.curve.dgammadashdash_by_dcoeff_vjp(vgdd) \ + self.rotation.dalpha_by_dcoeff_vjp(self.curve.quadpoints, va) \ + self.rotation.dalphadash_by_dcoeff_vjp(self.curve.quadpoints, vad) + if vgddd is not None: + out += self.curve.dgammadashdashdash_by_dcoeff_vjp(vgddd) + return out -class FilamentRotation(Optimizable): - - def __init__(self, quadpoints, order, scale=1., dofs=None): - """ - The rotation of the multifilament pack; alpha in Figure 1 of - Singh et al, "Optimization of finite-build stellarator coils", - Journal of Plasma Physics 86 (2020), - doi:10.1017/S0022377820000756 - """ - self.order = order - if dofs is None: - super().__init__(x0=np.zeros((2*order+1, ))) - else: - super().__init__(dofs=dofs) - self.quadpoints = quadpoints - self.scale = scale - self.jac = rotation_dcoeff(quadpoints, order) - self.jacdash = rotationdash_dcoeff(quadpoints, order) - self.jax_alpha = jit(lambda dofs, points: jaxrotation_pure(dofs, points, self.order)) - self.jax_alphadash = jit(lambda dofs, points: jaxrotationdash_pure(dofs, points, self.order)) - - def alpha(self, quadpoints): - return self.scale * self.jax_alpha(self._dofs.full_x, quadpoints) - - def alphadash(self, quadpoints): - return self.scale * self.jax_alphadash(self._dofs.full_x, quadpoints) - - def dalpha_by_dcoeff_vjp(self, quadpoints, v): - return Derivative({self: self.scale * sopp.vjp(v, self.jac)}) - - def dalphadash_by_dcoeff_vjp(self, quadpoints, v): - return Derivative({self: self.scale * sopp.vjp(v, self.jacdash)}) - - -class ZeroRotation(Optimizable): - - def __init__(self, quadpoints): - """ - Dummy class that just returns zero for the rotation angle. Equivalent to using - - .. code-block:: python - - rot = FilamentRotation(...) - rot.fix_all() - - """ - super().__init__() - self.zero = np.zeros((quadpoints.size, )) - - def alpha(self, quadpoints): - return self.zero - - def alphadash(self, quadpoints): - return self.zero - - def dalpha_by_dcoeff_vjp(self, quadpoints, v): - return Derivative({}) - - def dalphadash_by_dcoeff_vjp(self, quadpoints, v): - return Derivative({}) - - -@jit -def rotated_centroid_frame(gamma, gammadash, alpha): - t = gammadash - t *= 1./jnp.linalg.norm(gammadash, axis=1)[:, None] - R = jnp.mean(gamma, axis=0) # centroid - delta = gamma - R[None, :] - n = delta - jnp.sum(delta * t, axis=1)[:, None] * t - n *= 1./jnp.linalg.norm(n, axis=1)[:, None] - b = jnp.cross(t, n, axis=1) - - # now rotate the frame by alpha - nn = jnp.cos(alpha)[:, None] * n - jnp.sin(alpha)[:, None] * b - bb = jnp.sin(alpha)[:, None] * n + jnp.cos(alpha)[:, None] * b - return t, nn, bb - - -rotated_centroid_frame_dash = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash: jvp(rotated_centroid_frame, - (gamma, gammadash, alpha), - (gammadash, gammadashdash, alphadash))[1]) - -rotated_centroid_frame_dcoeff_vjp0 = jit( - lambda gamma, gammadash, alpha, v: vjp( - lambda g: rotated_centroid_frame(g, gammadash, alpha), gamma)[1](v)[0]) - -rotated_centroid_frame_dcoeff_vjp1 = jit( - lambda gamma, gammadash, alpha, v: vjp( - lambda gd: rotated_centroid_frame(gamma, gd, alpha), gammadash)[1](v)[0]) - -rotated_centroid_frame_dcoeff_vjp2 = jit( - lambda gamma, gammadash, alpha, v: vjp( - lambda a: rotated_centroid_frame(gamma, gammadash, a), alpha)[1](v)[0]) - -rotated_centroid_frame_dash_dcoeff_vjp0 = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( - lambda g: rotated_centroid_frame_dash(g, gammadash, gammadashdash, alpha, alphadash), gamma)[1](v)[0]) - -rotated_centroid_frame_dash_dcoeff_vjp1 = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( - lambda gd: rotated_centroid_frame_dash(gamma, gd, gammadashdash, alpha, alphadash), gammadash)[1](v)[0]) - -rotated_centroid_frame_dash_dcoeff_vjp2 = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( - lambda gdd: rotated_centroid_frame_dash(gamma, gammadash, gdd, alpha, alphadash), gammadashdash)[1](v)[0]) - -rotated_centroid_frame_dash_dcoeff_vjp3 = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( - lambda a: rotated_centroid_frame_dash(gamma, gammadash, gammadashdash, a, alphadash), alpha)[1](v)[0]) - -rotated_centroid_frame_dash_dcoeff_vjp4 = jit( - lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( - lambda ad: rotated_centroid_frame_dash(gamma, gammadash, gammadashdash, alpha, ad), alphadash)[1](v)[0]) - - -def jaxrotation_pure(dofs, points, order): - rotation = jnp.zeros((len(points), )) - rotation += dofs[0] - for j in range(1, order+1): - rotation += dofs[2*j-1] * jnp.sin(2*np.pi*j*points) - rotation += dofs[2*j] * jnp.cos(2*np.pi*j*points) - return rotation - +def create_multifilament_grid(curve, numfilaments_n, numfilaments_b, gapsize_n, gapsize_b, + rotation_order=None, rotation_scaling=None, frame='centroid'): + """ + Create a regular grid of ``numfilaments_n * numfilaments_b`` many + filaments to approximate a finite-build coil. -def jaxrotationdash_pure(dofs, points, order): - rotation = jnp.zeros((len(points), )) - for j in range(1, order+1): - rotation += dofs[2*j-1] * 2*np.pi*j*jnp.cos(2*np.pi*j*points) - rotation -= dofs[2*j] * 2*np.pi*j*jnp.sin(2*np.pi*j*points) - return rotation + Note that "normal" and "binormal" in the function arguments here + refer to either the Frenet frame or the "coil centroid + frame" defined by Singh et al., before rotation. + Args: + curve: The underlying curve. + numfilaments_n: number of filaments in normal direction. + numfilaments_b: number of filaments in bi-normal direction. + gapsize_n: gap between filaments in normal direction. + gapsize_b: gap between filaments in bi-normal direction. + rotation_order: Fourier order (maximum mode number) to use in the expression for the rotation + of the filament pack. ``None`` means that the rotation is not optimized. + rotation_scaling: scaling for the rotation degrees of freedom. good + scaling improves the convergence of first order optimization + algorithms. If ``None``, then the default of ``1 / max(gapsize_n, gapsize_b)`` + is used. + frame: orthonormal frame to define normal and binormal before rotation (either 'centroid' or 'frenet') + """ + assert frame in ['centroid', 'frenet'] + if numfilaments_n % 2 == 1: + shifts_n = np.arange(numfilaments_n) - numfilaments_n//2 + else: + shifts_n = np.arange(numfilaments_n) - numfilaments_n/2 + 0.5 + shifts_n = shifts_n * gapsize_n + if numfilaments_b % 2 == 1: + shifts_b = np.arange(numfilaments_b) - numfilaments_b//2 + else: + shifts_b = np.arange(numfilaments_b) - numfilaments_b/2 + 0.5 + shifts_b = shifts_b * gapsize_b -def rotation_dcoeff(points, order): - jac = np.zeros((len(points), 2*order+1)) - jac[:, 0] = 1 - for j in range(1, order+1): - jac[:, 2*j-1] = np.sin(2*np.pi*j*points) - jac[:, 2*j+0] = np.cos(2*np.pi*j*points) - return jac + if rotation_scaling is None: + rotation_scaling = 1/max(gapsize_n, gapsize_b) + if rotation_order is None: + rotation = ZeroRotation(curve.quadpoints) + else: + rotation = FrameRotation(curve.quadpoints, rotation_order, scale=rotation_scaling) + if frame == 'frenet': + framedcurve = FramedCurveFrenet(curve, rotation) + else: + framedcurve = FramedCurveCentroid(curve, rotation) + filaments = [] + for i in range(numfilaments_n): + for j in range(numfilaments_b): + filaments.append(CurveFilament(framedcurve, shifts_n[i], shifts_b[j])) + return filaments -def rotationdash_dcoeff(points, order): - jac = np.zeros((len(points), 2*order+1)) - for j in range(1, order+1): - jac[:, 2*j-1] = +2*np.pi*j*np.cos(2*np.pi*j*points) - jac[:, 2*j+0] = -2*np.pi*j*np.sin(2*np.pi*j*points) - return jac diff --git a/src/simsopt/geo/framedcurve.py b/src/simsopt/geo/framedcurve.py new file mode 100644 index 000000000..dc883418f --- /dev/null +++ b/src/simsopt/geo/framedcurve.py @@ -0,0 +1,650 @@ +import numpy as np +import jax.numpy as jnp +from jax import vjp, jvp +import simsoptpp as sopp +from .._core.optimizable import Optimizable +from .._core.derivative import Derivative +from .curve import Curve +from .jit import jit + +__all__ = ['FramedCurve', 'FramedCurveFrenet', 'FramedCurveCentroid', + 'FrameRotation', 'ZeroRotation', 'FramedCurve'] + + +class FramedCurve(sopp.Curve, Curve): + + def __init__(self, curve, rotation=None): + """ + A FramedCurve defines an orthonormal basis around a Curve, + where one basis is taken to be the tangent along the Curve. + The frame is defined with respect to a reference frame, + either centroid or frenet. A rotation angle defines the rotation + with respect to this reference frame. + """ + self.curve = curve + sopp.Curve.__init__(self, curve.quadpoints) + deps = [curve] + if rotation is not None: + deps.append(rotation) + if rotation is None: + rotation = ZeroRotation(curve.quadpoints) + self.rotation = rotation + Curve.__init__(self, depends_on=deps) + + +class FramedCurveFrenet(FramedCurve): + r""" + Given a curve, one defines a reference frame using the Frenet normal and + binormal vectors: + + tangent = dr/dl + + normal = (dtangent/dl)/||dtangent/dl|| + + binormal = tangent x normal + + In addition, we specify an angle along the curve that + defines the rotation with respect to this reference frame. + """ + + def __init__(self, curve, rotation=None): + FramedCurve.__init__(self, curve, rotation) + + self.binorm = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash: binormal_curvature_pure_frenet( + gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash)) + self.binormgrad_vjp0 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(g, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash), gamma)[1](v)[0]) + self.binormgrad_vjp1 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(gamma, g, gammadashdash, gammadashdashdash, alpha, alphadash), gammadash)[1](v)[0]) + self.binormgrad_vjp2 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(gamma, gammadash, g, gammadashdashdash, alpha, alphadash), gammadashdash)[1](v)[0]) + self.binormgrad_vjp3 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(gamma, gammadash, gammadashdash, g, alpha, alphadash), gammadashdashdash)[1](v)[0]) + self.binormgrad_vjp4 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(gamma, gammadash, gammadashdash, gammadashdashdash, g, alphadash), alpha)[1](v)[0]) + self.binormgrad_vjp5 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(gamma, gammadash, gammadashdash, gammadashdashdash, alpha, g), alphadash)[1](v)[0]) + + self.torsion = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash: torsion_pure_frenet( + gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash)) + self.torsiongrad_vjp0 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(g, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash), gamma)[1](v)[0]) + self.torsiongrad_vjp1 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(gamma, g, gammadashdash, gammadashdashdash, alpha, alphadash), gammadash)[1](v)[0]) + self.torsiongrad_vjp2 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(gamma, gammadash, g, gammadashdashdash, alpha, alphadash), gammadashdash)[1](v)[0]) + self.torsiongrad_vjp3 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(gamma, gammadash, gammadashdash, g, alpha, alphadash), gammadashdashdash)[1](v)[0]) + self.torsiongrad_vjp4 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(gamma, gammadash, gammadashdash, gammadashdashdash, g, alphadash), alpha)[1](v)[0]) + self.torsiongrad_vjp5 = jit(lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(gamma, gammadash, gammadashdash, gammadashdashdash, alpha, g), alphadash)[1](v)[0]) + + def rotated_frame(self): + return rotated_frenet_frame(self.curve.gamma(), self.curve.gammadash(), self.curve.gammadashdash(), self.rotation.alpha(self.curve.quadpoints)) + + def rotated_frame_dash(self): + return rotated_frenet_frame_dash( + self.curve.gamma(), self.curve.gammadash(), self.curve.gammadashdash(), self.curve.gammadashdashdash(), + self.rotation.alpha(self.curve.quadpoints), self.rotation.alphadash(self.curve.quadpoints) + ) + + def frame_torsion(self): + """Exports frame torsion along a curve""" + gamma = self.curve.gamma() + d1gamma = self.curve.gammadash() + d2gamma = self.curve.gammadashdash() + d3gamma = self.curve.gammadashdashdash() + alpha = self.rotation.alpha(self.curve.quadpoints) + alphadash = self.rotation.alphadash(self.curve.quadpoints) + return self.torsion(gamma, d1gamma, d2gamma, d3gamma, alpha, alphadash) + + def frame_binormal_curvature(self): + gamma = self.curve.gamma() + d1gamma = self.curve.gammadash() + d2gamma = self.curve.gammadashdash() + d3gamma = self.curve.gammadashdashdash() + alpha = self.rotation.alpha(self.curve.quadpoints) + alphadash = self.rotation.alphadash(self.curve.quadpoints) + return self.binorm(gamma, d1gamma, d2gamma, d3gamma, alpha, alphadash) + + def dframe_torsion_by_dcoeff_vjp(self, v): + gamma = self.curve.gamma() + d1gamma = self.curve.gammadash() + d2gamma = self.curve.gammadashdash() + d3gamma = self.curve.gammadashdashdash() + alpha = self.rotation.alpha(self.curve.quadpoints) + alphadash = self.rotation.alphadash(self.curve.quadpoints) + + grad0 = self.torsiongrad_vjp0(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + grad1 = self.torsiongrad_vjp1(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + grad2 = self.torsiongrad_vjp2(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + grad3 = self.torsiongrad_vjp3(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + grad4 = self.torsiongrad_vjp4(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + grad5 = self.torsiongrad_vjp5(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + + return self.curve.dgamma_by_dcoeff_vjp(grad0) \ + + self.curve.dgammadash_by_dcoeff_vjp(grad1) \ + + self.curve.dgammadashdash_by_dcoeff_vjp(grad2) \ + + self.curve.dgammadashdashdash_by_dcoeff_vjp(grad3) \ + + self.rotation.dalpha_by_dcoeff_vjp(self.curve.quadpoints, grad4) \ + + self.rotation.dalphadash_by_dcoeff_vjp(self.curve.quadpoints, grad5) + + def dframe_binormal_curvature_by_dcoeff_vjp(self, v): + gamma = self.curve.gamma() + d1gamma = self.curve.gammadash() + d2gamma = self.curve.gammadashdash() + d3gamma = self.curve.gammadashdashdash() + alpha = self.rotation.alpha(self.curve.quadpoints) + alphadash = self.rotation.alphadash(self.curve.quadpoints) + + grad0 = self.binormgrad_vjp0(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + grad1 = self.binormgrad_vjp1(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + grad2 = self.binormgrad_vjp2(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + grad3 = self.binormgrad_vjp3(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + grad4 = self.binormgrad_vjp4(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + grad5 = self.binormgrad_vjp5(gamma, d1gamma, d2gamma, + d3gamma, alpha, alphadash, v) + + return self.curve.dgamma_by_dcoeff_vjp(grad0) \ + + self.curve.dgammadash_by_dcoeff_vjp(grad1) \ + + self.curve.dgammadashdash_by_dcoeff_vjp(grad2) \ + + self.curve.dgammadashdashdash_by_dcoeff_vjp(grad3) \ + + self.rotation.dalpha_by_dcoeff_vjp(self.curve.quadpoints, grad4) \ + + self.rotation.dalphadash_by_dcoeff_vjp(self.curve.quadpoints, grad5) + + def rotated_frame_dcoeff_vjp(self, v, dn, db, arg=0): + assert arg in [0, 1, 2, 3] + g = self.curve.gamma() + gd = self.curve.gammadash() + gdd = self.curve.gammadashdash() + a = self.rotation.alpha(self.curve.quadpoints) + zero = np.zeros_like(v) + if arg == 0: + return rotated_frenet_frame_dcoeff_vjp0( + g, gd, gdd, a, (zero, dn*v, db*v)) + elif arg == 1: + return rotated_frenet_frame_dcoeff_vjp1( + g, gd, gdd, a, (zero, dn*v, db*v)) + elif arg == 2: + return rotated_frenet_frame_dcoeff_vjp2( + g, gd, gdd, a, (zero, dn*v, db*v)) + elif arg == 3: + return rotated_frenet_frame_dcoeff_vjp3( + g, gd, gdd, a, (zero, dn*v, db*v)) + + def rotated_frame_dash_dcoeff_vjp(self, v, dn, db, arg=0): + assert arg in [0, 1, 2, 3, 4, 5] + g = self.curve.gamma() + gd = self.curve.gammadash() + gdd = self.curve.gammadashdash() + gddd = self.curve.gammadashdashdash() + a = self.rotation.alpha(self.curve.quadpoints) + ad = self.rotation.alphadash(self.curve.quadpoints) + zero = np.zeros_like(v) + if arg == 0: + return rotated_frenet_frame_dash_dcoeff_vjp0( + g, gd, gdd, gddd, a, ad, (zero, dn*v, db*v)) + if arg == 1: + return rotated_frenet_frame_dash_dcoeff_vjp1( + g, gd, gdd, gddd, a, ad, (zero, dn*v, db*v)) + if arg == 2: + return rotated_frenet_frame_dash_dcoeff_vjp2( + g, gd, gdd, gddd, a, ad, (zero, dn*v, db*v)) + if arg == 3: + return rotated_frenet_frame_dash_dcoeff_vjp3( + g, gd, gdd, gddd, a, ad, (zero, dn*v, db*v)) + if arg == 4: + return rotated_frenet_frame_dash_dcoeff_vjp4( + g, gd, gdd, gddd, a, ad, (zero, dn*v, db*v)) + if arg == 5: + return rotated_frenet_frame_dash_dcoeff_vjp5( + g, gd, gdd, gddd, a, ad, (zero, dn*v, db*v)) + + +class FramedCurveCentroid(FramedCurve): + """ + Implementation of the centroid frame introduced in + Singh et al, "Optimization of finite-build stellarator coils", + Journal of Plasma Physics 86 (2020), + doi:10.1017/S0022377820000756. + Given a curve, one defines a reference frame using the normal and + binormal vector based on the centoid of the coil. In addition, we specify an + angle along the curve that defines the rotation with respect to this + reference frame. + + The idea is explained well in Figure 1 in the reference above. + """ + + def __init__(self, curve, rotation=None): + FramedCurve.__init__(self, curve, rotation) + + self.torsion = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash: torsion_pure_centroid( + gamma, gammadash, gammadashdash, alpha, alphadash)) + self.torsiongrad_vjp0 = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(g, gammadash, gammadashdash, alpha, alphadash), gamma)[1](v)[0]) + self.torsiongrad_vjp1 = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(gamma, g, gammadashdash, alpha, alphadash), gammadash)[1](v)[0]) + self.torsiongrad_vjp2 = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(gamma, gammadash, g, alpha, alphadash), gammadashdash)[1](v)[0]) + self.torsiongrad_vjp4 = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(gamma, gammadash, gammadashdash, g, alphadash), alpha)[1](v)[0]) + self.torsiongrad_vjp5 = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: self.torsion(gamma, gammadash, gammadashdash, alpha, g), alphadash)[1](v)[0]) + + self.binorm = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash: binormal_curvature_pure_centroid( + gamma, gammadash, gammadashdash, alpha, alphadash)) + self.binormgrad_vjp0 = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(g, gammadash, gammadashdash, alpha, alphadash), gamma)[1](v)[0]) + self.binormgrad_vjp1 = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(gamma, g, gammadashdash, alpha, alphadash), gammadash)[1](v)[0]) + self.binormgrad_vjp2 = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(gamma, gammadash, g, alpha, alphadash), gammadashdash)[1](v)[0]) + self.binormgrad_vjp4 = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(gamma, gammadash, gammadashdash, g, alphadash), alpha)[1](v)[0]) + self.binormgrad_vjp5 = jit(lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: self.binorm(gamma, gammadash, gammadashdash, alpha, g), alphadash)[1](v)[0]) + + def frame_torsion(self): + """Exports frame torsion along a curve""" + gamma = self.curve.gamma() + d1gamma = self.curve.gammadash() + d2gamma = self.curve.gammadashdash() + alpha = self.rotation.alpha(self.curve.quadpoints) + alphadash = self.rotation.alphadash(self.curve.quadpoints) + return self.torsion(gamma, d1gamma, d2gamma, alpha, alphadash) + + def frame_binormal_curvature(self): + gamma = self.curve.gamma() + d1gamma = self.curve.gammadash() + d2gamma = self.curve.gammadashdash() + alpha = self.rotation.alpha(self.curve.quadpoints) + alphadash = self.rotation.alphadash(self.curve.quadpoints) + return self.binorm(gamma, d1gamma, d2gamma, alpha, alphadash) + + def rotated_frame(self): + return rotated_centroid_frame(self.curve.gamma(), self.curve.gammadash(), + self.rotation.alpha(self.curve.quadpoints)) + + def rotated_frame_dash(self): + return rotated_centroid_frame_dash( + self.curve.gamma(), self.curve.gammadash(), self.curve.gammadashdash(), + self.rotation.alpha(self.curve.quadpoints), self.rotation.alphadash(self.curve.quadpoints) + ) + + def rotated_frame_dcoeff_vjp(self, v, dn, db, arg=0): + assert arg in [0, 1, 2, 3] + g = self.curve.gamma() + gd = self.curve.gammadash() + a = self.rotation.alpha(self.curve.quadpoints) + zero = np.zeros_like(v) + if arg == 0: + return rotated_centroid_frame_dcoeff_vjp0( + g, gd, a, (zero, dn*v, db*v)) + if arg == 1: + return rotated_centroid_frame_dcoeff_vjp1( + g, gd, a, (zero, dn*v, db*v)) + if arg == 2: + return None + if arg == 3: + return rotated_centroid_frame_dcoeff_vjp3( + g, gd, a, (zero, dn*v, db*v)) + + def rotated_frame_dash_dcoeff_vjp(self, v, dn, db, arg=0): + assert arg in [0, 1, 2, 3, 4, 5] + g = self.curve.gamma() + gd = self.curve.gammadash() + gdd = self.curve.gammadashdash() + a = self.rotation.alpha(self.curve.quadpoints) + ad = self.rotation.alphadash(self.curve.quadpoints) + zero = np.zeros_like(v) + if arg == 0: + return rotated_centroid_frame_dash_dcoeff_vjp0( + g, gd, gdd, a, ad, (zero, dn*v, db*v)) + if arg == 1: + return rotated_centroid_frame_dash_dcoeff_vjp1( + g, gd, gdd, a, ad, (zero, dn*v, db*v)) + if arg == 2: + return rotated_centroid_frame_dash_dcoeff_vjp2( + g, gd, gdd, a, ad, (zero, dn*v, db*v)) + if arg == 3: + return None + if arg == 4: + return rotated_centroid_frame_dash_dcoeff_vjp4( + g, gd, gdd, a, ad, (zero, dn*v, db*v)) + if arg == 5: + return rotated_centroid_frame_dash_dcoeff_vjp5( + g, gd, gdd, a, ad, (zero, dn*v, db*v)) + + def dframe_binormal_curvature_by_dcoeff_vjp(self, v): + gamma = self.curve.gamma() + d1gamma = self.curve.gammadash() + d2gamma = self.curve.gammadashdash() + alpha = self.rotation.alpha(self.curve.quadpoints) + alphadash = self.rotation.alphadash(self.curve.quadpoints) + + grad0 = self.binormgrad_vjp0(gamma, d1gamma, d2gamma, + alpha, alphadash, v) + grad1 = self.binormgrad_vjp1(gamma, d1gamma, d2gamma, + alpha, alphadash, v) + grad2 = self.binormgrad_vjp2(gamma, d1gamma, d2gamma, + alpha, alphadash, v) + grad4 = self.binormgrad_vjp4(gamma, d1gamma, d2gamma, + alpha, alphadash, v) + grad5 = self.binormgrad_vjp5(gamma, d1gamma, d2gamma, + alpha, alphadash, v) + + return self.curve.dgamma_by_dcoeff_vjp(grad0) \ + + self.curve.dgammadash_by_dcoeff_vjp(grad1) \ + + self.curve.dgammadashdash_by_dcoeff_vjp(grad2) \ + + self.rotation.dalpha_by_dcoeff_vjp(self.curve.quadpoints, grad4) \ + + self.rotation.dalphadash_by_dcoeff_vjp(self.curve.quadpoints, grad5) + + def dframe_torsion_by_dcoeff_vjp(self, v): + gamma = self.curve.gamma() + d1gamma = self.curve.gammadash() + d2gamma = self.curve.gammadashdash() + alpha = self.rotation.alpha(self.curve.quadpoints) + alphadash = self.rotation.alphadash(self.curve.quadpoints) + + grad0 = self.torsiongrad_vjp0(gamma, d1gamma, d2gamma, + alpha, alphadash, v) + grad1 = self.torsiongrad_vjp1(gamma, d1gamma, d2gamma, + alpha, alphadash, v) + grad2 = self.torsiongrad_vjp2(gamma, d1gamma, d2gamma, + alpha, alphadash, v) + grad4 = self.torsiongrad_vjp4(gamma, d1gamma, d2gamma, + alpha, alphadash, v) + grad5 = self.torsiongrad_vjp5(gamma, d1gamma, d2gamma, + alpha, alphadash, v) + + return self.curve.dgamma_by_dcoeff_vjp(grad0) \ + + self.curve.dgammadash_by_dcoeff_vjp(grad1) \ + + self.curve.dgammadashdash_by_dcoeff_vjp(grad2) \ + + self.rotation.dalpha_by_dcoeff_vjp(self.curve.quadpoints, grad4) \ + + self.rotation.dalphadash_by_dcoeff_vjp(self.curve.quadpoints, grad5) + + +class FrameRotation(Optimizable): + + def __init__(self, quadpoints, order, scale=1., dofs=None): + """ + Defines the rotation angle with respect to a reference orthonormal + frame (either frenet or centroid). For example, can be used to + define the rotation of a multifilament pack; alpha in Figure 1 of + Singh et al, "Optimization of finite-build stellarator coils", + Journal of Plasma Physics 86 (2020), + doi:10.1017/S0022377820000756 + """ + self.order = order + if dofs is None: + super().__init__(x0=np.zeros((2*order+1, ))) + else: + super().__init__(dofs=dofs) + self.quadpoints = quadpoints + self.scale = scale + self.jac = rotation_dcoeff(quadpoints, order) + self.jacdash = rotationdash_dcoeff(quadpoints, order) + self.jax_alpha = jit(lambda dofs, points: jaxrotation_pure(dofs, points, self.order)) + self.jax_alphadash = jit(lambda dofs, points: jaxrotationdash_pure(dofs, points, self.order)) + + def alpha(self, quadpoints): + return self.scale * self.jax_alpha(self._dofs.full_x, quadpoints) + + def alphadash(self, quadpoints): + return self.scale * self.jax_alphadash(self._dofs.full_x, quadpoints) + + def dalpha_by_dcoeff_vjp(self, quadpoints, v): + return Derivative({self: self.scale * sopp.vjp(v, self.jac)}) + + def dalphadash_by_dcoeff_vjp(self, quadpoints, v): + return Derivative({self: self.scale * sopp.vjp(v, self.jacdash)}) + + +class ZeroRotation(Optimizable): + + def __init__(self, quadpoints): + """ + Dummy class that just returns zero for the rotation angle. Equivalent to using + + .. code-block:: python + + rot = FrameRotation(...) + rot.fix_all() + + """ + super().__init__() + self.zero = np.zeros((quadpoints.size, )) + + def alpha(self, quadpoints): + return self.zero + + def alphadash(self, quadpoints): + return self.zero + + def dalpha_by_dcoeff_vjp(self, quadpoints, v): + return Derivative({}) + + def dalphadash_by_dcoeff_vjp(self, quadpoints, v): + return Derivative({}) + + +@jit +def rotated_centroid_frame(gamma, gammadash, alpha): + t = gammadash + t *= 1./jnp.linalg.norm(gammadash, axis=1)[:, None] + R = jnp.mean(gamma, axis=0) # centroid + delta = gamma - R[None, :] + n = delta - jnp.sum(delta * t, axis=1)[:, None] * t + n *= 1./jnp.linalg.norm(n, axis=1)[:, None] + b = jnp.cross(t, n, axis=1) + + # now rotate the frame by alpha + nn = jnp.cos(alpha)[:, None] * n - jnp.sin(alpha)[:, None] * b + bb = jnp.sin(alpha)[:, None] * n + jnp.cos(alpha)[:, None] * b + return t, nn, bb + + +rotated_centroid_frame_dash = jit( + lambda gamma, gammadash, gammadashdash, alpha, alphadash: jvp(rotated_centroid_frame, + (gamma, gammadash, alpha), + (gammadash, gammadashdash, alphadash))[1]) + +rotated_centroid_frame_dcoeff_vjp0 = jit( + lambda gamma, gammadash, alpha, v: vjp( + lambda g: rotated_centroid_frame(g, gammadash, alpha), gamma)[1](v)[0]) + +rotated_centroid_frame_dcoeff_vjp1 = jit( + lambda gamma, gammadash, alpha, v: vjp( + lambda gd: rotated_centroid_frame(gamma, gd, alpha), gammadash)[1](v)[0]) + +rotated_centroid_frame_dcoeff_vjp3 = jit( + lambda gamma, gammadash, alpha, v: vjp( + lambda a: rotated_centroid_frame(gamma, gammadash, a), alpha)[1](v)[0]) + +rotated_centroid_frame_dash_dcoeff_vjp0 = jit( + lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda g: rotated_centroid_frame_dash(g, gammadash, gammadashdash, alpha, alphadash), gamma)[1](v)[0]) + +rotated_centroid_frame_dash_dcoeff_vjp1 = jit( + lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda gd: rotated_centroid_frame_dash(gamma, gd, gammadashdash, alpha, alphadash), gammadash)[1](v)[0]) + +rotated_centroid_frame_dash_dcoeff_vjp2 = jit( + lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda gdd: rotated_centroid_frame_dash(gamma, gammadash, gdd, alpha, alphadash), gammadashdash)[1](v)[0]) + +rotated_centroid_frame_dash_dcoeff_vjp4 = jit( + lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda a: rotated_centroid_frame_dash(gamma, gammadash, gammadashdash, a, alphadash), alpha)[1](v)[0]) + +rotated_centroid_frame_dash_dcoeff_vjp5 = jit( + lambda gamma, gammadash, gammadashdash, alpha, alphadash, v: vjp( + lambda ad: rotated_centroid_frame_dash(gamma, gammadash, gammadashdash, alpha, ad), alphadash)[1](v)[0]) + + +@jit +def rotated_frenet_frame(gamma, gammadash, gammadashdash, alpha): + """Frenet frame of a curve rotated by a angle that varies along the coil path""" + + N = gamma.shape[0] + t, n, b = (np.zeros((N, 3)), np.zeros((N, 3)), np.zeros((N, 3))) + t = gammadash + t *= 1./jnp.linalg.norm(gammadash, axis=1)[:, None] + + tdash = (1./jnp.linalg.norm(gammadash, axis=1)[:, None])**2 * (jnp.linalg.norm(gammadash, axis=1)[:, None] * gammadashdash + - (inner(gammadash, gammadashdash)/jnp.linalg.norm(gammadash, axis=1))[:, None] * gammadash) + + n = tdash + n *= 1/jnp.linalg.norm(tdash, axis=1)[:, None] + b = jnp.cross(t, n, axis=1) + # now rotate the frame by alpha + nn = jnp.cos(alpha)[:, None] * n - jnp.sin(alpha)[:, None] * b + bb = jnp.sin(alpha)[:, None] * n + jnp.cos(alpha)[:, None] * b + + return t, nn, bb + + +rotated_frenet_frame_dash = jit( + lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash: jvp(rotated_frenet_frame, + (gamma, gammadash, + gammadashdash, alpha), + (gammadash, gammadashdash, gammadashdashdash, alphadash))[1]) + +rotated_frenet_frame_dcoeff_vjp0 = jit( + lambda gamma, gammadash, gammadashdash, alpha, v: vjp( + lambda g: rotated_frenet_frame(g, gammadash, gammadashdash, alpha), gamma)[1](v)[0]) + +rotated_frenet_frame_dcoeff_vjp1 = jit( + lambda gamma, gammadash, gammadashdash, alpha, v: vjp( + lambda gd: rotated_frenet_frame(gamma, gd, gammadashdash, alpha), gammadash)[1](v)[0]) + +rotated_frenet_frame_dcoeff_vjp2 = jit( + lambda gamma, gammadash, gammadashdash, alpha, v: vjp( + lambda gdd: rotated_frenet_frame(gamma, gammadash, gdd, alpha), gammadashdash)[1](v)[0]) + +rotated_frenet_frame_dcoeff_vjp3 = jit( + lambda gamma, gammadash, gammadashdash, alpha, v: vjp( + lambda a: rotated_frenet_frame(gamma, gammadash, gammadashdash, a), alpha)[1](v)[0]) + +rotated_frenet_frame_dash_dcoeff_vjp0 = jit( + lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda g: rotated_frenet_frame_dash(g, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash), gamma)[1](v)[0]) + +rotated_frenet_frame_dash_dcoeff_vjp1 = jit( + lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda gd: rotated_frenet_frame_dash(gamma, gd, gammadashdash, gammadashdashdash, alpha, alphadash), gammadash)[1](v)[0]) + +rotated_frenet_frame_dash_dcoeff_vjp2 = jit( + lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda gdd: rotated_frenet_frame_dash(gamma, gammadash, gdd, gammadashdashdash, alpha, alphadash), gammadashdash)[1](v)[0]) + +rotated_frenet_frame_dash_dcoeff_vjp3 = jit( + lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda gddd: rotated_frenet_frame_dash(gamma, gammadash, gammadashdash, gddd, alpha, alphadash), gammadashdashdash)[1](v)[0]) + +rotated_frenet_frame_dash_dcoeff_vjp4 = jit( + lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda a: rotated_frenet_frame_dash(gamma, gammadash, gammadashdash, gammadashdashdash, a, alphadash), alpha)[1](v)[0]) + +rotated_frenet_frame_dash_dcoeff_vjp5 = jit( + lambda gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash, v: vjp( + lambda ad: rotated_frenet_frame_dash(gamma, gammadash, gammadashdash, gammadashdashdash, alpha, ad), alphadash)[1](v)[0]) + + +def jaxrotation_pure(dofs, points, order): + rotation = jnp.zeros((len(points), )) + rotation += dofs[0] + for j in range(1, order+1): + rotation += dofs[2*j-1] * jnp.sin(2*np.pi*j*points) + rotation += dofs[2*j] * jnp.cos(2*np.pi*j*points) + return rotation + + +def jaxrotationdash_pure(dofs, points, order): + rotation = jnp.zeros((len(points), )) + for j in range(1, order+1): + rotation += dofs[2*j-1] * 2*np.pi*j*jnp.cos(2*np.pi*j*points) + rotation -= dofs[2*j] * 2*np.pi*j*jnp.sin(2*np.pi*j*points) + return rotation + + +def rotation_dcoeff(points, order): + jac = np.zeros((len(points), 2*order+1)) + jac[:, 0] = 1 + for j in range(1, order+1): + jac[:, 2*j-1] = np.sin(2*np.pi*j*points) + jac[:, 2*j+0] = np.cos(2*np.pi*j*points) + return jac + + +def rotationdash_dcoeff(points, order): + jac = np.zeros((len(points), 2*order+1)) + for j in range(1, order+1): + jac[:, 2*j-1] = +2*np.pi*j*np.cos(2*np.pi*j*points) + jac[:, 2*j+0] = -2*np.pi*j*np.sin(2*np.pi*j*points) + return jac + + +def inner(a, b): + """Inner product for arrays of shape (N, 3)""" + return np.sum(a*b, axis=1) + + +def torsion_pure_frenet(gamma, gammadash, gammadashdash, gammadashdashdash, + alpha, alphadash): + """Torsion function for export/evaulate coil sets""" + + _, _, b = rotated_frenet_frame(gamma, gammadash, gammadashdash, alpha) + _, ndash, _ = rotated_frenet_frame_dash( + gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash) + + ndash *= 1/jnp.linalg.norm(gammadash, axis=1)[:, None] + return inner(ndash, b) + + +def binormal_curvature_pure_frenet(gamma, gammadash, gammadashdash, gammadashdashdash, + alpha, alphadash): + """Binormal curvature function for export/evaulate coil sets.""" + + _, _, b = rotated_frenet_frame(gamma, gammadash, gammadashdash, alpha) + tdash, _, _ = rotated_frenet_frame_dash( + gamma, gammadash, gammadashdash, gammadashdashdash, alpha, alphadash) + + tdash *= 1/jnp.linalg.norm(gammadash, axis=1)[:, None] + return inner(tdash, b) + + +def torsion_pure_centroid(gamma, gammadash, gammadashdash, + alpha, alphadash): + """Torsion function for export/evaulate coil sets""" + + _, _, b = rotated_centroid_frame(gamma, gammadash, alpha) + _, ndash, _ = rotated_centroid_frame_dash( + gamma, gammadash, gammadashdash, alpha, alphadash) + + ndash *= 1/jnp.linalg.norm(gammadash, axis=1)[:, None] + return inner(ndash, b) + + +def binormal_curvature_pure_centroid(gamma, gammadash, gammadashdash, + alpha, alphadash): + """Binormal curvature function for export/evaulate coil sets.""" + + _, _, b = rotated_centroid_frame(gamma, gammadash, alpha) + tdash, _, _ = rotated_centroid_frame_dash( + gamma, gammadash, gammadashdash, alpha, alphadash) + + tdash *= 1/jnp.linalg.norm(gammadash, axis=1)[:, None] + return inner(tdash, b) diff --git a/src/simsopt/geo/strain_optimization.py b/src/simsopt/geo/strain_optimization.py new file mode 100644 index 000000000..59400c107 --- /dev/null +++ b/src/simsopt/geo/strain_optimization.py @@ -0,0 +1,192 @@ +import jax.numpy as jnp +from jax import vjp, grad +from simsopt.geo.jit import jit +from simsopt._core import Optimizable +from simsopt._core.derivative import derivative_dec +from simsopt.geo.curveobjectives import Lp_torsion_pure + +__all__ = ['LPBinormalCurvatureStrainPenalty', + 'LPTorsionalStrainPenalty', 'CoilStrain'] + + +class LPBinormalCurvatureStrainPenalty(Optimizable): + r""" + This class computes a penalty term based on the :math:`L_p` norm + of the binormal curvature strain, and penalizes where the local strain exceeds a threshold + + .. math:: + J = \frac{1}{p} \int_{\text{curve}} \text{max}(\epsilon_{\text{bend}} - \epsilon_0, 0)^p ~dl, + + where + + .. math:: + \epsilon_{\text{bend}} = \frac{w |\hat{\textbf{b}} \cdot \boldsymbol{\kappa}|}{2}, + + :math:`w` is the width of the tape, :math:`\hat{\textbf{b}}` is the + frame binormal vector, :math:`\boldsymbol{\kappa}` is the curvature vector of the + filamentary coil, and :math:`\epsilon_0` is a threshold strain, given by the argument ``threshold``. + """ + + def __init__(self, framedcurve, width=1e-3, p=2, threshold=0): + self.framedcurve = framedcurve + self.strain = CoilStrain(framedcurve, width) + self.width = width + self.p = p + self.threshold = threshold + self.J_jax = jit(lambda binorm, gammadash: Lp_torsion_pure( + binorm, gammadash, p, threshold)) + self.grad0 = jit(lambda binorm, gammadash: grad( + self.J_jax, argnums=0)(binorm, gammadash)) + self.grad1 = jit(lambda binorm, gammadash: grad( + self.J_jax, argnums=1)(binorm, gammadash)) + super().__init__(depends_on=[framedcurve]) + + def J(self): + """ + This returns the value of the quantity. + """ + return self.J_jax(self.strain.binormal_curvature_strain(), self.framedcurve.curve.gammadash()) + + @derivative_dec + def dJ(self): + """ + This returns the derivative of the quantity with respect to the curve and rotation dofs. + """ + grad0 = self.grad0(self.strain.binormal_curvature_strain(), + self.framedcurve.curve.gammadash()) + grad1 = self.grad1(self.strain.binormal_curvature_strain(), + self.framedcurve.curve.gammadash()) + vjp0 = self.strain.binormstrain_vjp( + self.framedcurve.frame_binormal_curvature(), self.width, grad0) + return self.framedcurve.dframe_binormal_curvature_by_dcoeff_vjp(vjp0) \ + + self.framedcurve.curve.dgammadash_by_dcoeff_vjp(grad1) + + return_fn_map = {'J': J, 'dJ': dJ} + + +class LPTorsionalStrainPenalty(Optimizable): + r""" + This class computes a penalty term based on the :math:`L_p` norm + of the torsional strain, and penalizes where the local strain exceeds a threshold + + .. math:: + J = \frac{1}{p} \int_{\text{curve}} \text{max}(\epsilon_{\text{tor}} - \epsilon_0, 0)^p ~dl + + where + + .. math:: + \epsilon_{\text{tor}} = \frac{\tau^2 w^2}{12}, + + :math:`\tau` is the torsion of the tape frame, :math:`w` is the width of the tape, + and :math:`\epsilon_0` is a threshold strain, given by the argument ``threshold``. + """ + + def __init__(self, framedcurve, width=1e-3, p=2, threshold=0): + self.framedcurve = framedcurve + self.strain = CoilStrain(framedcurve, width) + self.width = width + self.p = p + self.threshold = threshold + self.J_jax = jit(lambda torsion, gammadash: Lp_torsion_pure( + torsion, gammadash, p, threshold)) + self.grad0 = jit(lambda torsion, gammadash: grad( + self.J_jax, argnums=0)(torsion, gammadash)) + self.grad1 = jit(lambda torsion, gammadash: grad( + self.J_jax, argnums=1)(torsion, gammadash)) + super().__init__(depends_on=[framedcurve]) + + def J(self): + """ + This returns the value of the quantity. + """ + return self.J_jax(self.strain.torsional_strain(), self.framedcurve.curve.gammadash()) + + @derivative_dec + def dJ(self): + """ + This returns the derivative of the quantity with respect to the curve and rotation dofs. + """ + grad0 = self.grad0(self.strain.torsional_strain(), + self.framedcurve.curve.gammadash()) + grad1 = self.grad1(self.strain.torsional_strain(), + self.framedcurve.curve.gammadash()) + vjp0 = self.strain.torstrain_vjp( + self.framedcurve.frame_torsion(), self.width, grad0) + return self.framedcurve.dframe_torsion_by_dcoeff_vjp(vjp0) \ + + self.framedcurve.curve.dgammadash_by_dcoeff_vjp(grad1) + + return_fn_map = {'J': J, 'dJ': dJ} + + +class CoilStrain(Optimizable): + r""" + This class evaluates the torsional and binormal curvature strains on HTS, based on + a filamentary model of the coil and the orientation of the HTS tape. + + As defined in, + + Paz Soldan, "Non-planar coil winding angle optimization for compatibility with + non-insulated high-temperature superconducting magnets", Journal of Plasma Physics + 86 (2020), doi:10.1017/S0022377820001208, + + the expressions for the strains are: + + .. math:: + \epsilon_{\text{tor}} = \frac{\tau^2 w^2}{12} + + \epsilon_{\text{bend}} = \frac{w |\hat{\textbf{b}} \cdot \boldsymbol{\kappa}|}{2}, + + where :math:`\tau` is the torsion of the tape frame, :math:`\hat{\textbf{b}}` is the + frame binormal vector, :math:`\boldsymbol{\kappa}` is the curvature vector of the + filamentary coil, and :math:`w` is the width of the tape. + + This class is not intended to be used as an objective function inside + optimization. For that purpose you should instead use + :obj:`LPBinormalCurvatureStrainPenalty` or :obj:`LPTorsionalStrainPenalty`. + Those classes also compute gradients whereas this class does not. + """ + + def __init__(self, framedcurve, width=1e-3): + self.framedcurve = framedcurve + self.width = width + self.torstrain_jax = jit(lambda torsion, width: torstrain_pure( + torsion, width)) + self.binormstrain_jax = jit(lambda binorm, width: binormstrain_pure( + binorm, width)) + self.torstrain_vjp = jit(lambda torsion, width, v: vjp( + lambda g: torstrain_pure(g, width), torsion)[1](v)[0]) + self.binormstrain_vjp = jit(lambda binorm, width, v: vjp( + lambda g: binormstrain_pure(g, width), binorm)[1](v)[0]) + + super().__init__(depends_on=[framedcurve]) + + def torsional_strain(self): + r""" + Returns the value of the torsional strain, :math:`\epsilon_{\text{tor}}`, along + the quadpoints defining the filamentary coil. + """ + return self.torstrain_jax(self.framedcurve.frame_torsion(), self.width) + + def binormal_curvature_strain(self): + r""" + Returns the value of the torsional strain, :math:`\epsilon_{\text{bend}}`, along + the quadpoints defining the filamentary coil. + """ + return self.binormstrain_jax(self.framedcurve.frame_binormal_curvature(), self.width) + + +@jit +def torstrain_pure(torsion, width): + """ + This function is used in a Python+Jax implementation of the LPTorsionalStrainPenalty objective. + """ + return torsion**2 * width**2 / 12 + + +@jit +def binormstrain_pure(binorm, width): + """ + This function is used in a Python+Jax implementation of the LPBinormalCurvatureStrainPenalty + objective. + """ + return (width / 2) * jnp.abs(binorm) diff --git a/tests/geo/test_finitebuild.py b/tests/geo/test_finitebuild.py index 6c389d45f..d1a90ff14 100644 --- a/tests/geo/test_finitebuild.py +++ b/tests/geo/test_finitebuild.py @@ -3,8 +3,8 @@ from simsopt.field.biotsavart import BiotSavart from simsopt.field.coil import Coil, apply_symmetries_to_curves, apply_symmetries_to_currents from simsopt.geo.curveobjectives import CurveLength, CurveCurveDistance -from simsopt.geo.finitebuild import CurveFilament, FilamentRotation, \ - create_multifilament_grid, ZeroRotation +from simsopt.geo import CurveFilament, FrameRotation, \ + create_multifilament_grid, ZeroRotation, FramedCurveCentroid, FramedCurveFrenet from simsopt.geo.qfmsurface import QfmSurface from simsopt.objectives.fluxobjective import SquaredFlux from simsopt.objectives.utilities import QuadraticPenalty @@ -16,25 +16,31 @@ class MultifilamentTesting(unittest.TestCase): def test_multifilament_gammadash(self): - for order in [None, 1]: - with self.subTest(order=order): - self.subtest_multifilament_gammadash(order) + for centroid in [True, False]: + for order in [None, 1]: + with self.subTest(order=order): + self.subtest_multifilament_gammadash(order, centroid) - def subtest_multifilament_gammadash(self, order): + def subtest_multifilament_gammadash(self, order, centroid): assert order in [1, None] - curves, currents, ma = get_ncsx_data(Nt_coils=6, ppp=80) + curves, currents, ma = get_ncsx_data(Nt_coils=6, ppp=120) c = curves[0] if order == 1: - rotation = FilamentRotation(c.quadpoints, order) + rotation = FrameRotation(c.quadpoints, order) rotation.x = np.array([0, 0.1, 0.3]) - rotationShared = FilamentRotation(curves[0].quadpoints, order, dofs=rotation.dofs) + rotationShared = FrameRotation(curves[0].quadpoints, order, dofs=rotation.dofs) assert np.allclose(rotation.x, rotationShared.x) assert np.allclose(rotation.alpha(c.quadpoints), rotationShared.alpha(c.quadpoints)) else: rotation = ZeroRotation(c.quadpoints) - c = CurveFilament(c, 0.01, 0.01, rotation) + if centroid: + framedcurve = FramedCurveCentroid(c, rotation) + else: + framedcurve = FramedCurveFrenet(c, rotation) + + c = CurveFilament(framedcurve, 0.01, 0.01) g = c.gamma() gd = c.gammadash() idx = 16 @@ -51,22 +57,28 @@ def subtest_multifilament_gammadash(self, order): def test_multifilament_coefficient_derivative(self): for order in [None, 1]: - with self.subTest(order=order): - self.subtest_multifilament_coefficient_derivative(order) + for centroid in [True, False]: + with self.subTest(order=order): + self.subtest_multifilament_coefficient_derivative(order, centroid) - def subtest_multifilament_coefficient_derivative(self, order): + def subtest_multifilament_coefficient_derivative(self, order, centroid): assert order in [1, None] curves, currents, ma = get_ncsx_data(Nt_coils=4, ppp=10) c = curves[0] if order == 1: - rotation = FilamentRotation(c.quadpoints, order) + rotation = FrameRotation(c.quadpoints, order) rotation.x = np.array([0, 0.1, 0.3]) else: rotation = ZeroRotation(c.quadpoints) - c = CurveFilament(c, 0.02, 0.02, rotation) + if centroid: + framedcurve = FramedCurveCentroid(c, rotation) + else: + framedcurve = FramedCurveFrenet(c, rotation) + + c = CurveFilament(framedcurve, 0.02, 0.02) dofs = c.x @@ -123,26 +135,27 @@ def check(fils, c, numfilaments_n, numfilaments_b): # check that the coil pack is centered around the underlying curve assert np.linalg.norm(np.mean([f.gamma() for f in fils], axis=0)-c.gamma()) < 1e-13 - numfilaments_n = 2 - numfilaments_b = 3 - fils = create_multifilament_grid( - c, numfilaments_n, numfilaments_b, gapsize_n, gapsize_b, - rotation_order=None, rotation_scaling=None) - check(fils, c, numfilaments_n, numfilaments_b) - - numfilaments_n = 3 - numfilaments_b = 2 - fils = create_multifilament_grid( - c, numfilaments_n, numfilaments_b, gapsize_n, gapsize_b, - rotation_order=None, rotation_scaling=None) - check(fils, c, numfilaments_n, numfilaments_b) - - fils = create_multifilament_grid( - c, numfilaments_n, numfilaments_b, gapsize_n, gapsize_b, - rotation_order=3, rotation_scaling=None) - xr = fils[0].rotation.x - fils[0].rotation.x = xr + 1e-2*np.random.standard_normal(size=xr.shape) - check(fils, c, numfilaments_n, numfilaments_b) + for frame in ['centroid', 'frenet']: + numfilaments_n = 2 + numfilaments_b = 3 + fils = create_multifilament_grid( + c, numfilaments_n, numfilaments_b, gapsize_n, gapsize_b, + rotation_order=None, rotation_scaling=None, frame=frame) + check(fils, c, numfilaments_n, numfilaments_b) + + numfilaments_n = 3 + numfilaments_b = 2 + fils = create_multifilament_grid( + c, numfilaments_n, numfilaments_b, gapsize_n, gapsize_b, + rotation_order=None, rotation_scaling=None, frame=frame) + check(fils, c, numfilaments_n, numfilaments_b) + + fils = create_multifilament_grid( + c, numfilaments_n, numfilaments_b, gapsize_n, gapsize_b, + rotation_order=3, rotation_scaling=None, frame=frame) + xr = fils[0].rotation.x + fils[0].rotation.x = xr + 1e-2*np.random.standard_normal(size=xr.shape) + check(fils, c, numfilaments_n, numfilaments_b) def test_biotsavart_with_symmetries(self): """ @@ -152,46 +165,49 @@ def test_biotsavart_with_symmetries(self): """ np.random.seed(1) base_curves, base_currents, ma = get_ncsx_data(Nt_coils=5) - base_curves_finite_build = sum( - [create_multifilament_grid(c, 2, 2, 0.01, 0.01, rotation_order=1) for c in base_curves], []) - base_currents_finite_build = sum([[c]*4 for c in base_currents], []) - - nfp = 3 - - curves = apply_symmetries_to_curves(base_curves, nfp, True) - curves_fb = apply_symmetries_to_curves(base_curves_finite_build, nfp, True) - currents_fb = apply_symmetries_to_currents(base_currents_finite_build, nfp, True) - - coils_fb = [Coil(c, curr) for (c, curr) in zip(curves_fb, currents_fb)] - - bs = BiotSavart(coils_fb) - s = get_surface("SurfaceXYZFourier", True) - s.fit_to_curve(ma, 0.1) - Jf = SquaredFlux(s, bs) - Jls = [CurveLength(c) for c in base_curves] - Jdist = CurveCurveDistance(curves, 0.5) - LENGTH_PEN = 1e-2 - DIST_PEN = 1e-2 - JF = Jf \ - + LENGTH_PEN * sum(QuadraticPenalty(Jls[i], Jls[i].J()) for i in range(len(base_curves))) \ - + DIST_PEN * Jdist - - def fun(dofs, grad=True): - JF.x = dofs - return (JF.J(), JF.dJ()) if grad else JF.J() - - dofs = JF.x - dofs += 1e-2 * np.random.standard_normal(size=dofs.shape) - np.random.seed(1) - h = np.random.uniform(size=dofs.shape) - J0, dJ0 = fun(dofs) - dJh = sum(dJ0 * h) - err = 1e6 - for i in range(10, 15): - eps = 0.5**i - J1 = fun(dofs + eps*h, grad=False) - J2 = fun(dofs - eps*h, grad=False) - err_new = abs((J1-J2)/(2*eps) - dJh) - assert err_new < 0.55**2 * err - err = err_new - print("err", err) + + for frame in ['centroid', 'frenet']: + + base_curves_finite_build = sum( + [create_multifilament_grid(c, 2, 2, 0.01, 0.01, rotation_order=1, frame=frame) for c in base_curves], []) + base_currents_finite_build = sum([[c]*4 for c in base_currents], []) + + nfp = 3 + + curves = apply_symmetries_to_curves(base_curves, nfp, True) + curves_fb = apply_symmetries_to_curves(base_curves_finite_build, nfp, True) + currents_fb = apply_symmetries_to_currents(base_currents_finite_build, nfp, True) + + coils_fb = [Coil(c, curr) for (c, curr) in zip(curves_fb, currents_fb)] + + bs = BiotSavart(coils_fb) + s = get_surface("SurfaceXYZFourier", True) + s.fit_to_curve(ma, 0.1) + Jf = SquaredFlux(s, bs) + Jls = [CurveLength(c) for c in base_curves] + Jdist = CurveCurveDistance(curves, 0.5) + LENGTH_PEN = 1e-2 + DIST_PEN = 1e-2 + JF = Jf \ + + LENGTH_PEN * sum(QuadraticPenalty(Jls[i], Jls[i].J()) for i in range(len(base_curves))) \ + + DIST_PEN * Jdist + + def fun(dofs, grad=True): + JF.x = dofs + return (JF.J(), JF.dJ()) if grad else JF.J() + + dofs = JF.x + dofs += 1e-2 * np.random.standard_normal(size=dofs.shape) + np.random.seed(1) + h = np.random.uniform(size=dofs.shape) + J0, dJ0 = fun(dofs) + dJh = sum(dJ0 * h) + err = 1e6 + for i in range(10, 15): + eps = 0.5**i + J1 = fun(dofs + eps*h, grad=False) + J2 = fun(dofs - eps*h, grad=False) + err_new = abs((J1-J2)/(2*eps) - dJh) + assert err_new < 0.55**2 * err + err = err_new + print("err", err) diff --git a/tests/geo/test_strainopt.py b/tests/geo/test_strainopt.py new file mode 100644 index 000000000..87069508e --- /dev/null +++ b/tests/geo/test_strainopt.py @@ -0,0 +1,135 @@ +import unittest +from simsopt.geo import FrameRotation, ZeroRotation, FramedCurveCentroid, FramedCurveFrenet +from simsopt.configs.zoo import get_ncsx_data +from simsopt.geo.strain_optimization import LPBinormalCurvatureStrainPenalty, LPTorsionalStrainPenalty +import numpy as np +from simsopt.geo.curvexyzfourier import CurveXYZFourier +from scipy.optimize import minimize + + +class CoilStrainTesting(unittest.TestCase): + + def test_strain_opt(self): + """ + Check that for a circular coil, strains + can be optimized to vanish using rotation + dofs. + """ + for centroid in [True, False]: + quadpoints = np.linspace(0, 1, 10, endpoint=False) + curve = CurveXYZFourier(quadpoints, order=1) + curve.set('xc(1)', 1e-4) + curve.set('ys(1)', 1e-4) + curve.fix_all() + order = 2 + np.random.seed(1) + dofs = np.random.standard_normal(size=(2*order+1,)) + rotation = FrameRotation(quadpoints, order) + rotation.x = np.random.standard_normal(size=(2*order+1,)) + if centroid: + framedcurve = FramedCurveCentroid(curve, rotation) + else: + framedcurve = FramedCurveFrenet(curve, rotation) + Jt = LPTorsionalStrainPenalty(framedcurve, width=1e-3, p=2, threshold=0) + Jb = LPBinormalCurvatureStrainPenalty(framedcurve, width=1e-3, p=2, threshold=0) + J = Jt+Jb + + def fun(dofs): + J.x = dofs + grad = J.dJ() + return J.J(), grad + res = minimize(fun, J.x, jac=True, method='L-BFGS-B', + options={'maxiter': 100, 'maxcor': 10, 'gtol': 1e-20, 'ftol': 1e-20}, tol=1e-20) + assert Jt.J() < 1e-12 + assert Jb.J() < 1e-12 + + def test_torsion(self): + for centroid in [True, False]: + for order in [None, 1]: + with self.subTest(order=order): + self.subtest_torsion(order, centroid) + + def test_binormal_curvature(self): + for centroid in [True, False]: + for order in [None, 1]: + with self.subTest(order=order): + self.subtest_binormal_curvature(order, centroid) + + def subtest_binormal_curvature(self, order, centroid): + assert order in [1, None] + curves, currents, ma = get_ncsx_data(Nt_coils=6, ppp=120) + c = curves[0] + + if order == 1: + rotation = FrameRotation(c.quadpoints, order) + rotation.x = np.array([0, 0.1, 0.3]) + rotationShared = FrameRotation(curves[0].quadpoints, order, dofs=rotation.dofs) + assert np.allclose(rotation.x, rotationShared.x) + assert np.allclose(rotation.alpha(c.quadpoints), rotationShared.alpha(c.quadpoints)) + else: + rotation = None + + if centroid: + framedcurve = FramedCurveCentroid(c, rotation) + else: + framedcurve = FramedCurveFrenet(c, rotation) + + J = LPBinormalCurvatureStrainPenalty(framedcurve, width=1e-3, p=2, threshold=1e-4) + + if (not (not centroid and order is None)): + dofs = J.x + + np.random.seed(1) + h = np.random.standard_normal(size=dofs.shape) + df = np.sum(J.dJ()*h) + + errf_old = 1e10 + for i in range(9, 14): + eps = 0.5**i + J.x = dofs + eps*h + f1 = J.J() + J.x = dofs - eps*h + f2 = J.J() + errf = np.abs((f1-f2)/(2*eps) - df) + errf_old = errf + else: + # Binormal curvature vanishes in Frenet frame + assert J.J() < 1e-12 + + def subtest_torsion(self, order, centroid): + assert order in [1, None] + curves, currents, ma = get_ncsx_data(Nt_coils=6, ppp=120) + c = curves[0] + + if order == 1: + rotation = FrameRotation(c.quadpoints, order) + rotation.x = np.array([0, 0.1, 0.3]) + rotationShared = FrameRotation(curves[0].quadpoints, order, dofs=rotation.dofs) + assert np.allclose(rotation.x, rotationShared.x) + assert np.allclose(rotation.alpha(c.quadpoints), rotationShared.alpha(c.quadpoints)) + else: + rotation = ZeroRotation(c.quadpoints) + + if centroid: + framedcurve = FramedCurveCentroid(c, rotation) + else: + framedcurve = FramedCurveFrenet(c, rotation) + + J = LPTorsionalStrainPenalty(framedcurve, width=1e-3, p=2, threshold=1e-4) + + dofs = J.x + + np.random.seed(1) + h = np.random.standard_normal(size=dofs.shape) + df = np.sum(J.dJ()*h) + + errf_old = 1e10 + for i in range(9, 14): + eps = 0.5**i + J.x = dofs + eps*h + f1 = J.J() + J.x = dofs - eps*h + f2 = J.J() + errf = np.abs((f1-f2)/(2*eps) - df) + errf_old = errf +