Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Linear Elements Fast #850

Open
wants to merge 30 commits into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion src/elements/Aperture.H
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ namespace impactx::elements
/** Push all particles */
using BeamOptic::operator();

/** Compute and cache the constants for the push.
*
* In particular, used to pre-compute and cache variables that are
* independent of the individually tracked particle.
*
* @param refpart reference particle (unused)
*/
void compute_constants ([[maybe_unused]] RefPart const & refpart)
{
Alignment::compute_constants(refpart);
}

/** This is an aperture functor, so that a variable of this type can be used like an
* aperture function.
*
Expand All @@ -114,7 +126,7 @@ namespace impactx::elements
amrex::ParticleReal & AMREX_RESTRICT py,
[[maybe_unused]] amrex::ParticleReal & AMREX_RESTRICT pt,
uint64_t & AMREX_RESTRICT idcpu,
[[maybe_unused]] RefPart const & refpart
[[maybe_unused]] RefPart const & AMREX_RESTRICT refpart
) const
{
using namespace amrex::literals; // for _rt and _prt
Expand Down
59 changes: 42 additions & 17 deletions src/elements/Buncher.H
Original file line number Diff line number Diff line change
Expand Up @@ -64,47 +64,66 @@ namespace impactx::elements
/** Push all particles */
using BeamOptic::operator();

/** Compute and cache the constants for the push.
*
* In particular, used to pre-compute and cache variables that are
* independent of the individually tracked particle.
*
* @param refpart reference particle
*/
void compute_constants (RefPart const & refpart)
{
using namespace amrex::literals; // for _rt and _prt

Alignment::compute_constants(refpart);

// find beta*gamma^2
amrex::ParticleReal const betgam2 = std::pow(refpart.pt, 2) - 1.0_prt;

m_neg_kV = -m_k * m_V;
m_kV_r2bg2 = -m_neg_kV / (2.0_prt * betgam2);
}

/** This is a buncher functor, so that a variable of this type can be used like a
* buncher function.
*
* The @see compute_constants method must be called before pushing particles through this operator.
*
* @param x particle position in x
* @param y particle position in y
* @param t particle position in t
* @param px particle momentum in x
* @param py particle momentum in y
* @param pt particle momentum in t
* @param idcpu particle global index (unused)
* @param refpart reference particle
* @param refpart reference particle (unused)
*/
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void operator() (
amrex::ParticleReal & AMREX_RESTRICT x,
amrex::ParticleReal & AMREX_RESTRICT y,
amrex::ParticleReal & AMREX_RESTRICT t,
amrex::ParticleReal & AMREX_RESTRICT px,
amrex::ParticleReal & AMREX_RESTRICT py,
amrex::ParticleReal & AMREX_RESTRICT pt,
[[maybe_unused]] uint64_t & AMREX_RESTRICT idcpu,
RefPart const & refpart) const {

amrex::ParticleReal & AMREX_RESTRICT x,
amrex::ParticleReal & AMREX_RESTRICT y,
amrex::ParticleReal & AMREX_RESTRICT t,
amrex::ParticleReal & AMREX_RESTRICT px,
amrex::ParticleReal & AMREX_RESTRICT py,
amrex::ParticleReal & AMREX_RESTRICT pt,
[[maybe_unused]] uint64_t & AMREX_RESTRICT idcpu,
[[maybe_unused]] RefPart const & AMREX_RESTRICT refpart
) const
{
using namespace amrex::literals; // for _rt and _prt

// shift due to alignment errors of the element
shift_in(x, y, px, py);

// access reference particle values to find (beta*gamma)^2
amrex::ParticleReal const pt_ref = refpart.pt;
amrex::ParticleReal const betgam2 = std::pow(pt_ref, 2) - 1.0_prt;

// intialize output values of momenta
amrex::ParticleReal pxout = px;
amrex::ParticleReal pyout = py;
amrex::ParticleReal ptout = pt;

// advance position and momentum
pxout = px + m_k*m_V/(2.0_prt*betgam2)*x;
pyout = py + m_k*m_V/(2.0_prt*betgam2)*y;
ptout = pt - m_k*m_V*t;
pxout = px + m_kV_r2bg2 * x;
pyout = py + m_kV_r2bg2 * y;
ptout = pt + m_neg_kV * t;

// assign updated momenta
px = pxout;
Expand Down Expand Up @@ -136,6 +155,12 @@ namespace impactx::elements

amrex::ParticleReal m_V; //! normalized (max) RF voltage drop.
amrex::ParticleReal m_k; //! RF wavenumber in 1/m.

private:
// constants that are independent of the individually tracked particle,
// see: compute_constants() to refresh
amrex::ParticleReal m_neg_kV; //! -m_k*m_V
amrex::ParticleReal m_kV_r2bg2; //! m_k*m_V/(2.0_prt*betgam2)
};

} // namespace impactx
Expand Down
137 changes: 76 additions & 61 deletions src/elements/CFbend.H
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,68 @@ namespace impactx::elements
/** Push all particles */
using BeamOptic::operator();

/** Compute and cache the constants for the push.
*
* In particular, used to pre-compute and cache variables that are
* independent of the individually tracked particle.
*
* @param refpart reference particle
*/
void compute_constants (RefPart const & refpart)
{
using namespace amrex::literals; // for _rt and _prt

Alignment::compute_constants(refpart);

// length of the current slice
amrex::ParticleReal const slice_ds = m_ds / nslice();

// find beta*gamma^2, beta
amrex::ParticleReal const betgam2 = std::pow(refpart.pt, 2) - 1_prt;
amrex::ParticleReal const bet = refpart.beta();
amrex::ParticleReal const ibetgam2 = 1_prt / betgam2;
amrex::ParticleReal const b2rc2 = std::pow(bet, 2) * std::pow(m_rc, 2);

// update horizontal and longitudinal phase space variables
amrex::ParticleReal const gx = m_k + std::pow(m_rc,-2);
amrex::ParticleReal const omega_x = std::sqrt(std::abs(gx));

// update vertical phase space variables
amrex::ParticleReal const gy = -m_k;
amrex::ParticleReal const omega_y = std::sqrt(std::abs(gy));

// trigonometry
auto const [sinx, cosx] = amrex::Math::sincos(omega_x * slice_ds);
amrex::ParticleReal const sinhx = std::sinh(omega_x * slice_ds);
amrex::ParticleReal const coshx = std::cosh(omega_x * slice_ds);
auto const [siny, cosy] = amrex::Math::sincos(omega_y * slice_ds);
amrex::ParticleReal const sinhy = std::sinh(omega_y * slice_ds);
amrex::ParticleReal const coshy = std::cosh(omega_y * slice_ds);

amrex::ParticleReal const igbrc = 1_prt / (gx * bet * m_rc);
amrex::ParticleReal const iobrc = 1_prt / (omega_x * bet * m_rc);
amrex::ParticleReal const igobr = 1_prt / (gx * omega_x * b2rc2);

m_R11 = gx > 0_prt ? cosx : coshx;
m_R12 = gx > 0_prt ? sinx / omega_x : sinhx / omega_x;
m_R16 = gx > 0_prt ? -(1_prt - cosx) * igbrc : -(1_prt - coshx) * igbrc;
m_R21 = gx > 0_prt ? -omega_x * sinx : omega_x * sinhx;
m_R22 = gx > 0_prt ? cosx : coshx;
m_R26 = gx > 0_prt ? -sinx * iobrc : -sinhx * iobrc;
m_R33 = gy > 0_prt ? cosy : coshy;
m_R34 = gy > 0_prt ? siny / omega_y : sinhy / omega_y;
m_R43 = gy > 0_prt ? -omega_y * siny : omega_y * sinhy;
m_R44 = gy > 0_prt ? cosy : coshy;
m_R51 = gx > 0_prt ? sinx * iobrc : sinhx * iobrc;
m_R52 = gx > 0_prt ? (1_prt - cosx) * igbrc : (1_prt - coshx) * igbrc;
m_R56 = gx > 0_prt ?
slice_ds * ibetgam2 + (sinx - omega_x * slice_ds) * igobr :
slice_ds * ibetgam2 + (sinhx - omega_x * slice_ds) * igobr;
}

/** This is a cfbend functor, so that a variable of this type can be used like a cfbend function.
*
* The @see compute_constants method must be called before pushing particles through this operator.
*
* @param x particle position in x
* @param y particle position in y
Expand All @@ -90,7 +151,7 @@ namespace impactx::elements
* @param py particle momentum in y
* @param pt particle momentum in t
* @param idcpu particle global index
* @param refpart reference particle
* @param refpart reference particle (unused)
*/
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
void operator() (
Expand All @@ -101,7 +162,7 @@ namespace impactx::elements
amrex::ParticleReal & AMREX_RESTRICT py,
amrex::ParticleReal & AMREX_RESTRICT pt,
uint64_t & AMREX_RESTRICT idcpu,
RefPart const & refpart
[[maybe_unused]] RefPart const & AMREX_RESTRICT refpart
) const
{
using namespace amrex::literals; // for _rt and _prt
Expand All @@ -119,68 +180,17 @@ namespace impactx::elements
amrex::ParticleReal pyout = py;
amrex::ParticleReal ptout = pt;

// length of the current slice
amrex::ParticleReal const slice_ds = m_ds / nslice();

// access reference particle values to find beta*gamma^2
amrex::ParticleReal const pt_ref = refpart.pt;
amrex::ParticleReal const betgam2 = std::pow(pt_ref, 2) - 1.0_prt;
amrex::ParticleReal const bet = std::sqrt(betgam2/(1.0_prt + betgam2));

// update horizontal and longitudinal phase space variables
amrex::ParticleReal const gx = m_k + std::pow(m_rc,-2);
amrex::ParticleReal const omegax = std::sqrt(std::abs(gx));

if(gx > 0.0) {
// calculate expensive terms once
auto const [sinx, cosx] = amrex::Math::sincos(omegax * slice_ds);
amrex::ParticleReal const r56 = slice_ds/betgam2
+ (sinx - omegax*slice_ds)/(gx*omegax * std::pow(bet,2) * std::pow(m_rc,2));

// advance position and momentum (focusing)
x = cosx*xout + sinx/omegax*px - (1.0_prt - cosx)/(gx*bet*m_rc)*pt;
pxout = -omegax*sinx*xout + cosx*px - sinx/(omegax*bet*m_rc)*pt;

t = sinx/(omegax*bet*m_rc)*xout + (1.0_prt - cosx)/(gx*bet*m_rc)*px
+ tout + r56*pt;
ptout = pt;
} else {
// calculate expensive terms once
amrex::ParticleReal const sinhx = std::sinh(omegax * slice_ds);
amrex::ParticleReal const coshx = std::cosh(omegax * slice_ds);
amrex::ParticleReal const r56 = slice_ds/betgam2
+ (sinhx - omegax*slice_ds)/(gx*omegax * std::pow(bet,2) * std::pow(m_rc,2));

// advance position and momentum (defocusing)
x = coshx*xout + sinhx/omegax*px - (1.0_prt - coshx)/(gx*bet*m_rc)*pt;
pxout = omegax*sinhx*xout + coshx*px - sinhx/(omegax*bet*m_rc)*pt;

t = sinhx/(omegax*bet*m_rc)*xout + (1.0_prt - coshx)/(gx*bet*m_rc)*px
+ tout + r56*pt;
ptout = pt;
}
// advance position and momentum: (de)focusing
x = m_R11 * xout + m_R12 * px + m_R16 * pt;
pxout = m_R21 * xout + m_R22 * px + m_R26 * pt;
t = m_R51 * xout + m_R52 * px + m_R56 * pt + tout;
ptout = pt;

// update vertical phase space variables
amrex::ParticleReal const gy = -m_k;
amrex::ParticleReal const omegay = std::sqrt(std::abs(gy));

if(gy > 0.0) {
// calculate expensive terms once
auto const [siny, cosy] = amrex::Math::sincos(omegay * slice_ds);

// advance position and momentum (focusing)
y = cosy*yout + siny/omegay*py;
pyout = -omegay*siny*yout + cosy*py;

} else {
// calculate expensive terms once
amrex::ParticleReal const sinhy = std::sinh(omegay * slice_ds);
amrex::ParticleReal const coshy = std::cosh(omegay * slice_ds);

// advance position and momentum (defocusing)
y = coshy*yout + sinhy/omegay*py;
pyout = omegay*sinhy*yout + coshy*py;
}
// advance position and momentum (de)focusing
y = m_R33 * yout + m_R34 * py;
pyout = m_R43 * yout + m_R44 * py;

// assign updated momenta
px = pxout;
Expand Down Expand Up @@ -257,6 +267,11 @@ namespace impactx::elements

amrex::ParticleReal m_rc; //! bend radius in m
amrex::ParticleReal m_k; //! quadrupole strength in m^(-2)

private:
// constants that are independent of the individually tracked particle,
// see: compute_constants() to refresh
amrex::ParticleReal m_R11, m_R12, m_R16, m_R21, m_R22, m_R26, m_R33, m_R34, m_R43, m_R44, m_R51, m_R52, m_R56;
};

} // namespace impactx
Expand Down
Loading
Loading