diff --git a/CHANGES.rst b/CHANGES.rst index ddae52f5..6849038a 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,3 +1,12 @@ +0.2.2 (Unreleased) +================== + +ramp_fitting +------------ + +- Added DQ flag parameter to `ramp_fit` [#25] + + 0.2.1 (2021-05-20) ================== diff --git a/src/stcal/ramp_fitting/constants.py b/src/stcal/ramp_fitting/constants.py index 47b1b450..f45b4d06 100644 --- a/src/stcal/ramp_fitting/constants.py +++ b/src/stcal/ramp_fitting/constants.py @@ -1,5 +1,15 @@ -DO_NOT_USE = 2**0 # Bad pixel. Do not use. -SATURATED = 2**1 # Pixel saturated during exposure. -JUMP_DET = 2**2 # Jump detected during exposure -NO_GAIN_VALUE = 2**19 # Gain cannot be measured -UNRELIABLE_SLOPE = 2**24 # Slope variance large (i.e., noisy pixel) +dqflags = { + "DO_NOT_USE": None, + "SATURATED": None, + "JUMP_DET": None, + "NO_GAIN_VALUE": None, + "UNRELIABLE_SLOPE": None, +} + + +def update_dqflags(input_flags): + dqflags["DO_NOT_USE"] = input_flags["DO_NOT_USE"] + dqflags["SATURATED"] = input_flags["SATURATED"] + dqflags["JUMP_DET"] = input_flags["JUMP_DET"] + dqflags["NO_GAIN_VALUE"] = input_flags["NO_GAIN_VALUE"] + dqflags["UNRELIABLE_SLOPE"] = input_flags["UNRELIABLE_SLOPE"] diff --git a/src/stcal/ramp_fitting/ols_fit.py b/src/stcal/ramp_fitting/ols_fit.py index de1a2e5f..6c580d65 100644 --- a/src/stcal/ramp_fitting/ols_fit.py +++ b/src/stcal/ramp_fitting/ols_fit.py @@ -14,12 +14,6 @@ log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -# TODO Should figure out a better way to do this -DO_NOT_USE = constants.DO_NOT_USE -SATURATED = constants.SATURATED -JUMP_DET = constants.JUMP_DET -UNRELIABLE_SLOPE = constants.UNRELIABLE_SLOPE - BUFSIZE = 1024 * 300000 # 300Mb cache size for data section @@ -653,12 +647,13 @@ def discard_miri_groups(input_model): data = input_model.data err = input_model.err groupdq = input_model.groupdq + jump_flag = constants.dqflags["JUMP_DET"] n_int, ngroups, nrows, ncols = data.shape num_bad_slices = 0 # number of initial groups that are all DO_NOT_USE - while np.all(np.bitwise_and(groupdq[:, 0, :, :], DO_NOT_USE)): + while np.all(np.bitwise_and(groupdq[:, 0, :, :], constants.dqflags["DO_NOT_USE"])): num_bad_slices += 1 ngroups -= 1 @@ -673,11 +668,11 @@ def discard_miri_groups(input_model): # Where the initial group of the just-truncated data is a cosmic ray, # remove the JUMP_DET flag from the group dq for those pixels so # that those groups will be included in the fit. - wh_cr = np.where(np.bitwise_and(groupdq[:, 0, :, :], JUMP_DET)) + wh_cr = np.where(np.bitwise_and(groupdq[:, 0, :, :], jump_flag)) num_cr_1st = len(wh_cr[0]) for ii in range(num_cr_1st): - groupdq[wh_cr[0][ii], 0, wh_cr[1][ii], wh_cr[2][ii]] -= JUMP_DET + groupdq[wh_cr[0][ii], 0, wh_cr[1][ii], wh_cr[2][ii]] -= jump_flag if num_bad_slices > 0: data = data[:, num_bad_slices:, :, :] @@ -689,7 +684,7 @@ def discard_miri_groups(input_model): # in the while loop above, ngroups would have been set to 0, and Nones # would have been returned. If execution has gotten here, there must # be at least 1 remaining group that is not all flagged. - if np.all(np.bitwise_and(groupdq[:, -1, :, :], DO_NOT_USE)): + if np.all(np.bitwise_and(groupdq[:, -1, :, :], constants.dqflags["DO_NOT_USE"])): ngroups -= 1 # Check if there are remaining groups before accessing data @@ -782,6 +777,9 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting): Rate array """ + sat_flag = constants.dqflags["SATURATED"] + jump_flag = constants.dqflags["JUMP_DET"] + # Get image data information data = input_model.data err = input_model.err @@ -805,7 +803,7 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting): # the output products are returned to ramp_fit(). If the initial group of # a ramp is saturated, it is assumed that all groups are saturated. first_gdq = groupdq[:, 0, :, :] - if np.all(np.bitwise_and(first_gdq, SATURATED)): + if np.all(np.bitwise_and(first_gdq, sat_flag)): image_info, integ_info, opt_info = utils.do_all_sat( inpixeldq, groupdq, imshape, n_int, save_opt) @@ -876,7 +874,7 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting): gain_sect = gain_2d[rlo:rhi, :] # Reset all saturated groups in the input data array to NaN - where_sat = np.where(np.bitwise_and(gdq_sect, SATURATED)) + where_sat = np.where(np.bitwise_and(gdq_sect, sat_flag)) data_sect[where_sat] = np.NaN del where_sat @@ -902,7 +900,7 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting): # starting at group 1. The purpose of starting at index 1 is # to shift all the indices down by 1, so they line up with the # indices in first_diffs. - i_group, i_yy, i_xx, = np.where(np.bitwise_and(gdq_sect[1:, :, :], JUMP_DET)) + i_group, i_yy, i_xx, = np.where(np.bitwise_and(gdq_sect[1:, :, :], jump_flag)) first_diffs_sect[i_group, i_yy, i_xx] = np.NaN del i_group, i_yy, i_xx @@ -944,7 +942,7 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting): num_seg_per_int[num_int, rlo:rhi, :] = num_seg.reshape(sect_shape) # Populate integ-spec slice which is set if 0th group has SAT - wh_sat0 = np.where(np.bitwise_and(gdq_sect[0, :, :], SATURATED)) + wh_sat0 = np.where(np.bitwise_and(gdq_sect[0, :, :], sat_flag)) if len(wh_sat0[0]) > 0: sat_0th_group_int[num_int, rlo:rhi, :][wh_sat0] = 1 @@ -964,7 +962,7 @@ def ramp_fit_slopes(input_model, gain_2d, readnoise_2d, save_opt, weighting): # as approximation to cosmic ray amplitude for those pixels # having their DQ set for cosmic rays data_diff = data_sect - utils.shift_z(data_sect, -1) - dq_cr = np.bitwise_and(JUMP_DET, gdq_sect) + dq_cr = np.bitwise_and(jump_flag, gdq_sect) opt_res.cr_mag_seg[num_int, :, rlo:rhi, :] = data_diff * (dq_cr != 0) @@ -3274,7 +3272,8 @@ def calc_num_seg(gdq, n_int): # ramps, to use as a surrogate for the number of segments along the ramps # Note that we only care about flags that are NOT in the first or last groups, # because exclusion of a first or last group won't result in an additional segment. - max_cr = np.count_nonzero(np.bitwise_and(gdq[:, 1:-1], JUMP_DET | DO_NOT_USE), axis=1).max() + check_flag = constants.dqflags["JUMP_DET"] | constants.dqflags["DO_NOT_USE"] + max_cr = np.count_nonzero(np.bitwise_and(gdq[:, 1:-1], check_flag), axis=1).max() # Do not want to return a value > the number of groups, which can occur if # this is a MIRI dataset in which the first or last group was flagged as diff --git a/src/stcal/ramp_fitting/ramp_fit.py b/src/stcal/ramp_fitting/ramp_fit.py index b26e9468..6329706d 100755 --- a/src/stcal/ramp_fitting/ramp_fit.py +++ b/src/stcal/ramp_fitting/ramp_fit.py @@ -16,6 +16,7 @@ import numpy as np import logging +from . import constants # from . import gls_fit # used only if algorithm is "GLS" from . import ols_fit # used only if algorithm is "OLS" @@ -26,7 +27,7 @@ def ramp_fit(model, buffsize, save_opt, readnoise_2d, gain_2d, - algorithm, weighting, max_cores): + algorithm, weighting, max_cores, dqflags): """ Calculate the count rate for each pixel in all data cube sections and all integrations, equal to the slope for all sections (intervals between @@ -67,6 +68,10 @@ def ramp_fit(model, buffsize, save_opt, readnoise_2d, gain_2d, to use for multi-proc. The total number of cores includes the SMT cores (Hyper Threading for Intel). + dqflags: dict + A dictionary with at least the following keywords: + DO_NOT_USE, SATURATED, JUMP_DET, NO_GAIN_VALUE, UNRELIABLE_SLOPE + Returns ------- image_info: tuple @@ -82,6 +87,11 @@ def ramp_fit(model, buffsize, save_opt, readnoise_2d, gain_2d, Object containing optional GLS-specific ramp fitting data for the exposure """ + + constants.update_dqflags(dqflags) + if None in constants.dqflags.values(): + raise ValueError("Some of the DQ flags required for ramp_fitting are None.") + if algorithm.upper() == "GLS": # !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! # !!!!! Reference to ReadModel and GainModel changed to simple ndarrays !!!!! diff --git a/src/stcal/ramp_fitting/utils.py b/src/stcal/ramp_fitting/utils.py index 80f27874..d69ec49f 100644 --- a/src/stcal/ramp_fitting/utils.py +++ b/src/stcal/ramp_fitting/utils.py @@ -14,13 +14,6 @@ # Replace zero or negative variances with this: LARGE_VARIANCE = 1.e8 -# TODO Should figure out a better way to do this -DO_NOT_USE = constants.DO_NOT_USE -SATURATED = constants.SATURATED -JUMP_DET = constants.JUMP_DET -NO_GAIN_VALUE = constants.NO_GAIN_VALUE -UNRELIABLE_SLOPE = constants.UNRELIABLE_SLOPE - class OptRes: """ @@ -225,7 +218,7 @@ def shrink_crmag(self, n_int, dq_cube, imshape, nreads): max_cr = 0 for ii_int in range(0, n_int): dq_int = dq_cube[ii_int, :, :, :] - dq_cr = np.bitwise_and(JUMP_DET, dq_int) + dq_cr = np.bitwise_and(constants.dqflags["JUMP_DET"], dq_int) max_cr_int = (dq_cr > 0.).sum(axis=0).max() max_cr = max(max_cr, max_cr_int) @@ -514,7 +507,7 @@ def calc_slope_vars(rn_sect, gain_sect, gdq_sect, group_time, max_seg): gdq_2d_nan = gdq_2d.copy() # group dq with SATS will be replaced by nans gdq_2d_nan = gdq_2d_nan.astype(np.float32) - wh_sat = np.where(np.bitwise_and(gdq_2d, SATURATED)) + wh_sat = np.where(np.bitwise_and(gdq_2d, constants.dqflags["SATURATED"])) if len(wh_sat[0]) > 0: gdq_2d_nan[wh_sat] = np.nan # set all SAT groups to nan @@ -539,7 +532,8 @@ def calc_slope_vars(rn_sect, gain_sect, gdq_sect, group_time, max_seg): del wh_good # Locate any CRs that appear before the first SAT group... - wh_cr = np.where(gdq_2d_nan[i_read, :].astype(np.int32) & JUMP_DET > 0) + wh_cr = np.where( + gdq_2d_nan[i_read, :].astype(np.int32) & constants.dqflags["JUMP_DET"] > 0) # ... but not on final read: if (len(wh_cr[0]) > 0 and (i_read < nreads - 1)): @@ -656,7 +650,8 @@ def calc_pedestal(num_int, slope_int, firstf_int, dq_first, nframes, groupgap, ped = ff_all - slope_int[num_int, ::] * \ (((nframes + 1.) / 2. + dropframes1) / (nframes + groupgap)) - ped[np.bitwise_and(dq_first, SATURATED) == SATURATED] = 0 + sat_flag = constants.dqflags["SATURATED"] + ped[np.bitwise_and(dq_first, sat_flag) == sat_flag] = 0 ped[np.isnan(ped)] = 0. return ped @@ -1049,8 +1044,8 @@ def get_more_info(model): # pragma: no cover group_time = model.meta.exposure.group_time nframes_used = model.meta.exposure.nframes - saturated_flag = SATURATED - jump_flag = JUMP_DET + saturated_flag = constants.dqflags["SATURATED"] + jump_flag = constants.dqflags["JUMP_DET"] return (group_time, nframes_used, saturated_flag, jump_flag) @@ -1108,13 +1103,13 @@ def reset_bad_gain(pdq, gain): ''' wh_g = np.where(gain <= 0.) if len(wh_g[0]) > 0: - pdq[wh_g] = np.bitwise_or(pdq[wh_g], NO_GAIN_VALUE) - pdq[wh_g] = np.bitwise_or(pdq[wh_g], DO_NOT_USE) + pdq[wh_g] = np.bitwise_or(pdq[wh_g], constants.dqflags["NO_GAIN_VALUE"]) + pdq[wh_g] = np.bitwise_or(pdq[wh_g], constants.dqflags["DO_NOT_USE"]) wh_g = np.where(np.isnan(gain)) if len(wh_g[0]) > 0: - pdq[wh_g] = np.bitwise_or(pdq[wh_g], NO_GAIN_VALUE) - pdq[wh_g] = np.bitwise_or(pdq[wh_g], DO_NOT_USE) + pdq[wh_g] = np.bitwise_or(pdq[wh_g], constants.dqflags["NO_GAIN_VALUE"]) + pdq[wh_g] = np.bitwise_or(pdq[wh_g], constants.dqflags["DO_NOT_USE"]) return pdq @@ -1237,7 +1232,7 @@ def fix_sat_ramps(sat_0th_group_int, var_p3, var_both3, slope_int, dq_int): var_both3[sat_0th_group_int > 0] = LARGE_VARIANCE slope_int[sat_0th_group_int > 0] = 0. dq_int[sat_0th_group_int > 0] = np.bitwise_or( - dq_int[sat_0th_group_int > 0], DO_NOT_USE) + dq_int[sat_0th_group_int > 0], constants.dqflags["DO_NOT_USE"]) return var_p3, var_both3, slope_int, dq_int @@ -1275,8 +1270,8 @@ def do_all_sat(pixeldq, groupdq, imshape, n_int, save_opt): """ # Create model for the primary output. Flag all pixels in the pixiel DQ # extension as SATURATED and DO_NOT_USE. - pixeldq = np.bitwise_or(pixeldq, SATURATED) - pixeldq = np.bitwise_or(pixeldq, DO_NOT_USE) + pixeldq = np.bitwise_or(pixeldq, constants.dqflags["SATURATED"]) + pixeldq = np.bitwise_or(pixeldq, constants.dqflags["DO_NOT_USE"]) data = np.zeros(imshape, dtype=np.float32) dq = pixeldq @@ -1297,7 +1292,7 @@ def do_all_sat(pixeldq, groupdq, imshape, n_int, save_opt): groupdq_3d[ii, :, :] = np.bitwise_or.reduce(groupdq[ii, :, :, :], axis=0) - groupdq_3d = np.bitwise_or(groupdq_3d, DO_NOT_USE) + groupdq_3d = np.bitwise_or(groupdq_3d, constants.dqflags["DO_NOT_USE"]) data = np.zeros((n_int,) + imshape, dtype=np.float32) dq = groupdq_3d @@ -1444,12 +1439,15 @@ def dq_compress_sect(gdq_sect, pixeldq_sect): flags, 2-D flag """ - sat_loc_r = np.bitwise_and(gdq_sect, SATURATED) + sat_flag = constants.dqflags["SATURATED"] + jump_flag = constants.dqflags["JUMP_DET"] + + sat_loc_r = np.bitwise_and(gdq_sect, sat_flag) sat_loc_im = np.where(sat_loc_r.sum(axis=0) > 0) - pixeldq_sect[sat_loc_im] = np.bitwise_or(pixeldq_sect[sat_loc_im], SATURATED) + pixeldq_sect[sat_loc_im] = np.bitwise_or(pixeldq_sect[sat_loc_im], sat_flag) - cr_loc_r = np.bitwise_and(gdq_sect, JUMP_DET) + cr_loc_r = np.bitwise_and(gdq_sect, jump_flag) cr_loc_im = np.where(cr_loc_r.sum(axis=0) > 0) - pixeldq_sect[cr_loc_im] = np.bitwise_or(pixeldq_sect[cr_loc_im], JUMP_DET) + pixeldq_sect[cr_loc_im] = np.bitwise_or(pixeldq_sect[cr_loc_im], jump_flag) return pixeldq_sect