From fb279e5e5aa5982a818109dae1d2b0e85faeb1b7 Mon Sep 17 00:00:00 2001 From: Adam Coogan Date: Tue, 7 Mar 2023 13:09:27 -0500 Subject: [PATCH 01/11] make codecov not cause ci error --- .github/workflows/coverage.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/coverage.yaml b/.github/workflows/coverage.yaml index 02799b78..957bc888 100644 --- a/.github/workflows/coverage.yaml +++ b/.github/workflows/coverage.yaml @@ -64,4 +64,4 @@ jobs: uses: codecov/codecov-action@v3 with: files: ${{ github.workspace }}/coverage.xml - fail_ci_if_error: true + fail_ci_if_error: false From 83b126aa7f96c40e80a453aa2be1ef87e1b0aa51 Mon Sep 17 00:00:00 2001 From: Adam Coogan Date: Thu, 9 Mar 2023 20:19:32 -0500 Subject: [PATCH 02/11] same cosmo for lenstronomy and caustic in tests --- src/caustic/cosmology.py | 15 ++++++++++----- test/test_cosmology.py | 4 +++- test/test_epl.py | 16 ++++++++++------ test/test_external_shear.py | 8 +++----- test/test_lenses.py | 0 test/test_multiplane.py | 15 +++++---------- test/test_nfw.py | 26 ++++---------------------- test/test_point.py | 7 +++---- test/test_pseudo_jaffe.py | 7 +++---- test/test_sie.py | 7 +++---- test/test_sis.py | 13 ++++++------- test/utils.py | 10 ++++++++++ 12 files changed, 60 insertions(+), 68 deletions(-) delete mode 100644 test/test_lenses.py diff --git a/src/caustic/cosmology.py b/src/caustic/cosmology.py index 1b8c99ef..bfe176d5 100644 --- a/src/caustic/cosmology.py +++ b/src/caustic/cosmology.py @@ -54,26 +54,31 @@ def comoving_dist(self, z: Tensor, x: Optional[dict[str, Any]] = None) -> Tensor ... def comoving_dist_z1z2( - self, z1: Tensor, z2: Tensor, x: Optional[dict[str, Any]] = None) -> Tensor: + self, z1: Tensor, z2: Tensor, x: Optional[dict[str, Any]] = None + ) -> Tensor: return self.comoving_dist(z2, x) - self.comoving_dist(z1, x) def angular_diameter_dist( - self, z: Tensor, x: Optional[dict[str, Any]] = None) -> Tensor: + self, z: Tensor, x: Optional[dict[str, Any]] = None + ) -> Tensor: return self.comoving_dist(z, x) / (1 + z) def angular_diameter_dist_z1z2( - self, z1: Tensor, z2: Tensor, x: Optional[dict[str, Any]] = None) -> Tensor: + self, z1: Tensor, z2: Tensor, x: Optional[dict[str, Any]] = None + ) -> Tensor: return self.comoving_dist_z1z2(z1, z2, x) / (1 + z2) def time_delay_dist( - self, z_l: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None) -> Tensor: + self, z_l: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None + ) -> Tensor: d_l = self.angular_diameter_dist(z_l, x) d_s = self.angular_diameter_dist(z_s, x) d_ls = self.angular_diameter_dist_z1z2(z_l, z_s, x) return (1 + z_l) * d_l * d_s / d_ls def Sigma_cr( - self, z_l: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None) -> Tensor: + self, z_l: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None + ) -> Tensor: d_l = self.angular_diameter_dist(z_l, x) d_s = self.angular_diameter_dist(z_s, x) d_ls = self.angular_diameter_dist_z1z2(z_l, z_s, x) diff --git a/test/test_cosmology.py b/test/test_cosmology.py index 59fdf085..f4485ff6 100644 --- a/test/test_cosmology.py +++ b/test/test_cosmology.py @@ -5,7 +5,9 @@ from astropy.cosmology import Cosmology as Cosmology_AP from astropy.cosmology import FlatLambdaCDM as AstropyFlatLambdaCDM -from caustic.cosmology import Cosmology, FlatLambdaCDM as CausticFlatLambdaCDM, Om0_default, h0_default +from caustic.cosmology import Cosmology +from caustic.cosmology import FlatLambdaCDM as CausticFlatLambdaCDM +from caustic.cosmology import Om0_default, h0_default def get_cosmologies() -> List[Tuple[Cosmology, Cosmology_AP]]: diff --git a/test/test_epl.py b/test/test_epl.py index d0b2958e..43add0f4 100644 --- a/test/test_epl.py +++ b/test/test_epl.py @@ -3,19 +3,23 @@ import lenstronomy.Util.param_util as param_util import torch from lenstronomy.LensModel.lens_model import LensModel -from utils import Psi_test_helper, alpha_test_helper, kappa_test_helper +from utils import ( + Psi_test_helper, + alpha_test_helper, + get_default_cosmologies, + kappa_test_helper, +) -from caustic.cosmology import FlatLambdaCDM from caustic.lenses import EPL def test_lenstronomy(): # Models - cosmology = FlatLambdaCDM("cosmo") + cosmology, cosmology_ap = get_default_cosmologies() lens = EPL("epl", cosmology) # There is also an EPL_NUMBA class lenstronomy, but it shouldn't matter much lens_model_list = ["EPL"] - lens_ls = LensModel(lens_model_list=lens_model_list) + lens_ls = LensModel(lens_model_list=lens_model_list, cosmo=cosmology_ap) # Parameters z_s = torch.tensor(1.0) @@ -44,10 +48,10 @@ def test_special_case_sie(): """ Checks that the deflection field matches an SIE for `t=1`. """ - cosmology = FlatLambdaCDM("cosmo") + cosmology, cosmology_ap = get_default_cosmologies() lens = EPL("epl", cosmology) lens_model_list = ["SIE"] - lens_ls = LensModel(lens_model_list=lens_model_list) + lens_ls = LensModel(lens_model_list=lens_model_list, cosmo=cosmology_ap) # Parameters z_s = torch.tensor(1.9) diff --git a/test/test_external_shear.py b/test/test_external_shear.py index 8c0e4079..b8126595 100644 --- a/test/test_external_shear.py +++ b/test/test_external_shear.py @@ -1,8 +1,7 @@ import torch from lenstronomy.LensModel.lens_model import LensModel -from utils import lens_test_helper +from utils import get_default_cosmologies, lens_test_helper -from caustic.cosmology import FlatLambdaCDM from caustic.lenses import ExternalShear @@ -11,11 +10,10 @@ def test(): rtol = 1e-5 # Models - cosmology = FlatLambdaCDM("cosmo") + cosmology, cosmology_ap = get_default_cosmologies() lens = ExternalShear("shear", cosmology) lens_model_list = ["SHEAR"] - lens_ls = LensModel(lens_model_list=lens_model_list) - print(lens) + lens_ls = LensModel(lens_model_list=lens_model_list, cosmo=cosmology_ap) # Parameters z_s = torch.tensor(2.0) diff --git a/test/test_lenses.py b/test/test_lenses.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/test_multiplane.py b/test/test_multiplane.py index a9c659e8..94118f59 100644 --- a/test/test_multiplane.py +++ b/test/test_multiplane.py @@ -2,11 +2,9 @@ import lenstronomy.Util.param_util as param_util import torch -from astropy.cosmology import FlatLambdaCDM as FlatLambdaCDM_ap from lenstronomy.LensModel.lens_model import LensModel -from utils import lens_test_helper +from utils import get_default_cosmologies, lens_test_helper -from caustic.cosmology import FlatLambdaCDM from caustic.lenses import SIE, MultiplaneLens @@ -15,10 +13,8 @@ def test(): atol = 5e-3 # Setup - - z_s = torch.tensor(1.5, dtype=torch.float32) - cosmology = FlatLambdaCDM("cosmo") - cosmology.to(dtype=torch.float32) + z_s = torch.tensor(1.5) + cosmology, cosmology_ap = get_default_cosmologies() # Parameters xs = [ @@ -26,7 +22,7 @@ def test(): [0.7, 0.0, 0.5, 0.9999, -pi / 6, 0.7], [1.1, 0.4, 0.3, 0.9999, pi / 4, 0.9], ] - x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32) + x = torch.tensor([p for _xs in xs for p in _xs]) lens = MultiplaneLens( "multiplane", cosmology, [SIE(f"sie-{i}", cosmology) for i in range(len(xs))] @@ -47,12 +43,11 @@ def test(): ) # Use same cosmology - cosmo_ap = FlatLambdaCDM_ap(cosmology.h0.value, cosmology.Om0.value, Tcmb0=0) lens_ls = LensModel( lens_model_list=["SIE" for _ in range(len(xs))], z_source=z_s.item(), lens_redshift_list=[_xs[0] for _xs in xs], - cosmo=cosmo_ap, + cosmo=cosmology_ap, multi_plane=True, ) diff --git a/test/test_nfw.py b/test/test_nfw.py index 662d4a88..2cd537ff 100644 --- a/test/test_nfw.py +++ b/test/test_nfw.py @@ -1,35 +1,20 @@ -# from math import pi - -# import lenstronomy.Util.param_util as param_util import torch -from astropy.cosmology import FlatLambdaCDM as FlatLambdaCDM_AP -from astropy.cosmology import default_cosmology - -# next three imports to get Rs_angle and alpha_Rs in arcsec for lenstronomy from lenstronomy.Cosmo.lens_cosmo import LensCosmo from lenstronomy.LensModel.lens_model import LensModel -from utils import lens_test_helper +from utils import get_default_cosmologies, lens_test_helper -from caustic.cosmology import FlatLambdaCDM as CausticFlatLambdaCDM from caustic.lenses import NFW -h0_default = float(default_cosmology.get().h) -Om0_default = float(default_cosmology.get().Om0) -Ob0_default = float(default_cosmology.get().Ob0) - def test(): atol = 1e-5 rtol = 3e-2 # Models - cosmology = CausticFlatLambdaCDM("cosmo") + cosmology, cosmology_ap = get_default_cosmologies() z_l = torch.tensor(0.1) lens = NFW("nfw", cosmology, z_l=z_l) - lens_model_list = ["NFW"] - lens_ls = LensModel(lens_model_list=lens_model_list) - - print(lens) + lens_ls = LensModel(lens_model_list=["NFW"], cosmo=cosmology_ap) # Parameters z_s = torch.tensor(0.5) @@ -41,11 +26,8 @@ def test(): x = torch.tensor([thx0, thy0, m, c]) # Lenstronomy - cosmo = FlatLambdaCDM_AP(H0=h0_default * 100, Om0=Om0_default, Ob0=Ob0_default) - lens_cosmo = LensCosmo(z_lens=z_l.item(), z_source=z_s.item(), cosmo=cosmo) + lens_cosmo = LensCosmo(z_lens=z_l.item(), z_source=z_s.item(), cosmo=cosmology_ap) Rs_angle, alpha_Rs = lens_cosmo.nfw_physical2angle(M=m, c=c) - - # lenstronomy params ['Rs', 'alpha_Rs', 'center_x', 'center_y'] kwargs_ls = [ {"Rs": Rs_angle, "alpha_Rs": alpha_Rs, "center_x": thx0, "center_y": thy0} ] diff --git a/test/test_point.py b/test/test_point.py index a4255efb..6ffe7704 100644 --- a/test/test_point.py +++ b/test/test_point.py @@ -1,8 +1,7 @@ import torch from lenstronomy.LensModel.lens_model import LensModel -from utils import lens_test_helper +from utils import get_default_cosmologies, lens_test_helper -from caustic.cosmology import FlatLambdaCDM from caustic.lenses import Point @@ -11,10 +10,10 @@ def test(): rtol = 1e-5 # Models - cosmology = FlatLambdaCDM("cosmo") + cosmology, cosmology_ap = get_default_cosmologies() lens = Point("point", cosmology, z_l=torch.tensor(0.9)) lens_model_list = ["POINT_MASS"] - lens_ls = LensModel(lens_model_list=lens_model_list) + lens_ls = LensModel(lens_model_list=lens_model_list, cosmo=cosmology_ap) # Parameters z_s = torch.tensor(1.2) diff --git a/test/test_pseudo_jaffe.py b/test/test_pseudo_jaffe.py index 26051874..4225f37a 100644 --- a/test/test_pseudo_jaffe.py +++ b/test/test_pseudo_jaffe.py @@ -2,9 +2,8 @@ import torch from lenstronomy.LensModel.lens_model import LensModel -from utils import lens_test_helper +from utils import get_default_cosmologies, lens_test_helper -from caustic.cosmology import FlatLambdaCDM from caustic.lenses import PseudoJaffe @@ -13,10 +12,10 @@ def test(): rtol = 1e-5 # Models - cosmology = FlatLambdaCDM("cosmo") + cosmology, cosmology_ap = get_default_cosmologies() lens = PseudoJaffe("pj", cosmology) lens_model_list = ["PJAFFE"] - lens_ls = LensModel(lens_model_list=lens_model_list) + lens_ls = LensModel(lens_model_list=lens_model_list, cosmo=cosmology_ap) # Parameters, computing kappa_0 with a helper function z_s = torch.tensor(2.1) diff --git a/test/test_sie.py b/test/test_sie.py index 5d9e2ee8..4b63a7a5 100644 --- a/test/test_sie.py +++ b/test/test_sie.py @@ -3,9 +3,8 @@ import lenstronomy.Util.param_util as param_util import torch from lenstronomy.LensModel.lens_model import LensModel -from utils import lens_test_helper +from utils import get_default_cosmologies, lens_test_helper -from caustic.cosmology import FlatLambdaCDM from caustic.lenses import SIE @@ -14,10 +13,10 @@ def test(): rtol = 1e-5 # Models - cosmology = FlatLambdaCDM("cosmo") + cosmology, cosmology_ap = get_default_cosmologies() lens = SIE("sie", cosmology) lens_model_list = ["SIE"] - lens_ls = LensModel(lens_model_list=lens_model_list) + lens_ls = LensModel(lens_model_list=lens_model_list, cosmo=cosmology_ap) # Parameters z_s = torch.tensor(1.2) diff --git a/test/test_sis.py b/test/test_sis.py index e46ea36c..0e6ff9cc 100644 --- a/test/test_sis.py +++ b/test/test_sis.py @@ -1,8 +1,7 @@ import torch from lenstronomy.LensModel.lens_model import LensModel -from utils import lens_test_helper +from utils import get_default_cosmologies, lens_test_helper -from caustic.cosmology import FlatLambdaCDM from caustic.lenses import SIS @@ -11,16 +10,16 @@ def test(): rtol = 1e-5 # Models - cosmology = FlatLambdaCDM("cosmo", None) - lens = SIS("sis", cosmology, z_l=torch.tensor(0.5)) + cosmology, cosmology_ap = get_default_cosmologies() + lens = SIS("sis", cosmology) lens_model_list = ["SIS"] - lens_ls = LensModel(lens_model_list=lens_model_list) + lens_ls = LensModel(lens_model_list=lens_model_list, cosmo=cosmology_ap) # Parameters z_s = torch.tensor(1.2) - x = torch.tensor([-0.342, 0.51, 1.4, 0.7]) + x = torch.tensor([0.5, -0.342, 0.51, 1.4]) kwargs_ls = [ - {"center_x": x[0].item(), "center_y": x[1].item(), "theta_E": x[2].item()} + {"center_x": x[1].item(), "center_y": x[2].item(), "theta_E": x[3].item()} ] lens_test_helper(lens, lens_ls, z_s, x, kwargs_ls, rtol, atol) diff --git a/test/utils.py b/test/utils.py index 3a84313d..7cda28c7 100644 --- a/test/utils.py +++ b/test/utils.py @@ -1,14 +1,24 @@ from typing import Any, Dict, List, Union import numpy as np +from astropy.cosmology import FlatLambdaCDM as FlatLambdaCDM_AP from lenstronomy.Data.pixel_grid import PixelGrid from lenstronomy.LensModel.lens_model import LensModel +from caustic.cosmology import FlatLambdaCDM from caustic.lenses import ThinLens from caustic.lenses.base import ThickLens from caustic.utils import get_meshgrid +def get_default_cosmologies(): + cosmology = FlatLambdaCDM("cosmo") + cosmology_ap = FlatLambdaCDM_AP( + 100 * cosmology.h0.value, cosmology.Om0.value, Tcmb0=0 + ) + return cosmology, cosmology_ap + + def setup_grids(res=0.05, n_pix=100): # Caustic setup thx, thy = get_meshgrid(res, n_pix, n_pix) From 54ab65bd2e887448397c6eb512fab4e0e65d52d9 Mon Sep 17 00:00:00 2001 From: Adam Coogan Date: Fri, 17 Mar 2023 10:44:57 -0700 Subject: [PATCH 03/11] update time delay calculation --- src/caustic/lenses/base.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/caustic/lenses/base.py b/src/caustic/lenses/base.py index e37c9899..b7a3eb09 100644 --- a/src/caustic/lenses/base.py +++ b/src/caustic/lenses/base.py @@ -130,19 +130,20 @@ def raytrace( ax, ay = self.alpha(thx, thy, z_s, x) return thx - ax, thy - ay + def fermat_potential( + self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None + ) -> Tensor: + ax, ay = self.alpha(thx, thy, z_s, x) + Psi = self.Psi(thx, thy, z_s, x) + return 0.5 * (ax**2 + ay**2) - Psi + def time_delay( self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None ): z_l = self.unpack(x)[0] - - d_l = self.cosmology.angular_diameter_dist(z_l, x) - d_s = self.cosmology.angular_diameter_dist(z_s, x) - d_ls = self.cosmology.angular_diameter_dist_z1z2(z_l, z_s, x) - ax, ay = self.alpha(thx, thy, z_s, x) - Psi = self.Psi(thx, thy, z_s, x) - factor = (1 + z_l) / c_Mpc_s * d_s * d_l / d_ls - fp = 0.5 * d_ls**2 / d_s**2 * (ax**2 + ay**2) - Psi - return factor * fp * arcsec_to_rad**2 + d_td = self.cosmology.time_delay_dist(z_l, z_s, x) + phi = self.fermat_potential(thx, thy, z_s, x) + return d_td * phi * arcsec_to_rad**2 / c_Mpc_s def _lensing_jacobian_fft_method( self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None From 557b91f99e3fd1465a6c49e90a5cac90bab70953 Mon Sep 17 00:00:00 2001 From: Adam Coogan Date: Fri, 17 Mar 2023 10:45:05 -0700 Subject: [PATCH 04/11] start incorporating alphagen --- src/caustic/lenses/alphagen/__init__.py | 3 + src/caustic/lenses/alphagen/models.py | 193 +++++++++++++++++++ src/caustic/lenses/alphagen/normalizer.py | 219 ++++++++++++++++++++++ src/caustic/lenses/alphagen/stblocks.py | 91 +++++++++ 4 files changed, 506 insertions(+) create mode 100644 src/caustic/lenses/alphagen/__init__.py create mode 100644 src/caustic/lenses/alphagen/models.py create mode 100644 src/caustic/lenses/alphagen/normalizer.py create mode 100644 src/caustic/lenses/alphagen/stblocks.py diff --git a/src/caustic/lenses/alphagen/__init__.py b/src/caustic/lenses/alphagen/__init__.py new file mode 100644 index 00000000..9a15f89f --- /dev/null +++ b/src/caustic/lenses/alphagen/__init__.py @@ -0,0 +1,3 @@ +from .models import * +from .normalizer import * +from .stblocks import * diff --git a/src/caustic/lenses/alphagen/models.py b/src/caustic/lenses/alphagen/models.py new file mode 100644 index 00000000..5a6284b7 --- /dev/null +++ b/src/caustic/lenses/alphagen/models.py @@ -0,0 +1,193 @@ +import torch +from torch import nn + +from .stblocks import ISAB, PMA, SAB + + +class AlphaSetTransformerISAB(nn.Module): + def __init__( + self, + in_dim, + hidden_dims_enc, + hidden_dims_dec, + out_dim, + num_heads, + num_inds, + num_out=1, + ln=False, + ): + super().__init__() + + enc_layers = [ISAB(in_dim, hidden_dims_enc[0], num_heads, num_inds, ln=ln)] + for i in range(len(hidden_dims_enc) - 1): + enc_layers.append( + ISAB( + hidden_dims_enc[i], + hidden_dims_enc[i + 1], + num_heads, + num_inds, + ln=ln, + ) + ) + + self.enc = nn.Sequential(*enc_layers) + + dec_layers: list[nn.Module] = [ + PMA(hidden_dims_enc[-1], num_heads, num_out, ln=ln) + ] + hidden_dims_dec.insert(0, hidden_dims_enc[-1]) + for i in range(len(hidden_dims_dec) - 1): + dec_layers.append( + SAB(hidden_dims_dec[i], hidden_dims_dec[i + 1], num_heads, ln=ln) + ) + + dec_layers.append(nn.Linear(hidden_dims_dec[-1], out_dim)) + self.dec = nn.Sequential(*dec_layers) + + def forward(self, x): + enc_out = self.enc(x) + dec_out = self.dec(enc_out).squeeze(dim=1) + + return dec_out + + +class AlphaSetTransformerSAB(nn.Module): + def __init__( + self, + in_dim, + hidden_dims_enc, + hidden_dims_dec, + out_dim, + num_heads, + num_out=1, + ln=False, + ): + super().__init__() + + enc_layers = [SAB(in_dim, hidden_dims_enc[0], num_heads, ln=ln)] + for i in range(len(hidden_dims_enc) - 1): + enc_layers.append( + SAB(hidden_dims_enc[i], hidden_dims_enc[i + 1], num_heads, ln=ln) + ) + + self.enc = nn.Sequential(*enc_layers) + + dec_layers: list[nn.Module] = [ + PMA(hidden_dims_enc[-1], num_heads, num_out, ln=ln) + ] + hidden_dims_dec.insert(0, hidden_dims_enc[-1]) + for i in range(len(hidden_dims_dec) - 1): + dec_layers.append( + SAB(hidden_dims_dec[i], hidden_dims_dec[i + 1], num_heads, ln=ln) + ) + + dec_layers.append(nn.Linear(hidden_dims_dec[-1], out_dim)) + self.dec = nn.Sequential(*dec_layers) + + def forward(self, x): + enc_out = self.enc(x) + dec_out = self.dec(enc_out).squeeze(dim=1) + + return dec_out + + +class ConvNet2D(nn.Module): + def __init__( + self, in_ch, chs, out_ch, activation, final_activation, wrapper_func=None + ): + super().__init__() + + if not wrapper_func: + wrapper_func = lambda x: x + + chs.append(out_ch) + + layers = [ + nn.Conv2d(in_channels=in_ch, out_channels=chs[0], kernel_size=1, padding=0) + ] + + for i in range(len(chs) - 1): + layers.append(getattr(nn, activation)()) + layers.append( + wrapper_func( + nn.Conv2d( + in_channels=chs[i], + out_channels=chs[i + 1], + kernel_size=1, + padding=0, + ) + ) + ) + assert layers[-1].bias is not None + layers[-1].bias.data.fill_(0.0) + + if final_activation is not None: + layers.append(getattr(nn, final_activation)()) + + self.net = nn.Sequential(*layers) + + def forward(self, x): + """ + :param x: input Tensor, shape (b, c, npix, npix) + """ + return self.net(x) + + +class JointModel_ST_CNN(nn.Module): + def __init__( + self, + in_dim_SM, + hidden_dims_enc_SM, + hidden_dims_dec_SM, + embedding_size, + num_heads_SM, + layernorm_SM, + in_dim_MM, + hidden_dims_MM, + out_dim_MM, + npix, + normalizer, + activation_MM="ReLU", + final_activation_MM=None, + num_inds_SM=None, + ): + super().__init__() + + self.npix = npix + self.embedding_size = embedding_size + self.normalizer = normalizer + + if num_inds_SM is not None: + self.ST = AlphaSetTransformerISAB( + in_dim_SM, + hidden_dims_enc_SM, + hidden_dims_dec_SM, + embedding_size, + num_heads_SM, + num_inds_SM, + ln=layernorm_SM, + ) + else: + self.ST = AlphaSetTransformerSAB( + in_dim_SM, + hidden_dims_enc_SM, + hidden_dims_dec_SM, + embedding_size, + num_heads_SM, + ln=layernorm_SM, + ) + + self.CNN = ConvNet2D( + in_dim_MM, hidden_dims_MM, out_dim_MM, activation_MM, final_activation_MM + ) + + def forward(self, h, x): + embedding = self.ST(h) + embedding_ = ( + embedding.reshape(-1, self.embedding_size, 1, 1) + .repeat_interleave(self.npix, dim=-2) + .repeat_interleave(self.npix, dim=-1) + ) + cnn_input = torch.cat([x, embedding_], dim=1) + out = self.CNN(cnn_input) + return out diff --git a/src/caustic/lenses/alphagen/normalizer.py b/src/caustic/lenses/alphagen/normalizer.py new file mode 100644 index 00000000..4c159c5d --- /dev/null +++ b/src/caustic/lenses/alphagen/normalizer.py @@ -0,0 +1,219 @@ +import ast + +import h5py +import numpy as np +import torch +from scipy.interpolate import splrep + +__all__ = ("Normalizer",) + + +class Normalizer: + def __init__(self, normalize, mdef=None, segment="FG", stats_path=None): + """ + mdef == "PJAFFE" -> hset_params = ['center_x', 'center_y', 'Ra', 'Rs', 'sigma0', 'z'] + mdef == "NFW" -> hset_params = ["alpha_Rs", "Rs", "center_x", "center_y", "z"] + """ + self.normalize = normalize + self.mdef = mdef + self.segment = segment + if self.normalize == "from_stats": + try: + stats_file = h5py.File(stats_path, mode="r") + self.cone_fov = ast.literal_eval( + stats_file["base"].attrs["dataset_descriptor"] + )["cone_fov"] + + self.z = torch.tensor(stats_file["base"]["z"][:], dtype=torch.float) # type: ignore + self.z_norm = torch.tensor(stats_file["base"]["z_norm"][:], dtype=torch.float).reshape(1, -1, 1, 1) # type: ignore + + self.z_norm_fn = splrep( + np.insert(self.z.squeeze().detach().cpu().numpy(), 0, 0.0), + np.insert(self.z_norm.squeeze().detach().cpu().numpy(), 0, 0.0), + ) + + self.hset_bounds = torch.tensor(stats_file["base"]["hset_bounds"][:], dtype=torch.float) # type: ignore + if self.mdef == "PJAFFE": + self.log_slice = slice(2, -1) + self.hset_bounds[:, :2] = torch.tile( + torch.tensor( + [-self.cone_fov / 2, self.cone_fov / 2], dtype=torch.float + ), + (2, 1), + ).T # bounds for halo positions are known + self.hset_bounds[:, self.log_slice] = torch.log10( + self.hset_bounds[:, 2:-1] + ) # take log of non-position dimensions except z + elif self.mdef == "NFW": + self.log_slice = slice(2) + self.hset_bounds[:, 2:4] = torch.tile( + torch.tensor( + [-self.cone_fov / 2, self.cone_fov / 2], dtype=torch.float + ), + (2, 1), + ).T # bounds for halo positions are known + self.hset_bounds[:, self.log_slice] = torch.log10( + self.hset_bounds[:, :2] + ) # take log of non-position dimensions except z + + self.reverse_alpha_norm = self.reverse_alpha_from_stats + + if self.segment == "FG": + self.forward_norm = self.forward_from_stats_FG + self.reverse_norm = self.reverse_from_stats_FG + self.forward_x_norm = self.forward_x_FG + self.reverse_x_norm = self.reverse_x_FG + + elif self.segment == "BG" or self.segment == "FULL": + self.a_LP_bound = stats_file["base"]["a_LP_bound"][()] # type: ignore + self.beta_LP_bound = stats_file["base"]["beta_LP_bound"][()] # type: ignore + + self.forward_norm = self.forward_from_stats_BG + self.reverse_norm = self.reverse_from_stats_BG + self.forward_x_norm = self.forward_x_BG + self.reverse_x_norm = self.reverse_x_BG + except FileNotFoundError: + raise FileNotFoundError(f"stats file {stats_path} does not exist") + + elif self.normalize == None: + self.forward_norm = self.forward_null + self.reverse_norm = self.reverse_null + self.reverse_alpha_norm = self.reverse_alpha_null + self.forward_x_norm = self.forward_x_null + self.reverse_x_norm = self.reverse_x_null + + def forward(self, *args, **kwargs): + return self.forward_norm(*args, **kwargs) + + def reverse(self, *args, **kwargs): + return self.reverse_norm(*args, **kwargs) + + def reverse_alpha(self, a, plane_ids=None): + return self.reverse_alpha_norm(a, plane_ids) + + def forward_x(self, x): + return self.forward_x_norm(x) + + def reverse_x(self, x): + return self.reverse_x_norm(x) + + def reverse_alpha_from_stats(self, a, plane_ids): + return a * self.z_norm[:, plane_ids].to(a.device) + + def forward_x_FG(self, x): + x_scaled = self._min_max_scale( + x, bounds=(-self.cone_fov / 2, self.cone_fov / 2) + ) + return x_scaled + + def reverse_x_FG(self, x_scaled): + x = self._min_max_unscale( + x_scaled, bounds=(-self.cone_fov / 2, self.cone_fov / 2) + ) + return x + + def forward_x_BG(self, a_LP, beta_LP): + a_LP_scaled = self._min_max_scale(a_LP, bounds=(-self.a_LP_bound, self.a_LP_bound)) # type: ignore + beta_LP_scaled = self._min_max_scale(beta_LP, bounds=(-self.beta_LP_bound, self.beta_LP_bound)) # type: ignore + + return a_LP_scaled, beta_LP_scaled + + def reverse_x_BG(self, a_LP_scaled, beta_LP_scaled): + a_LP = self._min_max_unscale(a_LP_scaled, bounds=(-self.a_LP_bound, self.a_LP_bound)) # type: ignore + beta_LP = self._min_max_unscale(beta_LP_scaled, bounds=(-self.beta_LP_bound, self.beta_LP_bound)) # type: ignore + + return a_LP, beta_LP + + def forward_from_stats_FG(self, hset, a=None): + hset_scaled = torch.clone(hset) + hset_scaled[..., self.log_slice] = torch.log10(hset_scaled[..., self.log_slice]) + hset_scaled[..., :-1] = self._min_max_scale( + hset_scaled[..., :-1], bounds=self.hset_bounds[:, :-1] + ) + if a is not None: + a_scaled = self._safe_divide_z_norm(a, self.z_norm) + else: + a_scaled = None + return hset_scaled, a_scaled + + def reverse_from_stats_FG(self, hset_scaled, a_scaled=None): + hset = torch.clone(hset_scaled) + hset[..., :-1] = self._min_max_unscale( + hset[..., :-1], bounds=self.hset_bounds[:, :-1] + ) + hset[..., self.log_slice] = 10 ** hset[..., self.log_slice] + if a_scaled is not None: + a = self.z_norm * a_scaled + else: + a = None + return hset, a + + def forward_from_stats_BG(self, hset, a_LP, beta_LP, a=None): + hset_scaled = torch.clone(hset) + hset_scaled[..., self.log_slice] = torch.log10(hset_scaled[..., self.log_slice]) + hset_scaled[..., :-1] = self._min_max_scale( + hset_scaled[..., :-1], bounds=self.hset_bounds[:, :-1] + ) + + a_LP_scaled = self._min_max_scale(a_LP, bounds=(-self.a_LP_bound, self.a_LP_bound)) # type: ignore + beta_LP_scaled = self._min_max_scale(beta_LP, bounds=(-self.beta_LP_bound, self.beta_LP_bound)) # type: ignore + + if a is not None: + a_scaled = self._safe_divide_z_norm(a, self.z_norm) + else: + a_scaled = None + return hset_scaled, a_LP_scaled, beta_LP_scaled, a_scaled + + def reverse_from_stats_BG( + self, hset_scaled, a_LP_scaled, beta_LP_scaled, a_scaled=None + ): + hset = torch.clone(hset_scaled) + hset[..., :-1] = self._min_max_unscale( + hset[..., :-1], bounds=self.hset_bounds[:, :-1] + ) + hset[..., self.log_slice] = 10 ** hset[..., self.log_slice] + + a_LP = self._min_max_unscale(a_LP_scaled, bounds=(-self.a_LP_bound, self.a_LP_bound)) # type: ignore + beta_LP = self._min_max_unscale(beta_LP_scaled, bounds=(-self.beta_LP_bound, self.beta_LP_bound)) # type: ignore + + if a_scaled is not None: + a = self.z_norm * a_scaled + else: + a = None + return hset, a_LP, beta_LP, a + + def forward_null(self, *args): + return args + + def reverse_null(self, *args): + return args + + def reverse_alpha_null(self, a, *args): + return a + + def forward_x_null(self, x): + return x + + def reverse_x_null(self, x): + return x + + @staticmethod + def _min_max_scale(arr, bounds, vrange=(-1, 1)): + return vrange[0] + (arr - bounds[0]) * (vrange[1] - vrange[0]) / ( + bounds[1] - bounds[0] + ) + + @staticmethod + def _min_max_unscale(arr, bounds, vrange=(-1, 1)): + return (arr - vrange[0]) * (bounds[1] - bounds[0]) / ( + vrange[1] - vrange[0] + ) + bounds[0] + + @staticmethod + def _safe_divide_z_norm(num, denom): + out = torch.zeros_like(num, dtype=torch.float) + out[:, torch.squeeze(denom) != 0, ...] = ( + num[:, torch.squeeze(denom) != 0, ...] + / denom[:, torch.squeeze(denom) != 0, ...] + ) + return out diff --git a/src/caustic/lenses/alphagen/stblocks.py b/src/caustic/lenses/alphagen/stblocks.py new file mode 100644 index 00000000..df945f70 --- /dev/null +++ b/src/caustic/lenses/alphagen/stblocks.py @@ -0,0 +1,91 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MAB(nn.Module): + """ + Multi-Head Attention Block + """ + + def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): + super().__init__() + self.dim_Q = dim_Q + self.dim_V = dim_V + self.num_heads = num_heads + self.fc_q = nn.Linear(dim_Q, dim_V) + self.fc_k = nn.Linear(dim_K, dim_V) + self.fc_v = nn.Linear(dim_K, dim_V) + if ln: + self.ln0 = nn.LayerNorm(dim_V) + self.ln1 = nn.LayerNorm(dim_V) + self.fc_o = nn.Linear(dim_V, dim_V) + + def forward(self, Q, K): + bigQ = self.fc_q(Q) + bigK = self.fc_k(K) + bigV = self.fc_v(K) + + dim_split = self.dim_V // self.num_heads + Q_ = torch.cat(bigQ.split(dim_split, -1), 0) + K_ = torch.cat(bigK.split(dim_split, -1), 0) + V_ = torch.cat(bigV.split(dim_split, -1), 0) + + A = torch.softmax( + Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_Q), -1 + ) # softmax on last dim of QK^T product + + O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) + O = O if getattr(self, "ln0", None) is None else self.ln0(O) + O = O + F.relu(self.fc_o(O)) + MAB_out = O if getattr(self, "ln1", None) is None else self.ln1(O) + return MAB_out + + +class SAB(nn.Module): + """ + Set Attention Block + + SAB(X) = MAB(X,X) + """ + + def __init__(self, in_dim, out_dim, num_heads, ln=False): + super().__init__() + self.mab = MAB(in_dim, in_dim, out_dim, num_heads, ln=ln) + + def forward(self, X): + return self.mab(X, X) + + +class ISAB(nn.Module): + """ + Induced Set Attention Block + """ + + def __init__(self, in_dim, out_dim, num_heads, num_inds, ln=False): + super().__init__() + self.I = nn.Parameter(torch.Tensor(1, num_inds, out_dim)) + nn.init.xavier_uniform_(self.I) + self.mab0 = MAB(out_dim, in_dim, out_dim, num_heads, ln=ln) + self.mab1 = MAB(in_dim, out_dim, out_dim, num_heads, ln=ln) + + def forward(self, X): + H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) + return self.mab1(X, H) + + +class PMA(nn.Module): + """ + Pooling by Multi-Head Attention + """ + + def __init__(self, dim, num_heads, num_seeds, ln=False): + super().__init__() + self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) + nn.init.xavier_uniform_(self.S) + self.mab = MAB(dim, dim, dim, num_heads, ln=ln) + + def forward(self, X): + return self.mab(self.S.repeat(X.size(0), 1, 1), X) From b3b5f53a59aaa0481e6c1a3ea6c9c25bd7ab503c Mon Sep 17 00:00:00 2001 From: Adam Coogan Date: Fri, 14 Apr 2023 10:03:19 -0400 Subject: [PATCH 05/11] update readme --- README.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 8389c21f..42b94f06 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,9 @@ pip install -e ".[dev]" ``` This creates an editable install and installs the dev dependencies. -Please use `isort` and `black` to format your code. Open up issues for bugs/missing -features. Use pull requests for additions to the code. Write tests that can be run -by [`pytest`](https://docs.pytest.org/). +Some guidelines: +- Please use `isort` and `black` to format your code. +- Use `CamelCase` for class names and `snake_case` for variable and method names. +- Open up issues for bugs/missing features. +- Use pull requests for additions to the code. +- Write tests that can be run by [`pytest`](https://docs.pytest.org/). From 31d76cb8758135cd06c0ca3e035eeac495a9a234 Mon Sep 17 00:00:00 2001 From: AlexandreAdam Date: Fri, 14 Jul 2023 10:45:16 -0400 Subject: [PATCH 06/11] Moved alphagen --- .../lenses/alphagen/__init__.py | 0 .../lenses/alphagen/models.py | 0 .../lenses/alphagen/normalizer.py | 0 .../lenses/alphagen/stblocks.py | 0 src/caustic/cosmology.py | 141 --------------- src/caustic/lenses/base.py | 171 ------------------ 6 files changed, 312 deletions(-) rename {src/caustic => caustic}/lenses/alphagen/__init__.py (100%) rename {src/caustic => caustic}/lenses/alphagen/models.py (100%) rename {src/caustic => caustic}/lenses/alphagen/normalizer.py (100%) rename {src/caustic => caustic}/lenses/alphagen/stblocks.py (100%) delete mode 100644 src/caustic/cosmology.py delete mode 100644 src/caustic/lenses/base.py diff --git a/src/caustic/lenses/alphagen/__init__.py b/caustic/lenses/alphagen/__init__.py similarity index 100% rename from src/caustic/lenses/alphagen/__init__.py rename to caustic/lenses/alphagen/__init__.py diff --git a/src/caustic/lenses/alphagen/models.py b/caustic/lenses/alphagen/models.py similarity index 100% rename from src/caustic/lenses/alphagen/models.py rename to caustic/lenses/alphagen/models.py diff --git a/src/caustic/lenses/alphagen/normalizer.py b/caustic/lenses/alphagen/normalizer.py similarity index 100% rename from src/caustic/lenses/alphagen/normalizer.py rename to caustic/lenses/alphagen/normalizer.py diff --git a/src/caustic/lenses/alphagen/stblocks.py b/caustic/lenses/alphagen/stblocks.py similarity index 100% rename from src/caustic/lenses/alphagen/stblocks.py rename to caustic/lenses/alphagen/stblocks.py diff --git a/src/caustic/cosmology.py b/src/caustic/cosmology.py deleted file mode 100644 index bfe176d5..00000000 --- a/src/caustic/cosmology.py +++ /dev/null @@ -1,141 +0,0 @@ -from abc import abstractmethod -from math import pi -from typing import Any, Optional - -import torch -from astropy.cosmology import default_cosmology -from scipy.special import hyp2f1 -from torch import Tensor -from torchinterp1d import interp1d - -from .constants import G_over_c2, c_Mpc_s, km_to_Mpc -from .parametrized import Parametrized - -__all__ = ( - "h0_default", - "rho_cr_0_default", - "Om0_default", - "Cosmology", - "FlatLambdaCDM", -) - -h0_default = float(default_cosmology.get().h) -rho_cr_0_default = float( - default_cosmology.get().critical_density(0).to("solMass/Mpc^3").value -) -Om0_default = float(default_cosmology.get().Om0) - -# Set up interpolator to speed up comoving distance calculations in Lambda-CDM -# cosmologies. Construct with float64 precision. -_comoving_dist_helper_x_grid = 10 ** torch.linspace(-3, 1, 500, dtype=torch.float64) -_comoving_dist_helper_y_grid = torch.as_tensor( - _comoving_dist_helper_x_grid - * hyp2f1(1 / 3, 1 / 2, 4 / 3, -(_comoving_dist_helper_x_grid**3)), - dtype=torch.float64, -) - - -class Cosmology(Parametrized): - """ - Units: - - Distance: Mpc - - Mass: solMass - """ - - def __init__(self, name: str): - super().__init__(name) - - @abstractmethod - def rho_cr(self, z: Tensor, x: Optional[dict[str, Any]] = None) -> Tensor: - ... - - @abstractmethod - def comoving_dist(self, z: Tensor, x: Optional[dict[str, Any]] = None) -> Tensor: - ... - - def comoving_dist_z1z2( - self, z1: Tensor, z2: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - return self.comoving_dist(z2, x) - self.comoving_dist(z1, x) - - def angular_diameter_dist( - self, z: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - return self.comoving_dist(z, x) / (1 + z) - - def angular_diameter_dist_z1z2( - self, z1: Tensor, z2: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - return self.comoving_dist_z1z2(z1, z2, x) / (1 + z2) - - def time_delay_dist( - self, z_l: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - d_l = self.angular_diameter_dist(z_l, x) - d_s = self.angular_diameter_dist(z_s, x) - d_ls = self.angular_diameter_dist_z1z2(z_l, z_s, x) - return (1 + z_l) * d_l * d_s / d_ls - - def Sigma_cr( - self, z_l: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - d_l = self.angular_diameter_dist(z_l, x) - d_s = self.angular_diameter_dist(z_s, x) - d_ls = self.angular_diameter_dist_z1z2(z_l, z_s, x) - return d_s / d_l / d_ls / (4 * pi * G_over_c2) - - -class FlatLambdaCDM(Cosmology): - """ - - Flat LCDM cosmology with no radiation. - """ - - def __init__( - self, - name: str, - h0: Optional[Tensor] = torch.tensor(h0_default), - rho_cr_0: Optional[Tensor] = torch.tensor(rho_cr_0_default), - Om0: Optional[Tensor] = torch.tensor(Om0_default), - ): - super().__init__(name) - - self.add_param("h0", h0) - self.add_param("rho_cr_0", rho_cr_0) - self.add_param("Om0", Om0) - - self._comoving_dist_helper_x_grid = _comoving_dist_helper_x_grid.to( - dtype=torch.float32 - ) - self._comoving_dist_helper_y_grid = _comoving_dist_helper_y_grid.to( - dtype=torch.float32 - ) - - def dist_hubble(self, h0): - return c_Mpc_s / (100 * km_to_Mpc) / h0 - - def rho_cr(self, z: Tensor, x: Optional[dict[str, Any]] = None) -> torch.Tensor: - _, rho_cr_0, Om0 = self.unpack(x) - Ode0 = 1 - Om0 - return rho_cr_0 * (Om0 * (1 + z) ** 3 + Ode0) - - def _comoving_dist_helper(self, x: Tensor) -> Tensor: - return interp1d( - self._comoving_dist_helper_x_grid, - self._comoving_dist_helper_y_grid, - torch.atleast_1d(x), - ).reshape(x.shape) - - def comoving_dist(self, z: Tensor, x: Optional[dict[str, Any]] = None) -> Tensor: - h0, _, Om0 = self.unpack(x) - - Ode0 = 1 - Om0 - ratio = (Om0 / Ode0) ** (1 / 3) - return ( - self.dist_hubble(h0) - * ( - self._comoving_dist_helper((1 + z) * ratio) - - self._comoving_dist_helper(ratio) - ) - / (Om0 ** (1 / 3) * Ode0 ** (1 / 6)) - ) diff --git a/src/caustic/lenses/base.py b/src/caustic/lenses/base.py deleted file mode 100644 index b7a3eb09..00000000 --- a/src/caustic/lenses/base.py +++ /dev/null @@ -1,171 +0,0 @@ -from abc import abstractmethod -from typing import Any, Optional - -import torch -from torch import Tensor - -from ..constants import arcsec_to_rad, c_Mpc_s -from ..cosmology import Cosmology -from ..parametrized import Parametrized -from .utils import get_magnification - -__all__ = ("ThinLens", "ThickLens") - - -class ThickLens(Parametrized): - """ - Base class for lenses that can't be treated in the thin lens approximation. - """ - - def __init__(self, name: str, cosmology: Cosmology): - super().__init__(name) - self.cosmology = cosmology - - @abstractmethod - def alpha( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> tuple[Tensor, Tensor]: - """ - Reduced deflection angle [arcsec] - """ - ... - - def raytrace( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> tuple[Tensor, Tensor]: - ax, ay = self.alpha(thx, thy, z_s, x) - return thx - ax, thy - ay - - @abstractmethod - def Sigma( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - """ - Projected mass density. - - Returns: - [solMass / Mpc^2] - """ - ... - - @abstractmethod - def time_delay( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - ... - - def magnification(self, thx: Tensor, thy: Tensor, z_s: Tensor, x) -> Tensor: - return get_magnification(self.raytrace, thx, thy, z_s, x) - - -class ThinLens(Parametrized): - """ - Base class for lenses that can be treated in the thin lens approximation. - """ - - def __init__(self, name: str, cosmology: Cosmology, z_l: Optional[Tensor] = None): - super().__init__(name) - self.cosmology = cosmology - self.add_param("z_l", z_l) - - @abstractmethod - def alpha( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> tuple[Tensor, Tensor]: - """ - Reduced deflection angle [arcsec] - """ - ... - - def alpha_hat( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> tuple[Tensor, Tensor]: - """ - Physical deflection angle immediately after passing through this lens' - plane [arcsec]. - """ - z_l = self.unpack(x)[0] - - d_s = self.cosmology.angular_diameter_dist(z_s, x) - d_ls = self.cosmology.angular_diameter_dist_z1z2(z_l, z_s, x) - alpha_x, alpha_y = self.alpha(thx, thy, z_s, x) - return (d_s / d_ls) * alpha_x, (d_s / d_ls) * alpha_y - - @abstractmethod - def kappa( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - """ - Convergence [1] - """ - ... - - @abstractmethod - def Psi( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - """ - Potential [arcsec^2] - """ - ... - - def Sigma( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - """ - Surface mass density. - - Returns: - [solMass / Mpc^2] - """ - # Superclass params come before subclass ones - z_l = self.unpack(x)[0] - - Sigma_cr = self.cosmology.Sigma_cr(z_l, z_s, x) - return self.kappa(thx, thy, z_s, x) * Sigma_cr - - def raytrace( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> tuple[Tensor, Tensor]: - ax, ay = self.alpha(thx, thy, z_s, x) - return thx - ax, thy - ay - - def fermat_potential( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - ax, ay = self.alpha(thx, thy, z_s, x) - Psi = self.Psi(thx, thy, z_s, x) - return 0.5 * (ax**2 + ay**2) - Psi - - def time_delay( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ): - z_l = self.unpack(x)[0] - d_td = self.cosmology.time_delay_dist(z_l, z_s, x) - phi = self.fermat_potential(thx, thy, z_s, x) - return d_td * phi * arcsec_to_rad**2 / c_Mpc_s - - def _lensing_jacobian_fft_method( - self, thx: Tensor, thy: Tensor, z_s: Tensor, x: Optional[dict[str, Any]] = None - ) -> Tensor: - psi = self.Psi(thx, thy, z_s, x) - # quick dirty work to get kx and ky. Assumes thx and thy come from meshgrid... TODO Might want to get k differently - n = thx.shape[-1] - d = torch.abs(thx[0, 0] - thx[0, 1]) - k = torch.fft.fftfreq(2 * n, d=d) - kx, ky = torch.meshgrid([k, k], indexing="xy") - # Now we compute second derivatives in Fourier space, then inverse Fourier transform and unpad - pad = 2 * n - psi_tilde = torch.fft.fft(psi, (pad, pad)) - psi_xx = torch.abs(torch.fft.ifft2(-(kx**2) * psi_tilde))[..., :n, :n] - psi_yy = torch.abs(torch.fft.ifft2(-(ky**2) * psi_tilde))[..., :n, :n] - psi_xy = torch.abs(torch.fft.ifft2(-kx * ky * psi_tilde))[..., :n, :n] - j1 = torch.stack( - [1 - psi_xx, -psi_xy], dim=-1 - ) # Equation 2.33 from Meneghetti lensing lectures - j2 = torch.stack([-psi_xy, 1 - psi_yy], dim=-1) - jacobian = torch.stack([j1, j2], dim=-1) - return jacobian - - def magnification(self, thx: Tensor, thy: Tensor, z_s: Tensor, x) -> Tensor: - return get_magnification(self.raytrace, thx, thy, z_s, x) From 01a19596c027081c5578617bb5df61da0b7f809e Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Fri, 20 Oct 2023 15:10:18 -0400 Subject: [PATCH 07/11] add docs status badge --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 1b42356b..12526234 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ [![tests](https://github.com/Ciela-Institute/caustic/actions/workflows/python-app.yml/badge.svg?branch=main)](https://github.com/Ciela-Institute/caustic/actions) +[![Docs](https://github.com/Ciela-Institute/caustic/actions/workflows/documentation.yaml/badge.svg)](https://github.com/Ciela-Institute/caustic/actions/workflows/documentation.yaml) [![PyPI version](https://badge.fury.io/py/caustic.svg)](https://pypi.org/project/caustic/) [![coverage](https://img.shields.io/codecov/c/github/Ciela-Institute/caustic)](https://app.codecov.io/gh/Ciela-Institute/caustic) # caustic From fe06e1e7dd27db6232726abad6bbdbcaeaf48d17 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Fri, 20 Oct 2023 17:44:47 -0400 Subject: [PATCH 08/11] forward raytrace now tests that points map back to correct source psoition to within 1e-4 arcsec --- caustic/lenses/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/caustic/lenses/base.py b/caustic/lenses/base.py index 5bb15139..a0d2ddea 100644 --- a/caustic/lenses/base.py +++ b/caustic/lenses/base.py @@ -428,7 +428,7 @@ def raytrace( @unpack(3) def forward_raytrace( - self, bx: Tensor, by: Tensor, z_s: Tensor, *args, params: Optional["Packed"] = None, epsilon = 1e-2, n_init = 50, fov = 5., **kwargs + self, bx: Tensor, by: Tensor, z_s: Tensor, *args, params: Optional["Packed"] = None, epsilon = 1e-4, n_init = 100, fov = 5., **kwargs ) -> tuple[Tensor, Tensor]: """ Perform a forward ray-tracing operation which maps from the source plane to the image plane. @@ -438,7 +438,7 @@ def forward_raytrace( by (Tensor): Tensor of y coordinate in the source plane (scalar). z_s (Tensor): Tensor of source redshifts. params (Packed, optional): Dynamic parameter container for the lens model. Defaults to None. - epsilon (Tensor): maximum distance between two images (arcsec) before they are considered the same image. + epsilon (Tensor): maximum distance between two images (arcsec) before they are considered the same image. Also used as constraint to select failed optimizations; After fitting in image plane, points are raytraced back to source plane and must fall within a radius of epsilon of the input bx,by to be accepted. n_init (int): number of random initialization points used to try and find image plane points. fov (float): the field of view in which the initial random samples are taken. @@ -463,8 +463,10 @@ def forward_raytrace( f_args = (z_s, params) ) - # Clip points that didn't converge - x = x[c < 1e-2*epsilon**2] + # Clip points that didn't converge to the source point + bx_fit, by_fit = self.raytrace(x[...,0], x[...,1], z_s, params) + R = torch.sqrt((bx_fit - bx)**2 + (by_fit - by)**2) + x = x[R < epsilon] # Cluster results into n-images res = [] From 778489d28119a258b7287d67059b4c059336c706 Mon Sep 17 00:00:00 2001 From: AlexandreAdam Date: Fri, 3 Nov 2023 12:33:09 -0400 Subject: [PATCH 09/11] Added and tested check for valid names --- caustic/parametrized.py | 10 ++++++++++ test/test_jacobian_lens_equation.py | 6 +++--- test/test_multiplane.py | 2 +- test/test_parametrized.py | 22 +++++++++++++++++----- 4 files changed, 31 insertions(+), 9 deletions(-) diff --git a/caustic/parametrized.py b/caustic/parametrized.py index 4c91e1f5..369858dc 100644 --- a/caustic/parametrized.py +++ b/caustic/parametrized.py @@ -6,6 +6,7 @@ import torch import re +import keyword from torch import Tensor from .packed import Packed @@ -14,6 +15,13 @@ __all__ = ("Parametrized","unpack") + +def check_valid_name(name): + if keyword.iskeyword(name) or not bool(re.match("^[a-zA-Z_][a-zA-Z0-9_]*$", name)): + raise NameError(f"The string {name} contain illegal characters (like space or '-'). "\ + "Please use snake case or another valid python variable anming style.") + + class Parametrized: """ Represents a class with Param and Parametrized attributes, typically used to construct parts of a simulator @@ -40,6 +48,7 @@ class Parametrized: def __init__(self, name: str = None): if name is None: name = self._default_name() + check_valid_name(name) if not isinstance(name, str): raise ValueError(f"name must be a string (received {name})") self._name = name @@ -87,6 +96,7 @@ def name(self) -> str: @name.setter def name(self, new_name: str): + check_valid_name(new_name) old_name = self.name for parent in self._parents.values(): del parent._childs[old_name] diff --git a/test/test_jacobian_lens_equation.py b/test/test_jacobian_lens_equation.py index bfdf6474..88066581 100644 --- a/test/test_jacobian_lens_equation.py +++ b/test/test_jacobian_lens_equation.py @@ -41,7 +41,7 @@ def test_multiplane_jacobian(): x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32) lens = Multiplane( - name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie-{i}", cosmology=cosmology) for i in range(len(xs))] + name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie_{i}", cosmology=cosmology) for i in range(len(xs))] ) thx, thy = get_meshgrid(0.1, 10, 10) @@ -66,7 +66,7 @@ def test_multiplane_jacobian_autograd_vs_finitediff(): x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32) lens = Multiplane( - name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie-{i}", cosmology=cosmology) for i in range(len(xs))] + name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie_{i}", cosmology=cosmology) for i in range(len(xs))] ) thx, thy = get_meshgrid(0.01, 10, 10) @@ -96,7 +96,7 @@ def test_multiplane_effective_convergence(): x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32) lens = Multiplane( - name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie-{i}", cosmology=cosmology) for i in range(len(xs))] + name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie_{i}", cosmology=cosmology) for i in range(len(xs))] ) thx, thy = get_meshgrid(0.1, 10, 10) diff --git a/test/test_multiplane.py b/test/test_multiplane.py index 9cdacc3f..130bf00b 100644 --- a/test/test_multiplane.py +++ b/test/test_multiplane.py @@ -30,7 +30,7 @@ def test(): x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32) lens = Multiplane( - name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie-{i}", cosmology=cosmology) for i in range(len(xs))] + name="multiplane", cosmology=cosmology, lenses=[SIE(name=f"sie_{i}", cosmology=cosmology) for i in range(len(xs))] ) #lens.effective_reduced_deflection_angle = lens.raytrace diff --git a/test/test_parametrized.py b/test/test_parametrized.py index 19629d4d..e6dc0c86 100644 --- a/test/test_parametrized.py +++ b/test/test_parametrized.py @@ -1,5 +1,6 @@ import torch from torch import vmap +import pytest import numpy as np from caustic.sims import Simulator from caustic.parameter import Parameter @@ -103,11 +104,22 @@ def __init__(self): assert sim.name == "Test" # Check that DAG in SIM is being update updated - sim.lens.name = "Test Lens" - assert sim.lens.name == "Test Lens" - assert "Test Lens" in sim.params.dynamic.keys() - assert "Test Lens" in sim.cosmo._parents.keys() - + sim.lens.name = "test_lens" + assert sim.lens.name == "test_lens" + assert "test_lens" in sim.params.dynamic.keys() + assert "test_lens" in sim.cosmo._parents.keys() + + +def test_parametrized_name_setter_bad_names(): + # Make sure bad names are catched by our added method. Bad names are name which cannot be used as class attributes. + good_names = ["variable", "_variable", "var_iable2"] + for name in good_names: + module = Sersic(name=name) + bad_names = ["for", "2variable", "variable!", "var-iable", "var iable", "def"] + for name in bad_names: + print(name) + with pytest.raises(NameError): + module = Sersic(name=name) def test_parametrized_name_collision(): # Case 1: Name collision in children of simulator From d6ff524e4ab9c658c7c0b03bbb564e8d69d14100 Mon Sep 17 00:00:00 2001 From: AlexandreAdam Date: Fri, 3 Nov 2023 15:28:28 -0400 Subject: [PATCH 10/11] Fixed typo --- caustic/parametrized.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/caustic/parametrized.py b/caustic/parametrized.py index 369858dc..d0940c6a 100644 --- a/caustic/parametrized.py +++ b/caustic/parametrized.py @@ -18,8 +18,8 @@ def check_valid_name(name): if keyword.iskeyword(name) or not bool(re.match("^[a-zA-Z_][a-zA-Z0-9_]*$", name)): - raise NameError(f"The string {name} contain illegal characters (like space or '-'). "\ - "Please use snake case or another valid python variable anming style.") + raise NameError(f"The string {name} contains illegal characters (like space or '-'). "\ + "Please use snake case or another valid python variable naming style.") class Parametrized: From 22ad60e4d7d9c27298a34e886ef615a02544f4c2 Mon Sep 17 00:00:00 2001 From: Connor Stone Date: Fri, 3 Nov 2023 15:49:44 -0400 Subject: [PATCH 11/11] remove alphagen, no tests --- caustic/lenses/alphagen/__init__.py | 3 - caustic/lenses/alphagen/models.py | 193 ----------------------- caustic/lenses/alphagen/normalizer.py | 219 -------------------------- caustic/lenses/alphagen/stblocks.py | 91 ----------- 4 files changed, 506 deletions(-) delete mode 100644 caustic/lenses/alphagen/__init__.py delete mode 100644 caustic/lenses/alphagen/models.py delete mode 100644 caustic/lenses/alphagen/normalizer.py delete mode 100644 caustic/lenses/alphagen/stblocks.py diff --git a/caustic/lenses/alphagen/__init__.py b/caustic/lenses/alphagen/__init__.py deleted file mode 100644 index 9a15f89f..00000000 --- a/caustic/lenses/alphagen/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .models import * -from .normalizer import * -from .stblocks import * diff --git a/caustic/lenses/alphagen/models.py b/caustic/lenses/alphagen/models.py deleted file mode 100644 index 5a6284b7..00000000 --- a/caustic/lenses/alphagen/models.py +++ /dev/null @@ -1,193 +0,0 @@ -import torch -from torch import nn - -from .stblocks import ISAB, PMA, SAB - - -class AlphaSetTransformerISAB(nn.Module): - def __init__( - self, - in_dim, - hidden_dims_enc, - hidden_dims_dec, - out_dim, - num_heads, - num_inds, - num_out=1, - ln=False, - ): - super().__init__() - - enc_layers = [ISAB(in_dim, hidden_dims_enc[0], num_heads, num_inds, ln=ln)] - for i in range(len(hidden_dims_enc) - 1): - enc_layers.append( - ISAB( - hidden_dims_enc[i], - hidden_dims_enc[i + 1], - num_heads, - num_inds, - ln=ln, - ) - ) - - self.enc = nn.Sequential(*enc_layers) - - dec_layers: list[nn.Module] = [ - PMA(hidden_dims_enc[-1], num_heads, num_out, ln=ln) - ] - hidden_dims_dec.insert(0, hidden_dims_enc[-1]) - for i in range(len(hidden_dims_dec) - 1): - dec_layers.append( - SAB(hidden_dims_dec[i], hidden_dims_dec[i + 1], num_heads, ln=ln) - ) - - dec_layers.append(nn.Linear(hidden_dims_dec[-1], out_dim)) - self.dec = nn.Sequential(*dec_layers) - - def forward(self, x): - enc_out = self.enc(x) - dec_out = self.dec(enc_out).squeeze(dim=1) - - return dec_out - - -class AlphaSetTransformerSAB(nn.Module): - def __init__( - self, - in_dim, - hidden_dims_enc, - hidden_dims_dec, - out_dim, - num_heads, - num_out=1, - ln=False, - ): - super().__init__() - - enc_layers = [SAB(in_dim, hidden_dims_enc[0], num_heads, ln=ln)] - for i in range(len(hidden_dims_enc) - 1): - enc_layers.append( - SAB(hidden_dims_enc[i], hidden_dims_enc[i + 1], num_heads, ln=ln) - ) - - self.enc = nn.Sequential(*enc_layers) - - dec_layers: list[nn.Module] = [ - PMA(hidden_dims_enc[-1], num_heads, num_out, ln=ln) - ] - hidden_dims_dec.insert(0, hidden_dims_enc[-1]) - for i in range(len(hidden_dims_dec) - 1): - dec_layers.append( - SAB(hidden_dims_dec[i], hidden_dims_dec[i + 1], num_heads, ln=ln) - ) - - dec_layers.append(nn.Linear(hidden_dims_dec[-1], out_dim)) - self.dec = nn.Sequential(*dec_layers) - - def forward(self, x): - enc_out = self.enc(x) - dec_out = self.dec(enc_out).squeeze(dim=1) - - return dec_out - - -class ConvNet2D(nn.Module): - def __init__( - self, in_ch, chs, out_ch, activation, final_activation, wrapper_func=None - ): - super().__init__() - - if not wrapper_func: - wrapper_func = lambda x: x - - chs.append(out_ch) - - layers = [ - nn.Conv2d(in_channels=in_ch, out_channels=chs[0], kernel_size=1, padding=0) - ] - - for i in range(len(chs) - 1): - layers.append(getattr(nn, activation)()) - layers.append( - wrapper_func( - nn.Conv2d( - in_channels=chs[i], - out_channels=chs[i + 1], - kernel_size=1, - padding=0, - ) - ) - ) - assert layers[-1].bias is not None - layers[-1].bias.data.fill_(0.0) - - if final_activation is not None: - layers.append(getattr(nn, final_activation)()) - - self.net = nn.Sequential(*layers) - - def forward(self, x): - """ - :param x: input Tensor, shape (b, c, npix, npix) - """ - return self.net(x) - - -class JointModel_ST_CNN(nn.Module): - def __init__( - self, - in_dim_SM, - hidden_dims_enc_SM, - hidden_dims_dec_SM, - embedding_size, - num_heads_SM, - layernorm_SM, - in_dim_MM, - hidden_dims_MM, - out_dim_MM, - npix, - normalizer, - activation_MM="ReLU", - final_activation_MM=None, - num_inds_SM=None, - ): - super().__init__() - - self.npix = npix - self.embedding_size = embedding_size - self.normalizer = normalizer - - if num_inds_SM is not None: - self.ST = AlphaSetTransformerISAB( - in_dim_SM, - hidden_dims_enc_SM, - hidden_dims_dec_SM, - embedding_size, - num_heads_SM, - num_inds_SM, - ln=layernorm_SM, - ) - else: - self.ST = AlphaSetTransformerSAB( - in_dim_SM, - hidden_dims_enc_SM, - hidden_dims_dec_SM, - embedding_size, - num_heads_SM, - ln=layernorm_SM, - ) - - self.CNN = ConvNet2D( - in_dim_MM, hidden_dims_MM, out_dim_MM, activation_MM, final_activation_MM - ) - - def forward(self, h, x): - embedding = self.ST(h) - embedding_ = ( - embedding.reshape(-1, self.embedding_size, 1, 1) - .repeat_interleave(self.npix, dim=-2) - .repeat_interleave(self.npix, dim=-1) - ) - cnn_input = torch.cat([x, embedding_], dim=1) - out = self.CNN(cnn_input) - return out diff --git a/caustic/lenses/alphagen/normalizer.py b/caustic/lenses/alphagen/normalizer.py deleted file mode 100644 index 4c159c5d..00000000 --- a/caustic/lenses/alphagen/normalizer.py +++ /dev/null @@ -1,219 +0,0 @@ -import ast - -import h5py -import numpy as np -import torch -from scipy.interpolate import splrep - -__all__ = ("Normalizer",) - - -class Normalizer: - def __init__(self, normalize, mdef=None, segment="FG", stats_path=None): - """ - mdef == "PJAFFE" -> hset_params = ['center_x', 'center_y', 'Ra', 'Rs', 'sigma0', 'z'] - mdef == "NFW" -> hset_params = ["alpha_Rs", "Rs", "center_x", "center_y", "z"] - """ - self.normalize = normalize - self.mdef = mdef - self.segment = segment - if self.normalize == "from_stats": - try: - stats_file = h5py.File(stats_path, mode="r") - self.cone_fov = ast.literal_eval( - stats_file["base"].attrs["dataset_descriptor"] - )["cone_fov"] - - self.z = torch.tensor(stats_file["base"]["z"][:], dtype=torch.float) # type: ignore - self.z_norm = torch.tensor(stats_file["base"]["z_norm"][:], dtype=torch.float).reshape(1, -1, 1, 1) # type: ignore - - self.z_norm_fn = splrep( - np.insert(self.z.squeeze().detach().cpu().numpy(), 0, 0.0), - np.insert(self.z_norm.squeeze().detach().cpu().numpy(), 0, 0.0), - ) - - self.hset_bounds = torch.tensor(stats_file["base"]["hset_bounds"][:], dtype=torch.float) # type: ignore - if self.mdef == "PJAFFE": - self.log_slice = slice(2, -1) - self.hset_bounds[:, :2] = torch.tile( - torch.tensor( - [-self.cone_fov / 2, self.cone_fov / 2], dtype=torch.float - ), - (2, 1), - ).T # bounds for halo positions are known - self.hset_bounds[:, self.log_slice] = torch.log10( - self.hset_bounds[:, 2:-1] - ) # take log of non-position dimensions except z - elif self.mdef == "NFW": - self.log_slice = slice(2) - self.hset_bounds[:, 2:4] = torch.tile( - torch.tensor( - [-self.cone_fov / 2, self.cone_fov / 2], dtype=torch.float - ), - (2, 1), - ).T # bounds for halo positions are known - self.hset_bounds[:, self.log_slice] = torch.log10( - self.hset_bounds[:, :2] - ) # take log of non-position dimensions except z - - self.reverse_alpha_norm = self.reverse_alpha_from_stats - - if self.segment == "FG": - self.forward_norm = self.forward_from_stats_FG - self.reverse_norm = self.reverse_from_stats_FG - self.forward_x_norm = self.forward_x_FG - self.reverse_x_norm = self.reverse_x_FG - - elif self.segment == "BG" or self.segment == "FULL": - self.a_LP_bound = stats_file["base"]["a_LP_bound"][()] # type: ignore - self.beta_LP_bound = stats_file["base"]["beta_LP_bound"][()] # type: ignore - - self.forward_norm = self.forward_from_stats_BG - self.reverse_norm = self.reverse_from_stats_BG - self.forward_x_norm = self.forward_x_BG - self.reverse_x_norm = self.reverse_x_BG - except FileNotFoundError: - raise FileNotFoundError(f"stats file {stats_path} does not exist") - - elif self.normalize == None: - self.forward_norm = self.forward_null - self.reverse_norm = self.reverse_null - self.reverse_alpha_norm = self.reverse_alpha_null - self.forward_x_norm = self.forward_x_null - self.reverse_x_norm = self.reverse_x_null - - def forward(self, *args, **kwargs): - return self.forward_norm(*args, **kwargs) - - def reverse(self, *args, **kwargs): - return self.reverse_norm(*args, **kwargs) - - def reverse_alpha(self, a, plane_ids=None): - return self.reverse_alpha_norm(a, plane_ids) - - def forward_x(self, x): - return self.forward_x_norm(x) - - def reverse_x(self, x): - return self.reverse_x_norm(x) - - def reverse_alpha_from_stats(self, a, plane_ids): - return a * self.z_norm[:, plane_ids].to(a.device) - - def forward_x_FG(self, x): - x_scaled = self._min_max_scale( - x, bounds=(-self.cone_fov / 2, self.cone_fov / 2) - ) - return x_scaled - - def reverse_x_FG(self, x_scaled): - x = self._min_max_unscale( - x_scaled, bounds=(-self.cone_fov / 2, self.cone_fov / 2) - ) - return x - - def forward_x_BG(self, a_LP, beta_LP): - a_LP_scaled = self._min_max_scale(a_LP, bounds=(-self.a_LP_bound, self.a_LP_bound)) # type: ignore - beta_LP_scaled = self._min_max_scale(beta_LP, bounds=(-self.beta_LP_bound, self.beta_LP_bound)) # type: ignore - - return a_LP_scaled, beta_LP_scaled - - def reverse_x_BG(self, a_LP_scaled, beta_LP_scaled): - a_LP = self._min_max_unscale(a_LP_scaled, bounds=(-self.a_LP_bound, self.a_LP_bound)) # type: ignore - beta_LP = self._min_max_unscale(beta_LP_scaled, bounds=(-self.beta_LP_bound, self.beta_LP_bound)) # type: ignore - - return a_LP, beta_LP - - def forward_from_stats_FG(self, hset, a=None): - hset_scaled = torch.clone(hset) - hset_scaled[..., self.log_slice] = torch.log10(hset_scaled[..., self.log_slice]) - hset_scaled[..., :-1] = self._min_max_scale( - hset_scaled[..., :-1], bounds=self.hset_bounds[:, :-1] - ) - if a is not None: - a_scaled = self._safe_divide_z_norm(a, self.z_norm) - else: - a_scaled = None - return hset_scaled, a_scaled - - def reverse_from_stats_FG(self, hset_scaled, a_scaled=None): - hset = torch.clone(hset_scaled) - hset[..., :-1] = self._min_max_unscale( - hset[..., :-1], bounds=self.hset_bounds[:, :-1] - ) - hset[..., self.log_slice] = 10 ** hset[..., self.log_slice] - if a_scaled is not None: - a = self.z_norm * a_scaled - else: - a = None - return hset, a - - def forward_from_stats_BG(self, hset, a_LP, beta_LP, a=None): - hset_scaled = torch.clone(hset) - hset_scaled[..., self.log_slice] = torch.log10(hset_scaled[..., self.log_slice]) - hset_scaled[..., :-1] = self._min_max_scale( - hset_scaled[..., :-1], bounds=self.hset_bounds[:, :-1] - ) - - a_LP_scaled = self._min_max_scale(a_LP, bounds=(-self.a_LP_bound, self.a_LP_bound)) # type: ignore - beta_LP_scaled = self._min_max_scale(beta_LP, bounds=(-self.beta_LP_bound, self.beta_LP_bound)) # type: ignore - - if a is not None: - a_scaled = self._safe_divide_z_norm(a, self.z_norm) - else: - a_scaled = None - return hset_scaled, a_LP_scaled, beta_LP_scaled, a_scaled - - def reverse_from_stats_BG( - self, hset_scaled, a_LP_scaled, beta_LP_scaled, a_scaled=None - ): - hset = torch.clone(hset_scaled) - hset[..., :-1] = self._min_max_unscale( - hset[..., :-1], bounds=self.hset_bounds[:, :-1] - ) - hset[..., self.log_slice] = 10 ** hset[..., self.log_slice] - - a_LP = self._min_max_unscale(a_LP_scaled, bounds=(-self.a_LP_bound, self.a_LP_bound)) # type: ignore - beta_LP = self._min_max_unscale(beta_LP_scaled, bounds=(-self.beta_LP_bound, self.beta_LP_bound)) # type: ignore - - if a_scaled is not None: - a = self.z_norm * a_scaled - else: - a = None - return hset, a_LP, beta_LP, a - - def forward_null(self, *args): - return args - - def reverse_null(self, *args): - return args - - def reverse_alpha_null(self, a, *args): - return a - - def forward_x_null(self, x): - return x - - def reverse_x_null(self, x): - return x - - @staticmethod - def _min_max_scale(arr, bounds, vrange=(-1, 1)): - return vrange[0] + (arr - bounds[0]) * (vrange[1] - vrange[0]) / ( - bounds[1] - bounds[0] - ) - - @staticmethod - def _min_max_unscale(arr, bounds, vrange=(-1, 1)): - return (arr - vrange[0]) * (bounds[1] - bounds[0]) / ( - vrange[1] - vrange[0] - ) + bounds[0] - - @staticmethod - def _safe_divide_z_norm(num, denom): - out = torch.zeros_like(num, dtype=torch.float) - out[:, torch.squeeze(denom) != 0, ...] = ( - num[:, torch.squeeze(denom) != 0, ...] - / denom[:, torch.squeeze(denom) != 0, ...] - ) - return out diff --git a/caustic/lenses/alphagen/stblocks.py b/caustic/lenses/alphagen/stblocks.py deleted file mode 100644 index df945f70..00000000 --- a/caustic/lenses/alphagen/stblocks.py +++ /dev/null @@ -1,91 +0,0 @@ -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class MAB(nn.Module): - """ - Multi-Head Attention Block - """ - - def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False): - super().__init__() - self.dim_Q = dim_Q - self.dim_V = dim_V - self.num_heads = num_heads - self.fc_q = nn.Linear(dim_Q, dim_V) - self.fc_k = nn.Linear(dim_K, dim_V) - self.fc_v = nn.Linear(dim_K, dim_V) - if ln: - self.ln0 = nn.LayerNorm(dim_V) - self.ln1 = nn.LayerNorm(dim_V) - self.fc_o = nn.Linear(dim_V, dim_V) - - def forward(self, Q, K): - bigQ = self.fc_q(Q) - bigK = self.fc_k(K) - bigV = self.fc_v(K) - - dim_split = self.dim_V // self.num_heads - Q_ = torch.cat(bigQ.split(dim_split, -1), 0) - K_ = torch.cat(bigK.split(dim_split, -1), 0) - V_ = torch.cat(bigV.split(dim_split, -1), 0) - - A = torch.softmax( - Q_.bmm(K_.transpose(1, 2)) / math.sqrt(self.dim_Q), -1 - ) # softmax on last dim of QK^T product - - O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) - O = O if getattr(self, "ln0", None) is None else self.ln0(O) - O = O + F.relu(self.fc_o(O)) - MAB_out = O if getattr(self, "ln1", None) is None else self.ln1(O) - return MAB_out - - -class SAB(nn.Module): - """ - Set Attention Block - - SAB(X) = MAB(X,X) - """ - - def __init__(self, in_dim, out_dim, num_heads, ln=False): - super().__init__() - self.mab = MAB(in_dim, in_dim, out_dim, num_heads, ln=ln) - - def forward(self, X): - return self.mab(X, X) - - -class ISAB(nn.Module): - """ - Induced Set Attention Block - """ - - def __init__(self, in_dim, out_dim, num_heads, num_inds, ln=False): - super().__init__() - self.I = nn.Parameter(torch.Tensor(1, num_inds, out_dim)) - nn.init.xavier_uniform_(self.I) - self.mab0 = MAB(out_dim, in_dim, out_dim, num_heads, ln=ln) - self.mab1 = MAB(in_dim, out_dim, out_dim, num_heads, ln=ln) - - def forward(self, X): - H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) - return self.mab1(X, H) - - -class PMA(nn.Module): - """ - Pooling by Multi-Head Attention - """ - - def __init__(self, dim, num_heads, num_seeds, ln=False): - super().__init__() - self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim)) - nn.init.xavier_uniform_(self.S) - self.mab = MAB(dim, dim, dim, num_heads, ln=ln) - - def forward(self, X): - return self.mab(self.S.repeat(X.size(0), 1, 1), X)