Skip to content

Commit

Permalink
I tentatively may have fixed the jacobian calculation by just directl…
Browse files Browse the repository at this point in the history
…y computing A() in the coil_currents function. Will test this and clean this up tomorrow.
  • Loading branch information
akaptano committed Dec 28, 2024
1 parent d08042b commit 15df425
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from scipy.optimize import minimize
from simsopt.field import BiotSavart, Current, coils_via_symmetries

Check failure on line 11 in examples/3_Advanced/coil_force_optimization/passive_coils_debug.py

View workflow job for this annotation

GitHub Actions / CI (3.9)

Ruff (F401)

examples/3_Advanced/coil_force_optimization/passive_coils_debug.py:11:27: F401 `simsopt.field.BiotSavart` imported but unused

Check failure on line 11 in examples/3_Advanced/coil_force_optimization/passive_coils_debug.py

View workflow job for this annotation

GitHub Actions / CI (3.9)

Ruff (F401)

examples/3_Advanced/coil_force_optimization/passive_coils_debug.py:11:39: F401 `simsopt.field.Current` imported but unused
from simsopt.field import regularization_rect, PSCArray
from simsopt.field import regularization_rect, PSCArray, PSCArray2
from simsopt.field.force import coil_force, coil_torque, coil_net_torques, coil_net_forces, LpCurveForce, \
SquaredMeanForce, \
SquaredMeanTorque, LpCurveTorque
Expand Down Expand Up @@ -175,7 +175,7 @@ def pointData_forces_torques(coils, allcoils, aprimes, bprimes, nturns_list):
return point_data

eval_points = s.gamma().reshape(-1, 3)
psc_array = PSCArray(base_curves, coils_TF, eval_points, a_list, b_list, nfp=s.nfp, stellsym=s.stellsym)
psc_array = PSCArray2(base_curves, coils_TF, eval_points, a_list, b_list, nfp=s.nfp, stellsym=s.stellsym)
# # Calculate average, approximate on-axis B field strength
calculate_on_axis_B(psc_array.biot_savart_TF, s)
psc_array.biot_savart_TF.set_points(eval_points)
Expand Down
157 changes: 151 additions & 6 deletions src/simsopt/field/coil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from simsopt.geo.curvexyzfourier import CurveXYZFourier
from simsopt.geo.curve import RotatedCurve
from simsopt.geo.jit import jit
from .force import coil_currents_pure
from .force import coil_currents_pure, coil_currents_barebones
import simsoptpp as sopp


__all__ = ['Coil', 'JaxCurrent', 'PSCArray',
__all__ = ['Coil', 'JaxCurrent', 'PSCArray', 'PSCArray2',
'Current', 'coils_via_symmetries',
'load_coils_from_makegrid_file',
'apply_symmetries_to_currents', 'apply_symmetries_to_curves',
Expand Down Expand Up @@ -190,9 +190,9 @@ def vjp_setup(self, v_currents):
print(self.dI_dgammas(*args).shape)
dJ_dgammas = np.sum(v_currents[:, None, None, None] * self.dI_dgammas(*args), axis=0)
dJ_dgammadashs = np.sum(v_currents[:, None, None, None] * self.dI_dgammadashs(*args), axis=0)
# dI_dA = np.sum(v_currents[:, None, None, None] * self.dI_dA(*args).reshape(
# dJ_dA = np.sum(v_currents[:, None, None, None] * self.dI_dA(*args).reshape(
# len(self.psc_curves), len(self.psc_curves), -1, 3), axis=0)
dJ_dA = v_currents[:, None, None] * self.dI_dA(*args)
dJ_dA = np.sum(v_currents[:, None, None] * self.dI_dA(*args), axis=0)

#### Note: I think that the A() term also depends on the gammas,
# so there should also be a dA_dgammas (gammas being the PSC gammas),
Expand All @@ -204,8 +204,22 @@ def vjp_setup(self, v_currents):
vjp2 = [c.dgammadash_by_dcoeff_vjp(dJ_dgammadashs[i]) for i, c in enumerate(self.psc_curves)]
# loop over PSC curves here I think, since loop over TF coils is in A_vjp
vjp3 = [self.biot_savart_TF.A_vjp(dJ_dA[i]) for i, c in enumerate(self.psc_curves)]

######## dA_by_dgammas very challenging to obtain -- easiest
# way forward might be to write my own jax calculation of
# A_ext, so that I can just get the derivative of
# dI/dgammas and dI/dgammadashs, and dI/dI, where now I
# need derivatives with respect to all the TF quantities too.
# Let's fix the test_coil test to make sense before we save the
# current state of the code and try to implement this.

# Need dA_by_dgammas below
# print(dJ_dA.shape, self.biot_savart_TF.dA_by_dX().shape)
# dA_dX = dJ_dA @ self.biot_savart_TF.dA_by_dX()
# print(dA_dX.shape)
# vjp4 = [c.dgamma_by_dcoeff_vjp() for i, c in enumerate(self.psc_curves)]
self.biot_savart_TF.set_points(self.eval_points)
return sum(vjp1 + vjp2 + vjp3)
return sum(vjp1 + vjp2 + vjp3) # + vjp4)

def recompute_currents(self):
gammas = np.array([c.gamma() for c in self.psc_curves])
Expand All @@ -224,7 +238,138 @@ def recompute_currents(self):
c.current.set_dofs(currents[i])
# print('currents2 = ', [c.current.get_value() for c in self.coils])
self.biot_savart_TF.set_points(self.eval_points)
# return currents


class PSCArray2():
"""
A class that represents an array of passive superconducting
coils (PSCs).
"""
def __init__(self, base_psc_curves, coils_TF, eval_points, a_list, b_list, nfp=1, stellsym=False, downsample=1, cross_section='circular', dofs=None, **kwargs):
from .biotsavart import BiotSavart
self.base_psc_curves = base_psc_curves # not the symmetrized ones
self.nfp = nfp
self.stellsym = stellsym
psc_curves = apply_symmetries_to_curves(base_psc_curves, nfp, stellsym)
self.coils_TF = coils_TF
ncoils = len(psc_curves)

# save original TF evaluation points!
self.biot_savart_TF = BiotSavart(coils_TF)
self.eval_points = eval_points
self.biot_savart_TF.set_points(eval_points)
self.a_list = a_list[0] * np.ones(ncoils)
self.b_list = b_list[0] * np.ones(ncoils)
self.downsample = downsample
self.cross_section = cross_section

# Uses jacrev since # of inputs >> # of outputs
args = {"static_argnums": (5,)}
self.I_jax = jit(
lambda gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample:
coil_currents_barebones(gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, self.a_list, self.b_list, downsample, cross_section),
**args
)
self.dI_dgammas = jit(
lambda gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample:
jacrev(self.I_jax, argnums=0)(gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample),
**args
)
self.dI_dgammadashs = jit(
lambda gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample:
jacrev(self.I_jax, argnums=1)(gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample),
**args
)
self.dI_dgammasTF = jit(
lambda gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample:
jacrev(self.I_jax, argnums=2)(gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample),
**args
)
self.dI_dgammadashsTF = jit(
lambda gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample:
jacrev(self.I_jax, argnums=3)(gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample),
**args
)
self.dI_dcurrentsTF = jit(
lambda gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample:
jacrev(self.I_jax, argnums=4)(gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample),
**args
)
gammas = np.array([c.gamma() for c in psc_curves])
# self.biot_savart_TF.set_points(gammas[:, ::self.downsample, :].reshape(-1, 3))
# A_ext = np.array(self.biot_savart_TF.A())
gammadashs = np.array([c.gammadash() for c in psc_curves])
gammas_TF = np.array([c.curve.gamma() for c in self.coils_TF])
gammadashs_TF = np.array([c.curve.gamma() for c in self.coils_TF])
currents_TF = np.array([c.current.get_value() for c in self.coils_TF])
args = [
gammas,
gammadashs,
gammas_TF,
gammadashs_TF,
currents_TF,
self.downsample
]
currents = self.I_jax(*args)

psc_currents = [Current(currents[i] * 1e-6) * 1e6 for i in range(ncoils)]
self.base_psc_currents = psc_currents[:ncoils // (int(stellsym) + 1) // nfp]
[c.fix_all() for c in self.base_psc_currents] # Fix all the current dofs which are fake anyways
self.coils = coils_via_symmetries(self.base_psc_curves, self.base_psc_currents, nfp, stellsym)
self.psc_curves = [c.curve for c in self.coils]
self.biot_savart = BiotSavart(self.coils, self)
self.biot_savart_total = self.biot_savart + self.biot_savart_TF
self.biot_savart_total.set_points(self.eval_points)

def vjp_setup(self, v_currents):
gammas = np.array([c.gamma() for c in self.psc_curves])
gammadashs = np.array([c.gammadash() for c in self.psc_curves])
# self.biot_savart_TF.set_points(gammas[:, ::self.downsample, :].reshape(-1, 3))
# A_ext = self.biot_savart_TF.A()
gammas_TF = np.array([c.curve.gamma() for c in self.coils_TF])
gammadashs_TF = np.array([c.curve.gamma() for c in self.coils_TF])
currents_TF = np.array([c.current.get_value() for c in self.coils_TF])
args = [
gammas,
gammadashs,
gammas_TF,
gammadashs_TF,
currents_TF,
self.downsample
]
dJ_dgammas = np.sum(v_currents[:, None, None, None] * self.dI_dgammas(*args), axis=0)
dJ_dgammadashs = np.sum(v_currents[:, None, None, None] * self.dI_dgammadashs(*args), axis=0)
dJ_dgammas2 = np.sum(v_currents[:, None, None, None] * self.dI_dgammasTF(*args), axis=0)
dJ_dgammadashs2 = np.sum(v_currents[:, None, None, None] * self.dI_dgammadashsTF(*args), axis=0)
dJ_dcurrents2 = np.sum(v_currents[:, None, None, None] * self.dI_dcurrentsTF(*args), axis=0)
vjp1 = [c.dgamma_by_dcoeff_vjp(dJ_dgammas[i]) for i, c in enumerate(self.psc_curves)]
vjp2 = [c.dgammadash_by_dcoeff_vjp(dJ_dgammadashs[i]) for i, c in enumerate(self.psc_curves)]
vjp3 = [c.curve.dgamma_by_dcoeff_vjp(dJ_dgammas2[i]) for i, c in enumerate(self.coils_TF)]
vjp4 = [c.curve.dgammadash_by_dcoeff_vjp(dJ_dgammadashs2[i]) for i, c in enumerate(self.coils_TF)]
vjp5 = [c.current.vjp(dJ_dcurrents2[i]) for i, c in enumerate(self.coils_TF)]
self.biot_savart_TF.set_points(self.eval_points)
return sum(vjp1 + vjp2 + vjp3 + vjp4 + vjp5)

def recompute_currents(self):
gammas = np.array([c.gamma() for c in self.psc_curves])
gammadashs = np.array([c.gammadash() for c in self.psc_curves])
gammas_TF = np.array([c.curve.gamma() for c in self.coils_TF])
gammadashs_TF = np.array([c.curve.gamma() for c in self.coils_TF])
currents_TF = np.array([c.current.get_value() for c in self.coils_TF])
args = [
gammas,
gammadashs,
gammas_TF,
gammadashs_TF,
currents_TF,
self.downsample
]
currents = self.I_jax(*args)
# print('currents = ', currents)
for i, c in enumerate(self.coils):
c.current.set_dofs(currents[i])
# print('currents2 = ', [c.current.get_value() for c in self.coils])
# self.biot_savart_TF.set_points(self.eval_points)

class ScaledCurrent(sopp.CurrentBase, CurrentBase):
"""
Expand Down
19 changes: 19 additions & 0 deletions src/simsopt/field/force.py
Original file line number Diff line number Diff line change
Expand Up @@ -1970,6 +1970,9 @@ def coil_coil_inductances_inv_pure(gammas, gammadashs, a_list, b_list, downsampl
def coil_currents_pure(gammas, gammadashs, A_ext, a_list, b_list, downsample, cross_section):
return -coil_coil_inductances_inv_pure(gammas, gammadashs, a_list, b_list, downsample, cross_section) @ net_ext_fluxes_pure(gammadashs, A_ext, downsample)

def coil_currents_barebones(gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, a_list, b_list, downsample, cross_section):
return -coil_coil_inductances_inv_pure(gammas, gammadashs, a_list, b_list, downsample, cross_section) @ net_fluxes_pure(gammas, gammadashs, gammas_TF, gammadashs_TF, currents_TF, downsample)

def tve_pure(gamma, gammadash, gammas2, gammadashs2, current, currents2,
quadpoints, quadpoints2, a, b, downsample, cross_section):
r"""Pure function for minimizing the total vacuum energy on a coil.
Expand Down Expand Up @@ -2105,6 +2108,22 @@ def dJ(self):

return_fn_map = {'J': J, 'dJ': dJ}

def net_fluxes_pure(gammas, gammadashs, gammas2, gammadashs2, currents2, downsample):
"""
"""
# Downsample if desired
gammas = gammas[:, ::downsample, :]
gammadashs = gammadashs[:, ::downsample, :]
gammas2 = gammas2[:, ::downsample, :]
gammadashs2 = gammadashs2[:, ::downsample, :]
rij_norm = jnp.linalg.norm(gammas[:, :, None, None, :] - gammas2[None, None, :, :, :], axis=-1)
# sum over the currents, and sum over the biot savart integral
A_ext = jnp.sum(currents2[None, None, :, None] * jnp.sum(gammadashs2[None, None, :, :, :] / rij_norm[:, :, :, :, None], axis=-2), axis=-2)
# print(A_ext.shape, gammadashs.shape, gammadashs2.shape, rij_norm.shape, currents2.shape)
# Now sum over the PSC coil loops
return 1e-7 * jnp.sum(jnp.sum(A_ext * gammadashs, axis=-1), axis=-1) / jnp.shape(gammadashs)[1]


def net_ext_fluxes_pure(gammadashs, A_ext, downsample):
r""" Pure function to compute the net fluxes through a set of closed loop,
Expand Down

0 comments on commit 15df425

Please sign in to comment.