Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Al-522: Adding a DQ flags parameter to ramp_fit to make them pipeline indepedent #25

Merged
merged 8 commits into from
May 25, 2021
9 changes: 9 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
0.2.2 (Unreleased)
==================

ramp_fitting
------------

- Added DQ flag parameter to `ramp_fit` [#25]


0.2.1 (2021-05-20)
==================

Expand Down
20 changes: 15 additions & 5 deletions src/stcal/ramp_fitting/constants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
DO_NOT_USE = 2**0 # Bad pixel. Do not use.
SATURATED = 2**1 # Pixel saturated during exposure.
JUMP_DET = 2**2 # Jump detected during exposure
NO_GAIN_VALUE = 2**19 # Gain cannot be measured
UNRELIABLE_SLOPE = 2**24 # Slope variance large (i.e., noisy pixel)
dqflags = {
"DO_NOT_USE": None,
"SATURATED": None,
"JUMP_DET": None,
"NO_GAIN_VALUE": None,
"UNRELIABLE_SLOPE": None,
}
Comment on lines +1 to +7
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If dqflags is a required parameter to pass to ramp_fit now, what's the purpose of having flags defined as constants in a separate module? Shouldn't all the flag mnemonics be happening in ramp_fit and then passed along directly to any function it calls?

Having these defined in a module constants also implies they are constants, which clearly they no longer are if they can be changed.


nden marked this conversation as resolved.
Show resolved Hide resolved

def update_dqflags(input_flags):
dqflags["DO_NOT_USE"] = input_flags["DO_NOT_USE"]
dqflags["SATURATED"] = input_flags["SATURATED"]
dqflags["JUMP_DET"] = input_flags["JUMP_DET"]
dqflags["NO_GAIN_VALUE"] = input_flags["NO_GAIN_VALUE"]
dqflags["UNRELIABLE_SLOPE"] = input_flags["UNRELIABLE_SLOPE"]
31 changes: 15 additions & 16 deletions src/stcal/ramp_fitting/ols_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,6 @@
log = logging.getLogger(__name__)
log.setLevel(logging.DEBUG)

# TODO Should figure out a better way to do this
DO_NOT_USE = constants.DO_NOT_USE
SATURATED = constants.SATURATED
JUMP_DET = constants.JUMP_DET
UNRELIABLE_SLOPE = constants.UNRELIABLE_SLOPE

BUFSIZE = 1024 * 300000 # 300Mb cache size for data section


Expand Down Expand Up @@ -653,12 +647,13 @@ def discard_miri_groups(input_model):
data = input_model.data
err = input_model.err
groupdq = input_model.groupdq
jump_flag = constants.dqflags["JUMP_DET"]

n_int, ngroups, nrows, ncols = data.shape

num_bad_slices = 0 # number of initial groups that are all DO_NOT_USE

while np.all(np.bitwise_and(groupdq[:, 0, :, :], DO_NOT_USE)):
while np.all(np.bitwise_and(groupdq[:, 0, :, :], constants.dqflags["DO_NOT_USE"])):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jump_flag is a pointer to the mnemonic `dqflags["JUMP_DET"], but here you use the mnemonic itself. Consistency would be good. I.e. either always use long-form mnemonics directly from the dict, or always use a variable, preferably defined at top of function.

num_bad_slices += 1
ngroups -= 1

Expand All @@ -673,11 +668,11 @@ def discard_miri_groups(input_model):
# Where the initial group of the just-truncated data is a cosmic ray,
# remove the JUMP_DET flag from the group dq for those pixels so
# that those groups will be included in the fit.
wh_cr = np.where(np.bitwise_and(groupdq[:, 0, :, :], JUMP_DET))
wh_cr = np.where(np.bitwise_and(groupdq[:, 0, :, :], jump_flag))
num_cr_1st = len(wh_cr[0])

for ii in range(num_cr_1st):
groupdq[wh_cr[0][ii], 0, wh_cr[1][ii], wh_cr[2][ii]] -= JUMP_DET
groupdq[wh_cr[0][ii], 0, wh_cr[1][ii], wh_cr[2][ii]] -= jump_flag

if num_bad_slices > 0:
data = data[:, num_bad_slices:, :, :]
Expand All @@ -689,7 +684,7 @@ def discard_miri_groups(input_model):
# in the while loop above, ngroups would have been set to 0, and Nones
# would have been returned. If execution has gotten here, there must
# be at least 1 remaining group that is not all flagged.
if np.all(np.bitwise_and(groupdq[:, -1, :, :], DO_NOT_USE)):
if np.all(np.bitwise_and(groupdq[:, -1, :, :], constants.dqflags["DO_NOT_USE"])):
ngroups -= 1

# Check if there are remaining groups before accessing data
Expand Down Expand Up @@ -782,6 +777,9 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting):
Rate array
"""

sat_flag = constants.dqflags["SATURATED"]
jump_flag = constants.dqflags["JUMP_DET"]

# Get image data information
data = input_model.data
err = input_model.err
Expand All @@ -805,7 +803,7 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting):
# the output products are returned to ramp_fit(). If the initial group of
# a ramp is saturated, it is assumed that all groups are saturated.
first_gdq = groupdq[:, 0, :, :]
if np.all(np.bitwise_and(first_gdq, SATURATED)):
if np.all(np.bitwise_and(first_gdq, sat_flag)):
image_info, integ_info, opt_info = utils.do_all_sat(
inpixeldq, groupdq, imshape, n_int, save_opt)

Expand Down Expand Up @@ -876,7 +874,7 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting):
gain_sect = gain_2d[rlo:rhi, :]

# Reset all saturated groups in the input data array to NaN
where_sat = np.where(np.bitwise_and(gdq_sect, SATURATED))
where_sat = np.where(np.bitwise_and(gdq_sect, sat_flag))

data_sect[where_sat] = np.NaN
del where_sat
Expand All @@ -902,7 +900,7 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting):
# starting at group 1. The purpose of starting at index 1 is
# to shift all the indices down by 1, so they line up with the
# indices in first_diffs.
i_group, i_yy, i_xx, = np.where(np.bitwise_and(gdq_sect[1:, :, :], JUMP_DET))
i_group, i_yy, i_xx, = np.where(np.bitwise_and(gdq_sect[1:, :, :], jump_flag))
first_diffs_sect[i_group, i_yy, i_xx] = np.NaN

del i_group, i_yy, i_xx
Expand Down Expand Up @@ -944,7 +942,7 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting):
num_seg_per_int[num_int, rlo:rhi, :] = num_seg.reshape(sect_shape)

# Populate integ-spec slice which is set if 0th group has SAT
wh_sat0 = np.where(np.bitwise_and(gdq_sect[0, :, :], SATURATED))
wh_sat0 = np.where(np.bitwise_and(gdq_sect[0, :, :], sat_flag))
if len(wh_sat0[0]) > 0:
sat_0th_group_int[num_int, rlo:rhi, :][wh_sat0] = 1

Expand All @@ -964,7 +962,7 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting):
# as approximation to cosmic ray amplitude for those pixels
# having their DQ set for cosmic rays
data_diff = data_sect - utils.shift_z(data_sect, -1)
dq_cr = np.bitwise_and(JUMP_DET, gdq_sect)
dq_cr = np.bitwise_and(jump_flag, gdq_sect)

opt_res.cr_mag_seg[num_int, :, rlo:rhi, :] = data_diff * (dq_cr != 0)

Expand Down Expand Up @@ -3274,7 +3272,8 @@ def calc_num_seg(gdq, n_int):
# ramps, to use as a surrogate for the number of segments along the ramps
# Note that we only care about flags that are NOT in the first or last groups,
# because exclusion of a first or last group won't result in an additional segment.
max_cr = np.count_nonzero(np.bitwise_and(gdq[:, 1:-1], JUMP_DET | DO_NOT_USE), axis=1).max()
check_flag = constants.dqflags["JUMP_DET"] | constants.dqflags["DO_NOT_USE"]
max_cr = np.count_nonzero(np.bitwise_and(gdq[:, 1:-1], check_flag), axis=1).max()

# Do not want to return a value > the number of groups, which can occur if
# this is a MIRI dataset in which the first or last group was flagged as
Expand Down
12 changes: 11 additions & 1 deletion src/stcal/ramp_fitting/ramp_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
import logging

from . import constants
# from . import gls_fit # used only if algorithm is "GLS"
from . import ols_fit # used only if algorithm is "OLS"

Expand All @@ -26,7 +27,7 @@


def ramp_fit(model, buffsize, save_opt, readnoise_2d, gain_2d,
algorithm, weighting, max_cores):
algorithm, weighting, max_cores, dqflags):
"""
Calculate the count rate for each pixel in all data cube sections and all
integrations, equal to the slope for all sections (intervals between
Expand Down Expand Up @@ -67,6 +68,10 @@ def ramp_fit(model, buffsize, save_opt, readnoise_2d, gain_2d,
to use for multi-proc. The total number of cores includes the SMT cores
(Hyper Threading for Intel).

dqflags: dict
A dictionary with at least the following keywords:
DO_NOT_USE, SATURATED, JUMP_DET, NO_GAIN_VALUE, UNRELIABLE_SLOPE

Returns
-------
image_info: tuple
Expand All @@ -82,6 +87,11 @@ def ramp_fit(model, buffsize, save_opt, readnoise_2d, gain_2d,
Object containing optional GLS-specific ramp fitting data for the
exposure
"""

constants.update_dqflags(dqflags)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like constants are no longer constants. ;-)

This is sort-of-a round-about way of getting the dict passed to the calling code might be better achieved by just passing the dict directly to ols_ramp_fit. The indirection here is a bit confusing.

Clearly this is a dict that needs to be shared by a number a functions, so a class structure would be better encapsulate what is going on here. But that may be too much of a leap for this PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is the long term goal. In JP-553 I am creating an internal class. The needed flags are there. I will eventually remove this global variable. Not only are they not constants anymore, using global variables that can change can have bad side effects. This PR will be the last that requires changes to the JWST code. The changes regarding the internal class and the proper passing around of all the needed information, such as the DQ flags will not effect JWST anymore, i.e., a mirror PR in JWST will not be needed when a PR is opened in STCAL ramp fitting.

if None in constants.dqflags.values():
raise ValueError("Some of the DQ flags required for ramp_fitting are None.")

if algorithm.upper() == "GLS":
# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# !!!!! Reference to ReadModel and GainModel changed to simple ndarrays !!!!!
Expand Down
48 changes: 23 additions & 25 deletions src/stcal/ramp_fitting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,6 @@
# Replace zero or negative variances with this:
LARGE_VARIANCE = 1.e8

# TODO Should figure out a better way to do this
DO_NOT_USE = constants.DO_NOT_USE
SATURATED = constants.SATURATED
JUMP_DET = constants.JUMP_DET
NO_GAIN_VALUE = constants.NO_GAIN_VALUE
UNRELIABLE_SLOPE = constants.UNRELIABLE_SLOPE


class OptRes:
"""
Expand Down Expand Up @@ -225,7 +218,7 @@ def shrink_crmag(self, n_int, dq_cube, imshape, nreads):
max_cr = 0
for ii_int in range(0, n_int):
dq_int = dq_cube[ii_int, :, :, :]
dq_cr = np.bitwise_and(JUMP_DET, dq_int)
dq_cr = np.bitwise_and(constants.dqflags["JUMP_DET"], dq_int)
max_cr_int = (dq_cr > 0.).sum(axis=0).max()
max_cr = max(max_cr, max_cr_int)

Expand Down Expand Up @@ -514,7 +507,7 @@ def calc_slope_vars(rn_sect, gain_sect, gdq_sect, group_time, max_seg):
gdq_2d_nan = gdq_2d.copy() # group dq with SATS will be replaced by nans
gdq_2d_nan = gdq_2d_nan.astype(np.float32)

wh_sat = np.where(np.bitwise_and(gdq_2d, SATURATED))
wh_sat = np.where(np.bitwise_and(gdq_2d, constants.dqflags["SATURATED"]))
if len(wh_sat[0]) > 0:
gdq_2d_nan[wh_sat] = np.nan # set all SAT groups to nan

Expand All @@ -539,7 +532,8 @@ def calc_slope_vars(rn_sect, gain_sect, gdq_sect, group_time, max_seg):
del wh_good

# Locate any CRs that appear before the first SAT group...
wh_cr = np.where(gdq_2d_nan[i_read, :].astype(np.int32) & JUMP_DET > 0)
wh_cr = np.where(
gdq_2d_nan[i_read, :].astype(np.int32) & constants.dqflags["JUMP_DET"] > 0)

# ... but not on final read:
if (len(wh_cr[0]) > 0 and (i_read < nreads - 1)):
Expand Down Expand Up @@ -656,7 +650,8 @@ def calc_pedestal(num_int, slope_int, firstf_int, dq_first, nframes, groupgap,
ped = ff_all - slope_int[num_int, ::] * \
(((nframes + 1.) / 2. + dropframes1) / (nframes + groupgap))

ped[np.bitwise_and(dq_first, SATURATED) == SATURATED] = 0
sat_flag = constants.dqflags["SATURATED"]
ped[np.bitwise_and(dq_first, sat_flag) == sat_flag] = 0
ped[np.isnan(ped)] = 0.

return ped
Expand Down Expand Up @@ -1049,8 +1044,8 @@ def get_more_info(model): # pragma: no cover

group_time = model.meta.exposure.group_time
nframes_used = model.meta.exposure.nframes
saturated_flag = SATURATED
jump_flag = JUMP_DET
saturated_flag = constants.dqflags["SATURATED"]
jump_flag = constants.dqflags["JUMP_DET"]

return (group_time, nframes_used, saturated_flag, jump_flag)

Expand Down Expand Up @@ -1108,13 +1103,13 @@ def reset_bad_gain(pdq, gain):
'''
wh_g = np.where(gain <= 0.)
if len(wh_g[0]) > 0:
pdq[wh_g] = np.bitwise_or(pdq[wh_g], NO_GAIN_VALUE)
pdq[wh_g] = np.bitwise_or(pdq[wh_g], DO_NOT_USE)
pdq[wh_g] = np.bitwise_or(pdq[wh_g], constants.dqflags["NO_GAIN_VALUE"])
pdq[wh_g] = np.bitwise_or(pdq[wh_g], constants.dqflags["DO_NOT_USE"])

wh_g = np.where(np.isnan(gain))
if len(wh_g[0]) > 0:
pdq[wh_g] = np.bitwise_or(pdq[wh_g], NO_GAIN_VALUE)
pdq[wh_g] = np.bitwise_or(pdq[wh_g], DO_NOT_USE)
pdq[wh_g] = np.bitwise_or(pdq[wh_g], constants.dqflags["NO_GAIN_VALUE"])
pdq[wh_g] = np.bitwise_or(pdq[wh_g], constants.dqflags["DO_NOT_USE"])

return pdq

Expand Down Expand Up @@ -1237,7 +1232,7 @@ def fix_sat_ramps(sat_0th_group_int, var_p3, var_both3, slope_int, dq_int):
var_both3[sat_0th_group_int > 0] = LARGE_VARIANCE
slope_int[sat_0th_group_int > 0] = 0.
dq_int[sat_0th_group_int > 0] = np.bitwise_or(
dq_int[sat_0th_group_int > 0], DO_NOT_USE)
dq_int[sat_0th_group_int > 0], constants.dqflags["DO_NOT_USE"])

return var_p3, var_both3, slope_int, dq_int

Expand Down Expand Up @@ -1275,8 +1270,8 @@ def do_all_sat(pixeldq, groupdq, imshape, n_int, save_opt):
"""
# Create model for the primary output. Flag all pixels in the pixiel DQ
# extension as SATURATED and DO_NOT_USE.
pixeldq = np.bitwise_or(pixeldq, SATURATED)
pixeldq = np.bitwise_or(pixeldq, DO_NOT_USE)
pixeldq = np.bitwise_or(pixeldq, constants.dqflags["SATURATED"])
pixeldq = np.bitwise_or(pixeldq, constants.dqflags["DO_NOT_USE"])

data = np.zeros(imshape, dtype=np.float32)
dq = pixeldq
Expand All @@ -1297,7 +1292,7 @@ def do_all_sat(pixeldq, groupdq, imshape, n_int, save_opt):
groupdq_3d[ii, :, :] = np.bitwise_or.reduce(groupdq[ii, :, :, :],
axis=0)

groupdq_3d = np.bitwise_or(groupdq_3d, DO_NOT_USE)
groupdq_3d = np.bitwise_or(groupdq_3d, constants.dqflags["DO_NOT_USE"])

data = np.zeros((n_int,) + imshape, dtype=np.float32)
dq = groupdq_3d
Expand Down Expand Up @@ -1444,12 +1439,15 @@ def dq_compress_sect(gdq_sect, pixeldq_sect):
flags, 2-D flag

"""
sat_loc_r = np.bitwise_and(gdq_sect, SATURATED)
sat_flag = constants.dqflags["SATURATED"]
jump_flag = constants.dqflags["JUMP_DET"]

sat_loc_r = np.bitwise_and(gdq_sect, sat_flag)
sat_loc_im = np.where(sat_loc_r.sum(axis=0) > 0)
pixeldq_sect[sat_loc_im] = np.bitwise_or(pixeldq_sect[sat_loc_im], SATURATED)
pixeldq_sect[sat_loc_im] = np.bitwise_or(pixeldq_sect[sat_loc_im], sat_flag)

cr_loc_r = np.bitwise_and(gdq_sect, JUMP_DET)
cr_loc_r = np.bitwise_and(gdq_sect, jump_flag)
cr_loc_im = np.where(cr_loc_r.sum(axis=0) > 0)
pixeldq_sect[cr_loc_im] = np.bitwise_or(pixeldq_sect[cr_loc_im], JUMP_DET)
pixeldq_sect[cr_loc_im] = np.bitwise_or(pixeldq_sect[cr_loc_im], jump_flag)

return pixeldq_sect