From 3c74ef78a2ad7ca1c6d633c47f1e98458e7ba006 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Wed, 12 Feb 2025 11:23:57 -0800 Subject: [PATCH] [Hack] Make `Drift` Fast --- src/elements/Drift.H | 32 +++++++++++++++++++++----------- src/elements/mixin/beamoptic.H | 4 +++- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/elements/Drift.H b/src/elements/Drift.H index fadba1e01..27894d7b9 100644 --- a/src/elements/Drift.H +++ b/src/elements/Drift.H @@ -72,6 +72,18 @@ namespace impactx::elements /** Push all particles */ using BeamOptic::operator(); + void calc_constants (RefPart const & refpart) + { + using namespace amrex::literals; // for _rt and _prt + + // length of the current slice + m_slice_ds = m_ds / nslice(); + + amrex::ParticleReal const pt_ref = refpart.pt; + // find beta*gamma^2 + m_betgam2 = std::pow(pt_ref, 2) - 1.0_prt; + } + /** This is a drift functor, so that a variable of this type can be used like a drift function. * * @param x particle position in x @@ -92,7 +104,7 @@ namespace impactx::elements amrex::ParticleReal & AMREX_RESTRICT py, amrex::ParticleReal & AMREX_RESTRICT pt, uint64_t & AMREX_RESTRICT idcpu, - RefPart const & refpart + RefPart const & ) const { using namespace amrex::literals; // for _rt and _prt @@ -108,19 +120,12 @@ namespace impactx::elements amrex::ParticleReal pyout = py; amrex::ParticleReal ptout = pt; - // length of the current slice - amrex::ParticleReal const slice_ds = m_ds / nslice(); - - // access reference particle values to find beta*gamma^2 - amrex::ParticleReal const pt_ref = refpart.pt; - amrex::ParticleReal const betgam2 = std::pow(pt_ref, 2) - 1.0_prt; - // advance position and momentum (drift) - xout = x + slice_ds * px; + xout = x + m_slice_ds * px; // pxout = px; - yout = y + slice_ds * py; + yout = y + m_slice_ds * py; // pyout = py; - tout = t + (slice_ds/betgam2) * pt; + tout = t + (m_slice_ds/m_betgam2) * pt; // ptout = pt; // assign updated values @@ -202,6 +207,11 @@ namespace impactx::elements return R; } + + private: + // constants + amrex::ParticleReal m_slice_ds; //! m_ds / nslice(); + amrex::ParticleReal m_betgam2; //! beta*gamma^2 }; } // namespace impactx diff --git a/src/elements/mixin/beamoptic.H b/src/elements/mixin/beamoptic.H index 534cada8f..cf78807d6 100644 --- a/src/elements/mixin/beamoptic.H +++ b/src/elements/mixin/beamoptic.H @@ -20,6 +20,7 @@ namespace impactx::elements { + struct Drift; struct Quad; } @@ -137,7 +138,8 @@ namespace detail uint64_t* const AMREX_RESTRICT part_idcpu = pti.GetStructOfArrays().GetIdCPUData().dataPtr(); - if constexpr (std::is_same_v, elements::Quad>) + if constexpr (std::is_same_v, elements::Quad> || + std::is_same_v, elements::Drift>) { element.calc_constants(ref_part); }