Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replacement of sci_cube_skysub with DIKL #629

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 153 additions & 48 deletions vip_hci/preproc/skysubtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,28 +17,39 @@
| `https://arxiv.org/abs/1706.10069
<https://arxiv.org/abs/1706.10069>`_

.. [REN23]
| Ren 2023
| **Karhunen-Loève data imputation in high-contrast imaging**
| *Astronomy & Astrophysics, Volume 679, p. 8*
| `https://arxiv.org/abs/2308.16912
<https://arxiv.org/abs/2308.16912>`_

"""

__author__ = 'Carlos Alberto Gomez Gonzalez'
__author__ = 'Sandrine Juillard'
__all__ = ['cube_subtract_sky_pca']

import numpy as np
from ..var import prepare_matrix


def cube_subtract_sky_pca(sci_cube, sky_cube, mask, ref_cube=None, ncomp=2,
def cube_subtract_sky_pca(sci_cube, sky_cube, masks, ref_cube=None, ncomp=2,
full_output=False):
"""PCA-based sky subtraction as explained in [GOM17]_ and [HUN18]_.
""" PCA-based sky subtraction as explained in [REN23]_.
(see also [GOM17]_ and [HUN18]_)

Parameters
----------
sci_cube : numpy ndarray
3d array of science frames.
sky_cube : numpy ndarray
3d array of sky frames.
mask : numpy ndarray
Mask indicating the region for the analysis. Can be created with the
function vip_hci.var.create_ringed_spider_mask.
masks : tuple of two numpy ndarray or one signle numpy ndarray
Mask indicating the boat and anchor regions for the analysis.
If two masks are provided, they will be assigned to mask_anchor and
mask_boat in that order.
If only one mask is provided, it will be used as the anchor, and the
boat images will not be masked (i.e., full frames used).
ref_cube : numpy ndarray or None, opt
Reference cube.
ncomp : int, opt
Expand All @@ -47,6 +58,12 @@ def cube_subtract_sky_pca(sci_cube, sky_cube, mask, ref_cube=None, ncomp=2,
Whether to also output pcs, reconstructed cube, residuals cube and
derotated residual cube.

Notes
----------
Masks can be created with the function `vip_hci.var.create_ringed_spider_mask`
or `get_annulus_segments` (see Usage Exemple below)


Returns
-------
sci_cube_skysub : numpy ndarray
Expand All @@ -56,10 +73,38 @@ def cube_subtract_sky_pca(sci_cube, sky_cube, mask, ref_cube=None, ncomp=2,
If full_output is set to True, returns (in the following order):
- sky-subtracted science cube,
- sky-subtracted reference cube (if any provided),
- principal components (non-masked),
- principal components (masked), and
- boat principal components,
- anchor principal components, and
- reconstructed cube.

Usage Exemple
-------

You can create the masks using `get_annulus_segments` from `vip_hci.var`.

.. code-block:: python

from vip_hci.var import get_annulus_segments

The function must be used as follows, where `ring_out`, `ring_in`, and
`coro` define the radius of the different annulus masks. They must have
the same shape as a frame of the science cube.


.. code-block:: python

ones = np.ones(cube[0].shape)
boat = get_annulus_segments(ones,coro,ring_out-coro, mode="mask")[0]
anchor = get_annulus_segments(ones,ring_in,ring_out-ring_in, mode="mask")[0]


Masks should be provided as 'mask_rdi' argument when using PCA.

.. code-block:: python

res = pca(cube, angles, ref, mask_rdi=(boat, anchor), ncomp=2)


"""
try:
from ..psfsub.svd import svd_wrapper
Expand All @@ -69,79 +114,139 @@ def cube_subtract_sky_pca(sci_cube, sky_cube, mask, ref_cube=None, ncomp=2,
if sci_cube.shape[1] != sky_cube.shape[1] or sci_cube.shape[2] != \
sky_cube.shape[2]:
raise TypeError('Science and Sky frames sizes do not match')

if ref_cube is not None:
if sci_cube.shape[1] != ref_cube.shape[1] or sci_cube.shape[2] != \
ref_cube.shape[2]:
raise TypeError('Science and Reference frames sizes do not match')
if type(masks) not in (list, tuple):
# If only one mask is provided, the second mask is generated
mask_anchor = masks
mask_boat = np.ones(masks.shape)
elif len(masks)!=2:
raise TypeError('Science and Reference frames sizes do not match')
else :
mask_anchor, mask_boat = masks

## -- Generate boat and anchor matrixes

# Masking the sky cube with anchor
sky_cube_masked = np.zeros_like(sky_cube)
ind_masked = np.where(mask_anchor == 0)
for i in range(sky_cube.shape[0]):
masked_image = np.copy(sky_cube[i])
masked_image[ind_masked] = 0
sky_cube_masked[i] = masked_image
sky_anchor = sky_cube_masked.reshape(sky_cube.shape[0],
sky_cube.shape[1]*sky_cube.shape[2])

# Getting the EVs from the sky cube
Msky = prepare_matrix(sky_cube, scaling=None, verbose=False)
sky_pcs = svd_wrapper(Msky, 'lapack', sky_cube.shape[0], False)
sky_pcs_cube = sky_pcs.reshape(sky_cube.shape[0], sky_cube.shape[1],
sky_cube.shape[2])

# Masking the science cube
sci_cube_masked = np.zeros_like(sci_cube)
ind_masked = np.where(mask == 0)
# Masking the science cube with anchor
sci_cube_anchor = np.zeros_like(sci_cube)
ind_masked = np.where(mask_anchor == 0)
for i in range(sci_cube.shape[0]):
masked_image = np.copy(sci_cube[i])
masked_image[ind_masked] = 0
sci_cube_masked[i] = masked_image
Msci_masked = prepare_matrix(sci_cube_masked, scaling=None, verbose=False)
sci_cube_anchor[i] = masked_image
Msci_masked_anchor = prepare_matrix(sci_cube_anchor, scaling=None, verbose=False)

# Masking the PCs learned from the skies
sky_pcs_cube_masked = np.zeros_like(sky_pcs_cube)
for i in range(sky_pcs_cube.shape[0]):
masked_image = np.copy(sky_pcs_cube[i])
# Masking the science cube with boat
sci_cube_boat = np.zeros_like(sci_cube)
ind_masked = np.where(mask_boat == 0)
for i in range(sci_cube.shape[0]):
masked_image = np.copy(sci_cube[i])
masked_image[ind_masked] = 0
sky_pcs_cube_masked[i] = masked_image

# Project the masked frames onto the sky PCs to get the coefficients
transf_sci = np.zeros((sky_cube.shape[0], Msci_masked.shape[0]))
for i in range(Msci_masked.shape[0]):
transf_sci[:, i] = np.inner(sky_pcs, Msci_masked[i].T)

Msky_pcs_masked = prepare_matrix(sky_pcs_cube_masked, scaling=None,
sci_cube_boat[i] = masked_image
Msci_masked = prepare_matrix(sci_cube_boat, scaling=None, verbose=False)

# Masking the sky cube with boat
sky_cube_boat = np.zeros_like(sky_cube)
ind_masked = np.where(mask_boat == 0)
for i in range(sky_cube.shape[0]):
masked_image = np.copy(sky_cube[i])
masked_image[ind_masked] = 0
sky_cube_boat[i] = masked_image
sky_boat = sky_cube_boat.reshape(sky_cube.shape[0],
sky_cube.shape[1]*sky_cube.shape[2])

## -- Generate eigenvectors of R(a)T R(a)

sky_kl = np.dot(sky_anchor, sky_anchor.T)
Msky_kl = prepare_matrix(sky_kl, scaling=None, verbose=False)
sky_pcs = svd_wrapper(Msky_kl, 'lapack', sky_kl.shape[0], False)
sky_pcs_kl = sky_pcs.reshape(sky_kl.shape[0], sky_kl.shape[1])

## -- Generate Kl and Dikl transform

sky_pc_anchor = np.dot(sky_pcs_kl,sky_anchor)
sky_pcs_anchor_cube = sky_pc_anchor.reshape(sky_cube.shape[0],
sky_cube.shape[1], sky_cube.shape[2])

sky_pcs_boat_cube = np.dot(sky_pcs_kl,sky_boat).reshape(sky_cube.shape[0],
sky_cube.shape[1], sky_cube.shape[2])

## -- Generate Kl projection to get coeff

transf_sci = np.zeros((sky_cube.shape[0], Msci_masked_anchor.shape[0]))
for i in range(Msci_masked_anchor.shape[0]):
transf_sci[:, i] = np.inner(sky_pc_anchor, Msci_masked_anchor[i].T)

Msky_pcs_anchor = prepare_matrix(sky_pcs_anchor_cube, scaling=None,
verbose=False)
mat_inv = np.linalg.inv(np.dot(Msky_pcs_masked, Msky_pcs_masked.T))

mat_inv = np.linalg.inv(np.dot(Msky_pcs_anchor, Msky_pcs_anchor.T))
transf_sci_scaled = np.dot(mat_inv, transf_sci)

# Obtaining the optimized sky and subtraction

## -- Subtraction Dikl projection using anchor coeff to sci cube

sci_cube_skysub = np.zeros_like(sci_cube)
sky_opt = sci_cube.copy()
for i in range(Msci_masked.shape[0]):
sky_opt[i] = np.array([np.sum(
transf_sci_scaled[j, i] * sky_pcs_cube[j] for j in range(ncomp))])
sci_cube_skysub[i] = sci_cube[i] - sky_opt[i]
transf_sci_scaled[j, i] * sky_pcs_boat_cube[j] for j in range(ncomp))])
sci_cube_skysub[i] = sci_cube_boat[i] - sky_opt[i]

# Processing the reference cube (if any)
## -- Processing the reference cube (if any)
if ref_cube is not None:
ref_cube_masked = np.zeros_like(ref_cube)
for i in range(ref_cube.shape[0]):
masked_image = np.copy(ref_cube[i])

# Masking the ref cube with anchor
ref_cube_anchor = np.zeros_like(sci_cube)
ind_masked = np.where(mask_anchor == 0)
for i in range(sci_cube.shape[0]):
masked_image = np.copy(sci_cube[i])
masked_image[ind_masked] = 0
ref_cube_masked[i] = masked_image
Mref_masked = prepare_matrix(ref_cube_masked, scaling=None,
verbose=False)
ref_cube_anchor[i] = masked_image
Mref_masked_anchor = prepare_matrix(ref_cube_anchor, scaling=None, verbose=False)

# Masking the ref cube with boat
ref_cube_boat = np.zeros_like(sci_cube)
ind_masked = np.where(mask_boat == 0)
for i in range(sci_cube.shape[0]):
masked_image = np.copy(sci_cube[i])
masked_image[ind_masked] = 0
ref_cube_boat[i] = masked_image
Mref_masked = prepare_matrix(ref_cube_boat, scaling=None, verbose=False)

transf_ref = np.zeros((sky_cube.shape[0], Mref_masked.shape[0]))
for i in range(Mref_masked.shape[0]):
transf_ref[:, i] = np.inner(sky_pcs, Mref_masked[i].T)
transf_ref[:, i] = np.inner(sky_pc_anchor, Mref_masked_anchor[i].T)

transf_ref_scaled = np.dot(mat_inv, transf_ref)

ref_cube_skysub = np.zeros_like(ref_cube)
for i in range(Mref_masked.shape[0]):
sky_opt = np.array([np.sum(transf_ref_scaled[j, i] * sky_pcs_cube[j]
sky_opt = np.array([np.sum(transf_ref_scaled[j, i] * sky_pcs_boat_cube[j]
for j in range(ncomp))])
ref_cube_skysub[i] = ref_cube[i] - sky_opt
ref_cube_skysub[i] = ref_cube_boat[i] - sky_opt

if full_output:
return (sci_cube_skysub, ref_cube_skysub, sky_pcs_cube,
sky_pcs_cube_masked, sky_opt)
return (sci_cube_skysub, ref_cube_skysub, sky_pcs_anchor_cube,
sky_pcs_boat_cube, sky_opt)
else:
return sci_cube_skysub, ref_cube_skysub
else:
if full_output:
return (sci_cube_skysub, sky_pcs_cube, sky_pcs_cube_masked, sky_opt)
return (sci_cube_skysub, sky_pcs_anchor_cube, sky_pcs_boat_cube,
sky_opt)
else:
return sci_cube_skysub
Loading
Loading