Skip to content

Commit

Permalink
added a draft of the step on device method to the stepper
Browse files Browse the repository at this point in the history
  • Loading branch information
hightower8083 committed Jul 16, 2024
1 parent 96c3cb2 commit fef0ebd
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions axiprop/steppers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,58 @@ def step(self, u, dz, overwrite=False, show_progress=False):

return u_step

def step_on_device(self, u, dz, overwrite=False, show_progress=False):
"""
Propagate wave `u` over the distance `dz`.
Parameters
----------
u: 2darray of complex or double
Spectral-radial distribution of the field to propagate.
dz: float (m)
Distance over which wave should be propagated.
Returns
-------
u: 2darray of complex or double
Overwritten array with the propagated field.
"""
assert u.dtype == self.dtype

if not overwrite:
u_step = self.bcknd.zeros((self.Nkz, *self.shape_trns_new),
dtype=u.dtype)
else:
u_step = u

if tqdm_available and show_progress:
pbar = tqdm(total=self.Nkz, bar_format=bar_format)

for ikz in range(self.Nkz):
if self.kz[ikz] <= 0:
continue

# assuming data is already on device
#self.u_loc = self.bcknd.to_device(u[ikz,:].copy())
self.u_loc = u[ikz,:].copy()
self.TST()

phase_loc = self.kz[ikz]**2 - self.kr2
phase_loc = self.bcknd.sqrt( (phase_loc>=0.)*phase_loc )
self.u_ht *= self.bcknd.exp( 1j * dz * phase_loc )

self.iTST()
# u_step[ikz] = self.bcknd.to_host(self.u_iht)
u_step[ikz] = self.u_iht.copy()
if tqdm_available and show_progress:
pbar.update(1)

if tqdm_available and show_progress:
pbar.close()

return u_step

def steps(self, u, dz=None, z_axis=None, show_progress=True):
"""
Propagate wave `u` over the multiple steps.
Expand Down

0 comments on commit fef0ebd

Please sign in to comment.