Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JP-2581: fix to allow multiprocessing for jump #87

Merged
merged 4 commits into from
May 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 50 additions & 31 deletions src/stcal/jump/jump.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def detect_jumps(frames_per_group, data, gdq, pdq, err,
gain_2d, readnoise_2d, rejection_thresh,
three_grp_thresh, four_grp_thresh, max_jump_to_flag_neighbors,
three_grp_thresh, four_grp_thresh, max_cores, max_jump_to_flag_neighbors,
min_jump_to_flag_neighbors, flag_4_neighbors, dqflags):
"""
This is the high-level controlling routine for the jump detection process.
Expand Down Expand Up @@ -64,6 +64,11 @@ def detect_jumps(frames_per_group, data, gdq, pdq, err,
four_grp_thresh : float
cosmic ray sigma rejection threshold for ramps having 4 groups

max_cores: str
Maximum number of cores to use for multiprocessing. Available choices
are 'none' (which will create one process), 'quarter', 'half', 'all'
(of availble cpu cores).

max_jump_to_flag_neighbors : float
value in units of sigma that sets the upper limit for flagging of
neighbors. Any jump above this cutoff will not have its neighbors
Expand All @@ -82,6 +87,8 @@ def detect_jumps(frames_per_group, data, gdq, pdq, err,
A dictionary with at least the following keywords:
DO_NOT_USE, SATURATED, JUMP_DET, NO_GAIN_VALUE, GOOD



Returns
-------
gdq : int, 4D array
Expand Down Expand Up @@ -124,36 +131,17 @@ def detect_jumps(frames_per_group, data, gdq, pdq, err,
dtype=np.uint8)
row_below_gdq = np.zeros((n_ints, n_groups, n_cols), dtype=np.uint8)

# 05/18/21 - When multiprocessing is enabled, the input data cube is split
# into a number of row slices, based on the number or avalable cores.
# Multiprocessing has been disabled for now, so the nunber of slices
# is here set to 1. I'm leaving the related code in to ease the eventual
# re-enablement of this code.
n_slices = 1

yinc = int(n_rows / n_slices)
slices = []
# Slice up data, gdq, readnoise_2d into slices
# Each element of slices is a tuple of
# (data, gdq, readnoise_2d, rejection_thresh, three_grp_thresh,
# four_grp_thresh, nframes)
for i in range(n_slices - 1):
slices.insert(i, (data[:, :, i * yinc:(i + 1) * yinc, :],
gdq[:, :, i * yinc:(i + 1) * yinc, :],
readnoise_2d[i * yinc:(i + 1) * yinc, :],
rejection_thresh, three_grp_thresh, four_grp_thresh,
frames_per_group, flag_4_neighbors,
max_jump_to_flag_neighbors,
min_jump_to_flag_neighbors))

# last slice get the rest
slices.insert(n_slices - 1, (data[:, :, (n_slices - 1) * yinc:n_rows, :],
gdq[:, :, (n_slices - 1) * yinc:n_rows, :],
readnoise_2d[(n_slices - 1) * yinc:n_rows, :],
rejection_thresh, three_grp_thresh,
four_grp_thresh, frames_per_group,
flag_4_neighbors, max_jump_to_flag_neighbors,
min_jump_to_flag_neighbors))
# figure out how many slices to make based on 'max_cores'

max_available = multiprocessing.cpu_count()
if max_cores.lower() == 'none':
n_slices = 1
elif max_cores == 'quarter':
n_slices = max_available // 4 or 1
elif max_cores == 'half':
n_slices = max_available // 2 or 1
elif max_cores == 'all':
n_slices = max_available

if n_slices == 1:
gdq, row_below_dq, row_above_dq = \
Expand All @@ -164,6 +152,37 @@ def detect_jumps(frames_per_group, data, gdq, pdq, err,

elapsed = time.time() - start
else:
yinc = int(n_rows / n_slices)
slices = []
# Slice up data, gdq, readnoise_2d into slices
# Each element of slices is a tuple of
# (data, gdq, readnoise_2d, rejection_thresh, three_grp_thresh,
# four_grp_thresh, nframes)

# must copy arrays here, find_crs will make copies but if slices
# are being passed in for multiprocessing then the original gdq will be
# modified unless copied beforehand
gdq = gdq.copy()
data = data.copy()
copy_arrs = False # we dont need to copy arrays again in find_crs

for i in range(n_slices - 1):
slices.insert(i, (data[:, :, i * yinc:(i + 1) * yinc, :],
gdq[:, :, i * yinc:(i + 1) * yinc, :],
readnoise_2d[i * yinc:(i + 1) * yinc, :],
rejection_thresh, three_grp_thresh, four_grp_thresh,
frames_per_group, flag_4_neighbors,
max_jump_to_flag_neighbors,
min_jump_to_flag_neighbors, dqflags, copy_arrs))

# last slice get the rest
slices.insert(n_slices - 1, (data[:, :, (n_slices - 1) * yinc:n_rows, :],
gdq[:, :, (n_slices - 1) * yinc:n_rows, :],
readnoise_2d[(n_slices - 1) * yinc:n_rows, :],
rejection_thresh, three_grp_thresh,
four_grp_thresh, frames_per_group,
flag_4_neighbors, max_jump_to_flag_neighbors,
min_jump_to_flag_neighbors, dqflags, copy_arrs))
log.info("Creating %d processes for jump detection " % n_slices)
pool = multiprocessing.Pool(processes=n_slices)
# Starts each slice in its own process. Starmap allows more than one
Expand Down
13 changes: 10 additions & 3 deletions src/stcal/jump/twopoint_difference.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
def find_crs(dataa, group_dq, read_noise, normal_rej_thresh,
two_diff_rej_thresh, three_diff_rej_thresh, nframes,
flag_4_neighbors, max_jump_to_flag_neighbors,
min_jump_to_flag_neighbors, dqflags):
min_jump_to_flag_neighbors, dqflags, copy_arrs=True):

"""
Find CRs/Jumps in each integration within the input data array. The input
Expand Down Expand Up @@ -54,6 +54,10 @@ def find_crs(dataa, group_dq, read_noise, normal_rej_thresh,
neighbors (marginal detections). Any primary jump below this value will
not have its neighbors flagged.

copy_arrs : bool
Flag for making internal copies of the arrays so the input isn't modified,
defaults to True.

Returns
-------
gdq : int, 4D array
Expand All @@ -68,8 +72,11 @@ def find_crs(dataa, group_dq, read_noise, normal_rej_thresh,
"""

# copy data and group DQ array
dataa = dataa.copy()
gdq = group_dq.copy()
if copy_arrs:
dataa = dataa.copy()
gdq = group_dq.copy()
hbushouse marked this conversation as resolved.
Show resolved Hide resolved
else:
gdq = group_dq

# Get data characteristics
nints, ngroups, nrows, ncols = dataa.shape
Expand Down