From adc3ca405ff0f7729833a0e120e81eda1ce0992b Mon Sep 17 00:00:00 2001 From: Mihai Cara Date: Wed, 27 Nov 2024 11:57:44 -0500 Subject: [PATCH] Move common resample code to stcal --- jwst/resample/resample.py | 706 +++++++++++++++++++++- jwst/resample/resample_step.py | 105 ++-- jwst/resample/resample_utils.py | 167 ++++- jwst/resample/tests/test_resample_step.py | 24 +- 4 files changed, 898 insertions(+), 104 deletions(-) diff --git a/jwst/resample/resample.py b/jwst/resample/resample.py index 72b325cbd3..c59b1c8fcc 100644 --- a/jwst/resample/resample.py +++ b/jwst/resample/resample.py @@ -2,29 +2,712 @@ import os import warnings import json +import re +import math +from typing import Any +from copy import deepcopy import numpy as np +from astropy.io import fits import psutil + from drizzle.resample import Drizzle from spherical_geometry.polygon import SphericalPolygon from stdatamodels.jwst import datamodels from stdatamodels.jwst.library.basic_utils import bytes2human +from stdatamodels.jwst.datamodels.dqflags import pixel +from stdatamodels.properties import ObjectNode + +from stcal.resample import ( + Resample, + OutputTooLargeError, + UnsupportedWCSError, +) +from stcal.resample.utils import ( + compute_wcs_pixel_area, + is_imaging_wcs, + resample_range, +) from jwst.datamodels import ModelLibrary from jwst.associations.asn_from_list import asn_from_list from jwst.model_blender.blender import ModelBlender from jwst.resample import resample_utils +from jwst.assign_wcs import util as assign_wcs_util + + +log = logging.getLogger(__name__) +log.setLevel(logging.DEBUG) + + +__all__ = [ + "ResampleImage", + "OutputTooLargeError", + "is_imaging_wcs", +] + +_SUPPORTED_CUSTOM_WCS_PARS = [ + 'pixel_scale_ratio', + 'pixel_scale', + 'output_shape', + 'crpix', + 'crval', + 'rotation', +] log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -__all__ = ["OutputTooLargeError", "ResampleData"] +class ResampleImage(Resample): + dq_flag_name_map = pixel + + def __init__(self, input_models, pixfrac=1.0, kernel="square", + fillval="NAN", wht_type="ivm", good_bits=0, + blendheaders=True, output_wcs=None, wcs_pars=None, + output=None, enable_ctx=True, enable_var=True, + compute_err=None, allowed_memory=None, asn_id=None, + in_memory=True): + """ + Parameters + ---------- + input_models : LibModelAccessBase + A `LibModelAccessBase`-based object allowing iterating over + all contained models of interest. + + pixfrac : float, optional + The fraction of a pixel that the pixel flux is confined to. The + default value of 1 has the pixel flux evenly spread across the + image. A value of 0.5 confines it to half a pixel in the linear + dimension, so the flux is confined to a quarter of the pixel area + when the square kernel is used. + + kernel: {"square", "gaussian", "point", "turbo", "lanczos2", "lanczos3"}, optional + The name of the kernel used to combine the input. The choice of + kernel controls the distribution of flux over the kernel. + The square kernel is the default. + + .. warning:: + The "gaussian" and "lanczos2/3" kernels **DO NOT** + conserve flux. + + fillval: float, None, str, optional + The value of output pixels that did not have contributions from + input images' pixels. When ``fillval`` is either `None` or + ``"INDEF"`` and ``out_img`` is provided, the values of ``out_img`` + will not be modified. When ``fillval`` is either `None` or + ``"INDEF"`` and ``out_img`` is **not provided**, the values of + ``out_img`` will be initialized to `numpy.nan`. If ``fillval`` + is a string that can be converted to a number, then the output + pixels with no contributions from input images will be set to this + ``fillval`` value. + + wht_type : {"exptime", "ivm"}, optional + The weighting type for adding models' data. For ``wht_type="ivm"`` + (the default), the weighting will be determined per-pixel using + the inverse of the read noise (VAR_RNOISE) array stored in each + input image. If the ``VAR_RNOISE`` array does not exist, + the variance is set to 1 for all pixels (i.e., equal weighting). + If ``weight_type="exptime"``, the weight will be set equal + to the measurement time (``TMEASURE``) when available and to + the exposure time (``EFFEXPTM``) otherwise. + + good_bits : int, str, None, optional + An integer bit mask, `None`, a Python list of bit flags, a comma-, + or ``'|'``-separated, ``'+'``-separated string list of integer + bit flags or mnemonic flag names that indicate what bits in models' + DQ bitfield array should be *ignored* (i.e., zeroed). + + When co-adding models using :py:meth:`add_model`, any pixels with + a non-zero DQ values are assigned a weight of zero and therefore + they do not contribute to the output (resampled) data. + ``good_bits`` provides a mean to ignore some of the DQ bitflags. + + When ``good_bits`` is an integer, it must be + the sum of all the DQ bit values from the input model's + DQ array that should be considered "good" (or ignored). For + example, if pixels in the DQ array can be + combinations of 1, 2, 4, and 8 flags and one wants to consider DQ + "defects" having flags 2 and 4 as being acceptable, then + ``good_bits`` should be set to 2+4=6. Then a pixel with DQ values + 2,4, or 6 will be considered a good pixel, while a pixel with + DQ value, e.g., 1+2=3, 4+8=12, etc. will be flagged as + a "bad" pixel. + + Alternatively, when ``good_bits`` is a string, it can be a + comma-separated or '+' separated list of integer bit flags that + should be summed to obtain the final "good" bits. For example, + both "4,8" and "4+8" are equivalent to integer ``good_bits=12``. + + Finally, instead of integers, ``good_bits`` can be a string of + comma-separated mnemonics. For example, for JWST, all the following + specifications are equivalent: + + `"12" == "4+8" == "4, 8" == "JUMP_DET, DROPOUT"` + + In order to "translate" mnemonic code to integer bit flags, + ``Resample.dq_flag_name_map`` attribute must be set to either + a dictionary (with keys being mnemonc codes and the values being + integer flags) or a `~astropy.nddata.BitFlagNameMap`. + + In order to reverse the meaning of the flags + from indicating values of the "good" DQ flags + to indicating the "bad" DQ flags, prepend '~' to the string + value. For example, in order to exclude pixels with + DQ flags 4 and 8 for computations and to consider + as "good" all other pixels (regardless of their DQ flag), + use a value of ``~4+8``, or ``~4,8``. A string value of + ``~0`` would be equivalent to a setting of ``None``. + + Default value (0) will make *all* pixels with non-zero DQ + values be considered "bad" pixels, and the corresponding data + pixels will be assigned zero weight and thus these pixels + will not contribute to the output resampled data array. + + Set `good_bits` to `None` to turn off the use of model's DQ + array. + + For more details, see documentation for + `astropy.nddata.bitmask.extend_bit_flag_map`. + + blendheaders : bool, optional + Indicates whether to blend metadata from all input models and + store the combined result to the output model. + + output_wcs : dict, WCS object, None, optional + Specifies output WCS either directly as a WCS or a dictionary + with keys ``'wcs'`` (WCS object) and ``'pixel_scale'`` + (pixel scale in arcseconds). ``'pixel_scale'``, when provided, + will be used for computation of drizzle scaling factor. When it is + not provided, output pixel scale will be *estimated* from the + provided WCS object. ``output_wcs`` object is required when + ``output_model`` is `None`. ``output_wcs`` is ignored when + ``output_model`` is provided. + + wcs_pars : dict, None, optional + A dictionary of custom WCS parameters used to define an + output WCS from input models' outlines. This argument is ignored + when ``output_wcs`` is specified. + + List of supported parameters (keywords in the dictionary): + + - ``pixel_scale_ratio`` : float + + Desired pixel scale ratio defined as the ratio of the + desired output pixel scale to the first input model's pixel + scale computed from this model's WCS at the fiducial point + (taken as the ``ref_ra`` and ``ref_dec`` from the + ``wcsinfo`` meta attribute of the first input image). + Ignored when ``pixel_scale`` is specified. Default value + is ``1.0``. + + - ``pixel_scale`` : float, None + + Desired pixel scale (in degrees) of the output WCS. When + provided, overrides ``pixel_scale_ratio``. Default value + is `None`. + + - ``output_shape`` : tuple of two integers (int, int), None + + Shape of the image (data array) using ``np.ndarray`` + convention (``ny`` first and ``nx`` second). This value + will be assigned to ``pixel_shape`` and ``array_shape`` + properties of the returned WCS object. Default value is + `None`. + + - ``rotation`` : float, None + + Position angle of output image's Y-axis relative to North. + A value of ``0.0`` would orient the final output image to + be North up. The default of `None` specifies that the + images will not be rotated, but will instead be resampled + in the default orientation for the camera with the x and y + axes of the resampled image corresponding approximately + to the detector axes. Ignored when ``transform`` is + provided. Default value is `None`. + + - ``crpix`` : tuple of float, None + + Position of the reference pixel in the resampled image + array. If ``crpix`` is not specified, it will be set to + the center of the bounding box of the returned WCS object. + Default value is `None`. + + - ``crval`` : tuple of float, None + + Right ascension and declination of the reference pixel. + Automatically computed if not provided. Default value is + `None`. + + output : str, None, optional + Filename for the output model. + + accumulate : bool, optional + Indicates whether resampled models should be added to the + provided ``output_model`` data or if new arrays should be + created. + + enable_ctx : bool, optional + Indicates whether to create a context image. If ``disable_ctx`` + is set to `True`, parameters ``out_ctx``, ``begin_ctx_id``, and + ``max_ctx_id`` will be ignored. + + enable_var : bool, optional + Indicates whether to resample variance arrays. + + compute_err : {"from_var", "driz_err"}, None, optional + - ``"from_var"``: compute output model's error array from + all (Poisson, flat, readout) resampled variance arrays. + Setting ``compute_err`` to ``"from_var"`` will assume + ``enable_var`` was set to `True` regardless of actual + value of the parameter ``enable_var``. + + - ``"driz_err"``: compute output model's error array by drizzling + together all input models' error arrays. + + Error array will be assigned to ``'err'`` key of the output model. + + .. note:: + At this time, output error array is not equivalent to + error propagation results. + + allowed_memory : float, None + Fraction of memory allowed to be used for resampling. If + ``allowed_memory`` is `None` then no check for available memory + will be performed. + + in_memory : bool, optional + + asn_id : str, None, optional + + """ + self.input_models = input_models + self.output_jwst_model = None + + self.output_dir = None + self.output_filename = output + if output is not None and '.fits' not in str(output): + self.output_dir = output + self.output_filename = None + self.intermediate_suffix = 'outlier_i2d' + + self.blendheaders = blendheaders + if blendheaders: + self._blender = ModelBlender( + blend_ignore_attrs=[ + 'meta.photometry.pixelarea_steradians', + 'meta.photometry.pixelarea_arcsecsq', + 'meta.filename', + ] + ) + + self.in_memory = in_memory + self.asn_id = asn_id + + # check wcs_pars has supported keywords: + if wcs_pars is None: + wcs_pars = {} + elif wcs_pars: + unsup = [] + unsup = set(wcs_pars.keys()).difference(_SUPPORTED_CUSTOM_WCS_PARS) + if unsup: + raise KeyError( + "Unsupported custom WCS parameters: " + f"{','.join(map(repr, unsup))}." + ) + + # determine output WCS: + shape = wcs_pars.get("output_shape") + + if output_wcs is None: + if (pscale := wcs_pars.get("pixel_scale")) is not None: + pscale /= 3600.0 + wcs, _, ps, ps_ratio = resample_utils.resampled_wcs_from_models( + input_models, + pixel_scale_ratio=wcs_pars.get("pixel_scale_ratio", 1.0), + pixel_scale=pscale, + output_shape=shape, + rotation=wcs_pars.get("rotation"), + crpix=wcs_pars.get("crpix"), + crval=wcs_pars.get("crval"), + ) + + output_wcs = { + "wcs": wcs, + "pixel_scale": 3600.0 * ps, + "pixel_scale_ratio": ps_ratio, + } + + else: + if shape is None: + if output_wcs.array_shape is None: + raise ValueError( + "Custom WCS objects must have the 'array_shape' " + "attribute set (defined)." + ) + else: + output_wcs = deepcopy(output_wcs) + output_wcs.array_shape = shape + + super().__init__( + n_input_models=len(input_models), + pixfrac=pixfrac, + kernel=kernel, + fillval=fillval, + wht_type=wht_type, + good_bits=good_bits, + output_wcs=output_wcs, + output_model=None, + accumulate=False, + enable_ctx=enable_ctx, + enable_var=enable_var, + compute_err=compute_err, + allowed_memory=allowed_memory, + ) + + def _input_model_to_dict(self, model): + # wcs = model.meta.wcs + + model_dict = { + # arrays: + "data": model.data, + "dq": model.dq, + + # meta: + "filename": model.meta.filename, + "group_id": model.meta.group_id, + "s_region": model.meta.wcsinfo.s_region, + "wcs": model.meta.wcs, + "wcsinfo": model.meta.wcsinfo, + + "exposure_time": model.meta.exposure.exposure_time, + "start_time": model.meta.exposure.start_time, + "end_time": model.meta.exposure.end_time, + "duration": model.meta.exposure.duration, + "measurement_time": model.meta.exposure.measurement_time, + "effective_exposure_time": model.meta.exposure.effective_exposure_time, + "elapsed_exposure_time": model.meta.exposure.elapsed_exposure_time, + + "pixelarea_steradians": model.meta.photometry.pixelarea_steradians, + "pixelarea_arcsecsq": model.meta.photometry.pixelarea_arcsecsq, + + "level": model.meta.background.level, # sky level + "subtracted": model.meta.background.subtracted, + + # spectroscopy-specific: + "instrument_name": model.meta.instrument.name, + "exposure_type": model.meta.exposure.type, + } + + if self._enable_var: + model_dict["var_flat"] = model.var_flat + model_dict["var_rnoise"] = model.var_rnoise + model_dict["var_poisson"] = model.var_poisson + + elif (self.weight_type is not None and + self.weight_type.startswith('ivm')): + model_dict["var_rnoise"] = model.var_rnoise + + if self._compute_err == "driz_err": + model_dict["err"] = model.err + + return model_dict + + def _create_output_jwst_model(self, ref_input_model=None): + """ Create a new blank model and update it's meta with info from ``ref_input_model``. """ + output_model = datamodels.ImageModel(None) # tuple(self.output_wcs.array_shape)) + + # update meta data and wcs + if ref_input_model is not None: + output_model.update(ref_input_model) + output_model.meta.wcs = self.output_wcs + return output_model + + def _update_output_model(self, model, info_dict): + model.data = info_dict["data"] + model.wht = info_dict["wht"] + if self._enable_ctx: + model.con = info_dict["con"] + if self._compute_err: + model.err = info_dict["err"] + if self._enable_var: + model.var_rnoise = info_dict["var_rnoise"] + model.var_flat = info_dict["var_flat"] + model.var_poisson = info_dict["var_poisson"] + + model.meta.wcs = info_dict["wcs"] + model.meta.photometry.pixelarea_steradians = info_dict["pixelarea_steradians"] + model.meta.photometry.pixelarea_arcsecsq = info_dict["pixelarea_arcsecsq"] + + model.meta.resample.pointings = info_dict["pointings"] + # model.meta.resample.n_coadds = info_dict["n_coadds"] + + model.meta.resample.pixel_scale_ratio = info_dict["pixel_scale_ratio"] + model.meta.resample.pixfrac = info_dict["pixfrac"] + model.meta.resample.kernel = info_dict["kernel"] + model.meta.resample.fillval = info_dict["fillval"] + model.meta.resample.weight_type = info_dict["weight_type"] + + model.meta.exposure.exposure_time = info_dict["exposure_time"] + model.meta.exposure.start_time = info_dict["start_time"] + model.meta.exposure.end_time = info_dict["end_time"] + model.meta.exposure.duration = info_dict["duration"] + model.meta.exposure.measurement_time = info_dict["measurement_time"] + model.meta.exposure.effective_exposure_time = info_dict["effective_exposure_time"] + model.meta.exposure.elapsed_exposure_time = info_dict["elapsed_exposure_time"] + + model.meta.cal_step.resample = 'COMPLETE' + + def add_model(self, model): + """ Resamples model image and either variance data (if ``enable_var`` + was `True`) or error data (if ``enable_err`` was `True`) and adds + them using appropriate weighting to the corresponding + arrays of the output model. It also updates resampled data weight, + the context array (if ``enable_ctx`` is `True`), relevant output + model's values such as "n_coadds". + + Whenever ``model`` has a unique group ID that was never processed + before, the "pointings" value of the output model is incremented and + the "group_id" attribute is updated. Also, time counters are updated + with new values from the input ``model`` by calling + :py:meth:`~Resample.update_time`. + + Parameters + ---------- + model : dict + A dictionary containing data arrays and other meta attributes + and values of actual models used by pipelines. + + """ + super().add_model(self._input_model_to_dict(model)) + if self.output_jwst_model is None: + self.output_jwst_model = self._create_output_jwst_model( + ref_input_model=model + ) + if self.blendheaders: + self._blender.accumulate(model) + + def finalize(self, free_memory=True): + """ Finalizes all computations and frees temporary objects. + + ``finalize`` calls :py:meth:`~Resample.finalize_resample_variance` and + :py:meth:`~Resample.finalize_time_info`. + + .. warning:: + If ``enable_var=True`` and :py:meth:`~Resample.finalize` is called + with ``free_memory=True`` then intermediate arrays holding variance + weights will be lost and so continuing adding new models after + a call to :py:meth:`~Resample.finalize` will result in incorrect + variance. + + """ + if self.blendheaders: + self._blender.finalize_model(self.output_jwst_model) + super().finalize(free_memory=True) + + self._update_output_model( + self.output_jwst_model, + self.output_model, + ) + + if is_imaging_wcs(self.output_jwst_model.meta.wcs): + # only for an imaging WCS: + self.update_fits_wcsinfo(self.output_jwst_model) + assign_wcs_util.update_s_region_imaging(self.output_jwst_model) + else: + assign_wcs_util.update_s_region_spectral(self.output_jwst_model) + + self.output_jwst_model.meta.cal_step.resample = 'COMPLETE' + + def reset_arrays(self, reset_output=True, n_input_models=None): + """ Initialize/reset `Drizzle` objects, `ModelBlender`, output model + and arrays, and time counters. Output WCS and shape are not modified + from `Resample` object initialization. This method needs to be called + before calling :py:meth:`add_model` for the first time if + :py:meth:`finalize` was previously called. + + Parameters + ---------- + reset_output : bool, optional + When `True` a new output model will be created. Otherwise new + models will be resampled and added to existing output data arrays. + + n_input_models : int, None, optional + Number of input models expected to be resampled. When provided, + this is used to estimate memory requirements and optimize memory + allocation for the context array. + + """ + super().reset_arrays( + reset_output=reset_output, + n_input_models=n_input_models + ) + if self.blendheaders: + self._blender = ModelBlender( + blend_ignore_attrs=[ + 'meta.photometry.pixelarea_steradians', + 'meta.photometry.pixelarea_arcsecsq', + 'meta.filename', + ] + ) + self.output_jwst_model = None + + def _create_output_model(self, ref_input_model=None): + """ Create a new blank model and update it's meta with info from + ``ref_input_model``. + """ + output_model = datamodels.ImageModel(None) + + # update meta data and wcs + if ref_input_model is not None: + output_model.update(ref_input_model) + output_model.meta.wcs = self._output_wcs + + return output_model -class OutputTooLargeError(RuntimeError): - """Raised when the output is too large for in-memory instantiation""" + def resample_group(self, indices, compute_error=False): + """ Resample multiple input images that belong to a single + ``group_id`` as specified by ``indices``. + + Parameters + ---------- + indices : list + + compute_error : bool, optional + If True, an approximate error image will be resampled + alongside the science image. + """ + if self.output_jwst_model is not None: + self.reset_arrays(reset_output=True) + + output_model_filename = '' + + log.info(f"{len(indices)} exposures to drizzle together") + for index in indices: + model = self.input_models.borrow(index) + if self.output_jwst_model is None: + # Determine output file type from input exposure filenames + # Use this for defining the output filename + indx = model.meta.filename.rfind('.') + output_type = model.meta.filename[indx:] + output_root = '_'.join(model.meta.filename.replace( + output_type, + '' + ).split('_')[:-1]) + output_model_filename = ( + f'{output_root}_' + f'{self.intermediate_suffix}{output_type}' + ) + + if isinstance(model, datamodels.SlitModel): + # must call this explicitly to populate area extension + # although the existence of this extension may not be necessary + model.area = model.area + + self.add_model(model) + self.input_models.shelve(model, index, modify=False) + del model + + self.finalize() + copy_asn_info_from_library(self.input_models, self.output_jwst_model) + self.output_jwst_model.meta.filename = output_model_filename + return self.output_jwst_model + + def resample_many_to_many(self): + """Resample many inputs to many outputs where outputs have a common frame. + + Coadd only different detectors of the same exposure, i.e. map NRCA5 and + NRCB5 onto the same output image, as they image different areas of the + sky. + + Used for outlier detection + """ + output_models = [] + for group_id, indices in self.input_models.group_indices.items(): + + output_model = self.resample_group(self.input_models, indices) + + if not self.in_memory: + # Write out model to disk, then return filename + output_name = output_model.meta.filename + if self.output_dir is not None: + output_name = os.path.join(self.output_dir, output_name) + output_model.save(output_name) + log.info(f"Saved model in {output_name}") + output_models.append(output_name) + else: + output_models.append(output_model) + + if self.in_memory: + # build ModelLibrary as a list of in-memory models + return ModelLibrary(output_models, on_disk=False) + else: + # build ModelLibrary as an association from the output files + # this saves memory if there are multiple groups + asn = asn_from_list(output_models, product_name='outlier_i2d') + asn_dict = json.loads(asn.dump()[1]) # serializes the asn and converts to dict + return ModelLibrary(asn_dict, on_disk=True) + + def resample_many_to_one(self): + """Resample and coadd many inputs to a single output. + + Used for stage 3 resampling + """ + # if self.output_jwst_model is not None: + # self.reset_arrays(reset_output=True) + + log.info("Resampling science and variance data") + + with self.input_models: + for model in self.input_models: + self.add_model(model) + self.input_models.shelve(model) + + self.finalize() + self.output_jwst_model.meta.filename = self.output_filename + copy_asn_info_from_library(self.input_models, self.output_jwst_model) + + return self.output_jwst_model + + @staticmethod + def update_fits_wcsinfo(model): + """ + Update FITS WCS keywords of the resampled image. + """ + # Delete any SIP-related keywords first + pattern = r"^(cd[12]_[12]|[ab]p?_\d_\d|[ab]p?_order)$" + regex = re.compile(pattern) + + keys = list(model.meta.wcsinfo.instance.keys()) + for key in keys: + if regex.match(key): + del model.meta.wcsinfo.instance[key] + + # Write new PC-matrix-based WCS based on GWCS model + transform = model.meta.wcs.forward_transform + model.meta.wcsinfo.crpix1 = -transform[0].offset.value + 1 + model.meta.wcsinfo.crpix2 = -transform[1].offset.value + 1 + model.meta.wcsinfo.cdelt1 = transform[3].factor.value + model.meta.wcsinfo.cdelt2 = transform[4].factor.value + model.meta.wcsinfo.ra_ref = transform[6].lon.value + model.meta.wcsinfo.dec_ref = transform[6].lat.value + model.meta.wcsinfo.crval1 = model.meta.wcsinfo.ra_ref + model.meta.wcsinfo.crval2 = model.meta.wcsinfo.dec_ref + model.meta.wcsinfo.pc1_1 = transform[2].matrix.value[0][0] + model.meta.wcsinfo.pc1_2 = transform[2].matrix.value[0][1] + model.meta.wcsinfo.pc2_1 = transform[2].matrix.value[1][0] + model.meta.wcsinfo.pc2_2 = transform[2].matrix.value[1][1] + model.meta.wcsinfo.ctype1 = "RA---TAN" + model.meta.wcsinfo.ctype2 = "DEC--TAN" + + # Remove no longer relevant WCS keywords + rm_keys = ['v2_ref', 'v3_ref', 'ra_ref', 'dec_ref', 'roll_ref', + 'v3yangle', 'vparity'] + for key in rm_keys: + if key in model.meta.wcsinfo.instance: + del model.meta.wcsinfo.instance[key] class ResampleData: @@ -93,6 +776,12 @@ def __init__(self, input_models, output=None, single=False, blendheaders=True, log.info(f"Driz parameter fillval: {self.fillval}") log.info(f"Driz parameter weight_type: {self.weight_type}") + # output_wcs = kwargs["wcs_pars"].get('output_wcs', None) + # output_shape = kwargs["wcs_pars"].get('output_shape', None) + # crpix = kwargs["wcs_pars"].get('crpix', None) + # crval = kwargs["wcs_pars"].get('crval', None) + # rotation = kwargs["wcs_pars"].get('rotation', None) + output_wcs = kwargs.get('output_wcs', None) output_shape = kwargs.get('output_shape', None) crpix = kwargs.get('crpix', None) @@ -114,7 +803,7 @@ def __init__(self, input_models, output=None, single=False, blendheaders=True, self.output_wcs.array_shape = output_shape[::-1] if output_wcs.pixel_area is None: - self.output_pix_area = compute_image_pixel_area(self.output_wcs) + self.output_pix_area = compute_wcs_pixel_area(self.output_wcs) if self.output_pix_area is None: raise ValueError( "Unable to compute output pixel area from 'output_wcs'." @@ -238,7 +927,10 @@ def _get_intensity_scale(self, img): input_pixel_area *= self.pscale_ratio else: img.meta.wcs.array_shape = img.data.shape - input_pixel_area = compute_image_pixel_area(img.meta.wcs) + input_pixel_area = compute_wcs_pixel_area( + img.meta.wcs, + img.data.shape, + ) if input_pixel_area is None: raise ValueError( "Unable to compute input pixel area from WCS of input " @@ -327,7 +1019,7 @@ def resample_group(self, input_models, indices, compute_error=False): else: data = img.data - xmin, xmax, ymin, ymax = resample_utils._resample_range( + xmin, xmax, ymin, ymax = resample_range( data.shape, img.meta.wcs.bounding_box ) @@ -485,7 +1177,7 @@ def resample_many_to_one(self, input_models): else: data = img.data.copy() - in_image_limits = resample_utils._resample_range( + in_image_limits = resample_range( data.shape, img.meta.wcs.bounding_box ) diff --git a/jwst/resample/resample_step.py b/jwst/resample/resample_step.py index 4132850918..0c0804f7eb 100755 --- a/jwst/resample/resample_step.py +++ b/jwst/resample/resample_step.py @@ -99,29 +99,30 @@ def process(self, input): kwargs = self.get_drizpars() # Call the resampling routine - resamp = resample.ResampleData(input_models, output=output, **kwargs) - result = resamp.do_drizzle(input_models) - - with result: - for model in result: - model.meta.cal_step.resample = 'COMPLETE' - self.update_fits_wcs(model) - util.update_s_region_imaging(model) - - # if pixel_scale exists, it will override pixel_scale_ratio. - # calculate the actual value of pixel_scale_ratio based on pixel_scale - # because source_catalog uses this value from the header. - if self.pixel_scale is None: - model.meta.resample.pixel_scale_ratio = self.pixel_scale_ratio - else: - model.meta.resample.pixel_scale_ratio = resamp.pscale_ratio - model.meta.resample.pixfrac = kwargs['pixfrac'] - result.shelve(model) - - if len(result) == 1: - model = result.borrow(0) - result.shelve(model, 0, modify=False) - return model + if self.single: + resamp = resample.ResampleImage( + input_models, + output=output, + enable_var=False, + compute_err="driz_err", + **kwargs + ) + result = resamp.resample_many_to_many() + else: + resamp = resample.ResampleImage( + input_models, + output=output, + enable_var=True, + compute_err="from_var", + **kwargs + ) + result = resamp.resample_many_to_one() + + # with result: + # if len(result) == 1: + # model = result.borrow(0) + # result.shelve(model, 0, modify=False) + # return model return result @@ -232,67 +233,35 @@ def get_drizpars(self): fillval=self.fillval, wht_type=self.weight_type, good_bits=GOOD_BITS, - single=self.single, blendheaders=self.blendheaders, allowed_memory=self.allowed_memory, in_memory=self.in_memory ) # Custom output WCS parameters. - kwargs['output_shape'] = self.check_list_pars( + output_shape = self.check_list_pars( self.output_shape, 'output_shape', min_vals=[1, 1] ) kwargs['output_wcs'] = self.load_custom_wcs( self.output_wcs, - kwargs['output_shape'] + output_shape ) - kwargs['crpix'] = self.check_list_pars(self.crpix, 'crpix') - kwargs['crval'] = self.check_list_pars(self.crval, 'crval') - kwargs['rotation'] = self.rotation - kwargs['pscale'] = self.pixel_scale - kwargs['pscale_ratio'] = self.pixel_scale_ratio + + wcs_pars = { + 'crpix': self.check_list_pars(self.crpix, 'crpix'), + 'crval': self.check_list_pars(self.crval, 'crval'), + 'rotation': self.rotation, + 'pixel_scale': self.pixel_scale, + 'pixel_scale_ratio': self.pixel_scale_ratio, + 'output_shape': None if output_shape is None else output_shape[::-1], + } + + kwargs['wcs_pars'] = wcs_pars # Report values to processing log for k, v in kwargs.items(): self.log.debug(' {}={}'.format(k, v)) return kwargs - - def update_fits_wcs(self, model): - """ - Update FITS WCS keywords of the resampled image. - """ - # Delete any SIP-related keywords first - pattern = r"^(cd[12]_[12]|[ab]p?_\d_\d|[ab]p?_order)$" - regex = re.compile(pattern) - - keys = list(model.meta.wcsinfo.instance.keys()) - for key in keys: - if regex.match(key): - del model.meta.wcsinfo.instance[key] - - # Write new PC-matrix-based WCS based on GWCS model - transform = model.meta.wcs.forward_transform - model.meta.wcsinfo.crpix1 = -transform[0].offset.value + 1 - model.meta.wcsinfo.crpix2 = -transform[1].offset.value + 1 - model.meta.wcsinfo.cdelt1 = transform[3].factor.value - model.meta.wcsinfo.cdelt2 = transform[4].factor.value - model.meta.wcsinfo.ra_ref = transform[6].lon.value - model.meta.wcsinfo.dec_ref = transform[6].lat.value - model.meta.wcsinfo.crval1 = model.meta.wcsinfo.ra_ref - model.meta.wcsinfo.crval2 = model.meta.wcsinfo.dec_ref - model.meta.wcsinfo.pc1_1 = transform[2].matrix.value[0][0] - model.meta.wcsinfo.pc1_2 = transform[2].matrix.value[0][1] - model.meta.wcsinfo.pc2_1 = transform[2].matrix.value[1][0] - model.meta.wcsinfo.pc2_2 = transform[2].matrix.value[1][1] - model.meta.wcsinfo.ctype1 = "RA---TAN" - model.meta.wcsinfo.ctype2 = "DEC--TAN" - - # Remove no longer relevant WCS keywords - rm_keys = ['v2_ref', 'v3_ref', 'ra_ref', 'dec_ref', 'roll_ref', - 'v3yangle', 'vparity'] - for key in rm_keys: - if key in model.meta.wcsinfo.instance: - del model.meta.wcsinfo.instance[key] diff --git a/jwst/resample/resample_utils.py b/jwst/resample/resample_utils.py index 4d126ed28c..05f3a7add6 100644 --- a/jwst/resample/resample_utils.py +++ b/jwst/resample/resample_utils.py @@ -1,5 +1,6 @@ from copy import deepcopy import logging +import math import warnings import numpy as np @@ -9,15 +10,153 @@ from stdatamodels.dqflags import interpret_bit_flags from stdatamodels.jwst.datamodels.dqflags import pixel -from jwst.assign_wcs.util import wcs_bbox_from_shape -from stcal.alignment import util +from stcal.alignment.util import ( + compute_scale, + wcs_bbox_from_shape, + wcs_from_sregions, +) +from stcal.resample.utils import compute_wcs_pixel_area +__all__ = ["decode_context", "make_output_wcs", "resampled_wcs_from_models"] + log = logging.getLogger(__name__) log.setLevel(logging.DEBUG) -__all__ = ['decode_context'] +def resampled_wcs_from_models( + input_models, + ref_wcs=None, + pixel_scale_ratio=1.0, + pixel_scale=None, + output_shape=None, + rotation=None, + crpix=None, + crval=None, +): + """ + Computes the WCS of the resampled image from input models and + specified WCS parameters. + + Parameters + ---------- + + input_models : `~jwst.datamodel.ModelLibrary` + Each datamodel must have a ``model.meta.wcs`` set to a ~gwcs.WCS object. + + ref_wcs : WCS object + A WCS used as reference for the creation of the output + coordinate frame, projection, and scaling and rotation transforms. + + pixel_scale_ratio : float, optional + Desired pixel scale ratio defined as the ratio of the desired output + pixel scale to the first input model's pixel scale computed from this + model's WCS at the fiducial point (taken as the ``ref_ra`` and + ``ref_dec`` from the ``wcsinfo`` meta attribute of the first input + image). Ignored when ``pixel_scale`` is specified. + + pixel_scale : float, None, optional + Desired pixel scale (in degrees) of the output WCS. When provided, + overrides ``pixel_scale_ratio``. + + output_shape : tuple of two integers (int, int), None, optional + Shape of the image (data array) using ``np.ndarray`` convention + (``ny`` first and ``nx`` second). This value will be assigned to + ``pixel_shape`` and ``array_shape`` properties of the returned + WCS object. + + rotation : float, None, optional + Position angle of output image's Y-axis relative to North. + A value of 0.0 would orient the final output image to be North up. + The default of `None` specifies that the images will not be rotated, + but will instead be resampled in the default orientation for the + camera with the x and y axes of the resampled image corresponding + approximately to the detector axes. Ignored when ``transform`` is + provided. + + crpix : tuple of float, None, optional + Position of the reference pixel in the resampled image array. + If ``crpix`` is not specified, it will be set to the center of the + bounding box of the returned WCS object. + + crval : tuple of float, None, optional + Right ascension and declination of the reference pixel. + Automatically computed if not provided. + + Returns + ------- + wcs : ~gwcs.wcs.WCS + The WCS object corresponding to the combined input footprints. + + pscale_in : float + Computed pixel scale (in degrees) of the first input image. + + pscale_out : float + Computed pixel scale (in degrees) of the output image. + + pixel_scale_ratio : float + Pixel scale ratio (output to input). + + """ + # build a list of WCS of all input models: + sregion_list = [] + ref_wcs = None + ref_wcsinfo = None + shape = None + + with input_models: + for model in input_models: + w = model.meta.wcs + if ref_wcsinfo is None: + ref_wcsinfo = model.meta.wcsinfo.instance + shape = model.data.shape + if ref_wcs is None: + ref_wcs = w + # make sure all WCS objects have the bounding_box defined: + if w.bounding_box is None: + w.bounding_box = wcs_bbox_from_shape(shape) + sregion_list.append(model.meta.wcsinfo.s_region) + input_models.shelve(model) + + if not sregion_list: + raise ValueError("No input models.") + + if pixel_scale is None: + # TODO: at some point we should switch to compute_wcs_pixel_area + # instead of compute_scale. + pscale_in0 = compute_scale( + ref_wcs, + fiducial=np.array([ref_wcsinfo["ra_ref"], ref_wcsinfo["dec_ref"]]) + ) + pixel_scale = pscale_in0 * pixel_scale_ratio + log.info( + f"Pixel scale ratio (pscale_out/pscale_in): {pixel_scale_ratio}" + ) + log.info(f"Computed output pixel scale: {3600 * pixel_scale} arcsec.") + else: + pscale_in0 = np.rad2deg( + math.sqrt(compute_wcs_pixel_area(ref_wcs, shape=shape)) + ) + + pixel_scale_ratio = pixel_scale / pscale_in0 + log.info(f"Output pixel scale: {3600 * pixel_scale} arcsec.") + log.info( + "Computed pixel scale ratio (pscale_out/pscale_in): " + f"{pixel_scale_ratio}." + ) + + wcs = wcs_from_sregions( + sregion_list, + ref_wcs=ref_wcs, + ref_wcsinfo=ref_wcsinfo, + pscale_ratio=pixel_scale_ratio, + pscale=pixel_scale, + rotation=rotation, + shape=output_shape, + crpix=crpix, + crval=crval + ) + return wcs, pscale_in0, pixel_scale, pixel_scale_ratio def make_output_wcs(input_models, ref_wcs=None, @@ -89,10 +228,10 @@ def make_output_wcs(input_models, ref_wcs=None, f"but the supplied WCS has {naxes} axes.") raise RuntimeError(msg) - output_wcs = util.wcs_from_sregions( + output_wcs = wcs_from_sregions( sregion_list, - ref_wcs, - ref_wcsinfo, + ref_wcs=ref_wcs, + ref_wcsinfo=ref_wcsinfo, pscale_ratio=pscale_ratio, pscale=pscale, rotation=rotation, @@ -335,22 +474,6 @@ def decode_context(context, x, y): return idx -def _resample_range(data_shape, bbox=None): - # Find range of input pixels to resample: - if bbox is None: - xmin = ymin = 0 - xmax = data_shape[1] - 1 - ymax = data_shape[0] - 1 - else: - ((x1, x2), (y1, y2)) = bbox - xmin = max(0, int(x1 + 0.5)) - ymin = max(0, int(y1 + 0.5)) - xmax = min(data_shape[1] - 1, int(x2 + 0.5)) - ymax = min(data_shape[0] - 1, int(y2 + 0.5)) - - return xmin, xmax, ymin, ymax - - def check_for_tmeasure(model): ''' Check if the measurement_time keyword is present in the datamodel diff --git a/jwst/resample/tests/test_resample_step.py b/jwst/resample/tests/test_resample_step.py index 75eddb8530..dc463ec5d0 100644 --- a/jwst/resample/tests/test_resample_step.py +++ b/jwst/resample/tests/test_resample_step.py @@ -6,6 +6,7 @@ import asdf from stdatamodels.jwst.datamodels import ImageModel +from stcal.resample.resample import compute_wcs_pixel_area from jwst.datamodels import ModelContainer, ModelLibrary from jwst.assign_wcs import AssignWcsStep @@ -13,7 +14,6 @@ from jwst.exp_to_source import multislit_to_container from jwst.extract_2d import Extract2dStep from jwst.resample import ResampleSpecStep, ResampleStep -from jwst.resample.resample import compute_image_pixel_area from jwst.resample.resample_spec import ResampleSpecData, compute_spectral_pixel_scale @@ -29,7 +29,11 @@ def _set_photom_kwd(im): bb = ((xmin - 0.5, xmax - 0.5), (ymin - 0.5, ymax - 0.5)) im.meta.wcs.bounding_box = bb - mean_pixel_area = compute_image_pixel_area(im.meta.wcs) + mean_pixel_area = compute_wcs_pixel_area( + im.meta.wcs, + shape=im.data.shape, + ) + if mean_pixel_area: im.meta.photometry.pixelarea_steradians = mean_pixel_area im.meta.photometry.pixelarea_arcsecsq = ( @@ -174,6 +178,8 @@ def miri_rate_pair(miri_rate_zero_crossing): im1.close() im2.close() +from stcal.resample.utils import compute_wcs_pixel_area + @pytest.fixture def nircam_rate(): @@ -650,7 +656,7 @@ def test_weight_type(nircam_rate, tmp_cwd): result2 = ResampleStep.call(c, weight_type="exptime", blendheaders=False) assert_allclose(result2.data[100:105, 100:105], 6.667, rtol=1e-2) - expectation_value = 407. + expectation_value = 407.0 assert_allclose(result2.wht[100:105, 100:105], expectation_value, rtol=1e-2) # remove measurement time to force use of exposure time @@ -923,8 +929,10 @@ def test_custom_refwcs_resample_imaging(nircam_rate, output_shape2, match, # make sure pixel values are similar, accounting for scale factor # (assuming inputs are in surface brightness units) - iscale = np.sqrt(im.meta.photometry.pixelarea_steradians - / compute_image_pixel_area(im.meta.wcs)) + iscale = np.sqrt( + im.meta.photometry.pixelarea_steradians + / compute_wcs_pixel_area(im.meta.wcs, shape=im.data.shape) + ) input_mean = np.nanmean(im.data) output_mean_1 = np.nanmean(data1) output_mean_2 = np.nanmean(data2) @@ -980,8 +988,10 @@ def test_custom_refwcs_pixel_shape_imaging(nircam_rate, tmp_path): # make sure pixel values are similar, accounting for scale factor # (assuming inputs are in surface brightness units) - iscale = np.sqrt(im.meta.photometry.pixelarea_steradians - / compute_image_pixel_area(im.meta.wcs)) + iscale = np.sqrt( + im.meta.photometry.pixelarea_steradians + / compute_wcs_pixel_area(im.meta.wcs, shape=im.data.shape) + ) input_mean = np.nanmean(im.data) output_mean_1 = np.nanmean(data1) output_mean_2 = np.nanmean(data2)