From ea19d892a9cc851704637dd51ef516445b5f2ad9 Mon Sep 17 00:00:00 2001 From: Don Setiawan Date: Fri, 8 Mar 2024 08:57:24 -0800 Subject: [PATCH] feat: Add functionality to build simulator from a YAML configuration file (#167) * feat: Add registry for various parametrized "kind" (#84) * feat: Add registry for various parametrized "kind" Added a registry for cosmology, lenses, light, and sims classes to be used as "kind" * style: pre-commit fixes * fix: Fix misspelling of pixelatedconvergence Co-authored-by: Cordero Core <127983572+uwcdc@users.noreply.github.com> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Cordero Core <127983572+uwcdc@users.noreply.github.com> * feat: Add _meta_params class attribute * fix: Fix default _meta_params type * feat: Add _meta_params to cosmo, lenses, and source (#87) * feat: Add _meta_params to cosmology, lenses, and sources for use in yaml validation * feat: Add _meta_params to lens_source * fix: Remove _meta_params from multiplane.py Co-authored-by: Don Setiawan * fix: Remove _meta_params from singleplane.py * fix: Update _meta_params to reflect z_s --------- Co-authored-by: Don Setiawan * feat: Add pydantic dynamic models from classes (#88) * feat: Add pydantic dynamic models from classes * revert: Remove the use of _meta_params, and use annotated instead * feat: Updated dynamic creation to use annotated from class * fix: Fix bugs in creating field defs * test: Add tests for models/utils * fix: Fix typehints for Parametrized | * chore(deps): Add pydantic 2 as dependency * refactor: Add build_simulator function to caustics init * feat: Add way to evaluate string and dict in pre field inputs * fix: Fix return type to be Any * fix: Apply suggestions from code review Co-authored-by: Cordero Core <127983572+uwcdc@users.noreply.github.com> * fix: Apply suggestions from code review Co-authored-by: Cordero Core <127983572+uwcdc@users.noreply.github.com> --------- Co-authored-by: Cordero Core <127983572+uwcdc@users.noreply.github.com> * feat: Add ValueError when dict doesn't include both 'func' and 'keys' * test: Add integration test for yaml config (#89) * test: Add integration test for yaml config * refactor: Remove unecessary attr and specify union for build * test: Add test for models registry * chore(deps): Add pytest-mock for mocking * test: Separate models test utils and create a complex yaml * fix: Move SinglePlane to single lenses and handle case * test: Add test for models api * fix: Add single plane model to list of lenses * refactor: Renamed config_json to config_dict for clarity * test: Fix tempfile test for windows * test: Fix temp file creation and reading * test: Extract temp yaml creation to a func * test: Fix to use path of tempfile * test: Fix where name is extracted * test: Add build_simulator test with state * test: Ignore cleanup for temp state dict... let OS clean it up * fix: Fix bug with arbitrary dict --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Cordero Core <127983572+uwcdc@users.noreply.github.com> --- pyproject.toml | 3 +- requirements.txt | 1 + src/caustics/__init__.py | 2 + src/caustics/cosmology/FlatLambdaCDM.py | 18 +- src/caustics/cosmology/base.py | 6 +- src/caustics/lenses/base.py | 18 +- src/caustics/lenses/epl.py | 45 ++- src/caustics/lenses/external_shear.py | 36 ++- src/caustics/lenses/mass_sheet.py | 29 +- src/caustics/lenses/multiplane.py | 7 +- src/caustics/lenses/nfw.py | 36 ++- src/caustics/lenses/pixelated_convergence.py | 54 ++-- src/caustics/lenses/point.py | 33 ++- src/caustics/lenses/pseudo_jaffe.py | 44 ++- src/caustics/lenses/sie.py | 37 ++- src/caustics/lenses/singleplane.py | 13 +- src/caustics/lenses/sis.py | 31 ++- src/caustics/lenses/tnfw.py | 58 +++- src/caustics/light/base.py | 4 +- src/caustics/light/pixelated.py | 35 ++- src/caustics/light/sersic.py | 53 +++- src/caustics/models/__init__.py | 0 src/caustics/models/api.py | 35 +++ src/caustics/models/base_models.py | 90 ++++++ src/caustics/models/registry.py | 126 +++++++++ src/caustics/models/utils.py | 277 +++++++++++++++++++ src/caustics/parametrized.py | 9 + src/caustics/sims/lens_source.py | 52 ++-- src/caustics/sims/simulator.py | 4 +- src/caustics/utils.py | 55 +++- tests/conftest.py | 30 ++ tests/models/test_mod_api.py | 228 +++++++++++++++ tests/models/test_mod_registry.py | 97 +++++++ tests/models/test_mod_utils.py | 118 ++++++++ tests/test_epl.py | 24 +- tests/test_external_shear.py | 24 +- tests/test_masssheet.py | 24 +- tests/test_multiplane.py | 54 +++- tests/test_nfw.py | 53 +++- tests/test_point.py | 29 +- tests/test_pseudo_jaffe.py | 25 +- tests/test_sersic.py | 18 +- tests/test_sie.py | 24 +- tests/test_simulator_runs.py | 132 +++++++-- tests/test_sis.py | 29 +- tests/test_tnfw.py | 32 ++- tests/utils/__init__.py | 4 + tests/utils/models.py | 178 ++++++++++++ 48 files changed, 2053 insertions(+), 281 deletions(-) create mode 100644 src/caustics/models/__init__.py create mode 100644 src/caustics/models/api.py create mode 100644 src/caustics/models/base_models.py create mode 100644 src/caustics/models/registry.py create mode 100644 src/caustics/models/utils.py create mode 100644 tests/models/test_mod_api.py create mode 100644 tests/models/test_mod_registry.py create mode 100644 tests/models/test_mod_utils.py create mode 100644 tests/utils/models.py diff --git a/pyproject.toml b/pyproject.toml index dd5f0142..99616f22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,8 @@ dev = [ "lenstronomy==1.11.1", "pytest>=8.0,<9", "pytest-cov>=4.1,<5", - "pre-commit>=3.6,<4" + "pytest-mock>=3.12,<4", + "pre-commit>=3.6,<4", ] [tool.hatch.metadata.hooks.requirements_txt] diff --git a/requirements.txt b/requirements.txt index 3fdc2830..62964719 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ astropy>=5.2.1,<6.0.0 graphviz==0.20.1 h5py>=3.8.0 numpy>=1.23.5 +pydantic>=2.6.1,<3 safetensors>=0.4.1 scipy>=1.8.0 torch>=2.0.0 diff --git a/src/caustics/__init__.py b/src/caustics/__init__.py index e43a8a04..1e1e7788 100644 --- a/src/caustics/__init__.py +++ b/src/caustics/__init__.py @@ -28,6 +28,7 @@ from . import utils from .sims import Lens_Source, Simulator from .tests import test +from .models.api import build_simulator __version__ = VERSION __author__ = "Ciela" @@ -62,4 +63,5 @@ "Lens_Source", "Simulator", "test", + "build_simulator", ] diff --git a/src/caustics/cosmology/FlatLambdaCDM.py b/src/caustics/cosmology/FlatLambdaCDM.py index 3a68b929..9aaaa671 100644 --- a/src/caustics/cosmology/FlatLambdaCDM.py +++ b/src/caustics/cosmology/FlatLambdaCDM.py @@ -1,5 +1,5 @@ # mypy: disable-error-code="operator" -from typing import Optional +from typing import Optional, Annotated import torch from torch import Tensor @@ -11,9 +11,7 @@ from ..parametrized import unpack from ..packed import Packed from ..constants import c_Mpc_s, km_to_Mpc -from .base import ( - Cosmology, -) +from .base import Cosmology, NameType _h0_default = float(default_cosmology.get().h) _critical_density_0_default = float( @@ -43,10 +41,14 @@ class FlatLambdaCDM(Cosmology): def __init__( self, - h0: Optional[Tensor] = h0_default, - critical_density_0: Optional[Tensor] = critical_density_0_default, - Om0: Optional[Tensor] = Om0_default, - name: Optional[str] = None, + h0: Annotated[Optional[Tensor], "Hubble constant over 100", True] = h0_default, + critical_density_0: Annotated[ + Optional[Tensor], "Critical density at z=0", True + ] = critical_density_0_default, + Om0: Annotated[ + Optional[Tensor], "Matter density parameter at z=0", True + ] = Om0_default, + name: NameType = None, ): """ Initialize a new instance of the FlatLambdaCDM class. diff --git a/src/caustics/cosmology/base.py b/src/caustics/cosmology/base.py index 1b2a7ed2..152814ad 100644 --- a/src/caustics/cosmology/base.py +++ b/src/caustics/cosmology/base.py @@ -1,7 +1,7 @@ # mypy: disable-error-code="operator" from abc import abstractmethod from math import pi -from typing import Optional +from typing import Optional, Annotated from torch import Tensor @@ -9,6 +9,8 @@ from ..parametrized import Parametrized, unpack from ..packed import Packed +NameType = Annotated[Optional[str], "Name of the cosmology"] + class Cosmology(Parametrized): """ @@ -31,7 +33,7 @@ class Cosmology(Parametrized): Name of the cosmological model. """ - def __init__(self, name: Optional[str] = None): + def __init__(self, name: NameType = None): """ Initialize the Cosmology. diff --git a/src/caustics/lenses/base.py b/src/caustics/lenses/base.py index abc547a7..ffe56724 100644 --- a/src/caustics/lenses/base.py +++ b/src/caustics/lenses/base.py @@ -1,6 +1,6 @@ # mypy: disable-error-code="call-overload" from abc import abstractmethod -from typing import Optional, Union +from typing import Optional, Union, Annotated, List from functools import partial import warnings @@ -16,13 +16,21 @@ __all__ = ("ThinLens", "ThickLens") +CosmologyType = Annotated[ + Cosmology, + "Cosmology object that encapsulates cosmological parameters and distances", +] +NameType = Annotated[Optional[str], "Name of the lens model"] +ZLType = Annotated[Optional[Union[Tensor, float]], "The redshift of the lens", True] +LensesType = Annotated[List["ThinLens"], "A list of ThinLens objects"] + class Lens(Parametrized): """ Base class for all lenses """ - def __init__(self, cosmology: Cosmology, name: Optional[str] = None): + def __init__(self, cosmology: CosmologyType, name: NameType = None): """ Initializes a new instance of the Lens class. @@ -715,9 +723,9 @@ class ThinLens(Lens): def __init__( self, - cosmology: Cosmology, - z_l: Optional[Union[Tensor, float]] = None, - name: Optional[str] = None, + cosmology: CosmologyType, + z_l: ZLType = None, + name: NameType = None, ): super().__init__(cosmology=cosmology, name=name) self.add_param("z_l", z_l) diff --git a/src/caustics/lenses/epl.py b/src/caustics/lenses/epl.py index 35ae8abc..f1829026 100644 --- a/src/caustics/lenses/epl.py +++ b/src/caustics/lenses/epl.py @@ -1,12 +1,11 @@ -# mypy: disable-error-code="operator" -from typing import Optional, Union +# mypy: disable-error-code="operator,dict-item" +from typing import Optional, Union, Annotated import torch from torch import Tensor -from ..cosmology import Cosmology from ..utils import derotate, translate_rotate -from .base import ThinLens +from .base import ThinLens, CosmologyType, NameType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -92,17 +91,33 @@ class EPL(ThinLens): def __init__( self, - cosmology: Cosmology, - z_l: Optional[Union[Tensor, float]] = None, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - q: Optional[Union[Tensor, float]] = None, - phi: Optional[Union[Tensor, float]] = None, - b: Optional[Union[Tensor, float]] = None, - t: Optional[Union[Tensor, float]] = None, - s: float = 0.0, - n_iter: int = 18, - name: Optional[str] = None, + cosmology: CosmologyType, + z_l: ZLType = None, + x0: Annotated[ + Optional[Union[Tensor, float]], "X coordinate of the lens center", True + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], "Y coordinate of the lens center", True + ] = None, + q: Annotated[ + Optional[Union[Tensor, float]], "Axis ratio of the lens", True + ] = None, + phi: Annotated[ + Optional[Union[Tensor, float]], "Position angle of the lens", True + ] = None, + b: Annotated[ + Optional[Union[Tensor, float]], "Scale length of the lens", True + ] = None, + t: Annotated[ + Optional[Union[Tensor, float]], + "Power law slope (`gamma-1`) of the lens", + True, + ] = None, + s: Annotated[ + float, "Softening length for the elliptical power-law profile" + ] = 0.0, + n_iter: Annotated[int, "Number of iterations for the iterative solver"] = 18, + name: NameType = None, ): """ Initialize an EPL lens model. diff --git a/src/caustics/lenses/external_shear.py b/src/caustics/lenses/external_shear.py index 1ffc33b3..f10995bb 100644 --- a/src/caustics/lenses/external_shear.py +++ b/src/caustics/lenses/external_shear.py @@ -1,10 +1,10 @@ -from typing import Optional, Union +# mypy: disable-error-code="dict-item" +from typing import Optional, Union, Annotated from torch import Tensor -from ..cosmology import Cosmology from ..utils import translate_rotate -from .base import ThinLens +from .base import ThinLens, CosmologyType, NameType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -53,14 +53,28 @@ class ExternalShear(ThinLens): def __init__( self, - cosmology: Cosmology, - z_l: Optional[Union[Tensor, float]] = None, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - gamma_1: Optional[Union[Tensor, float]] = None, - gamma_2: Optional[Union[Tensor, float]] = None, - s: float = 0.0, - name: Optional[str] = None, + cosmology: CosmologyType, + z_l: ZLType = None, + x0: Annotated[ + Optional[Union[Tensor, float]], + "x-coordinate of the shear center in the lens plane", + True, + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], + "y-coordinate of the shear center in the lens plane", + True, + ] = None, + gamma_1: Annotated[ + Optional[Union[Tensor, float]], "Shear component in the x-direction", True + ] = None, + gamma_2: Annotated[ + Optional[Union[Tensor, float]], "Shear component in the y-direction", True + ] = None, + s: Annotated[ + float, "Softening length for the elliptical power-law profile" + ] = 0.0, + name: NameType = None, ): super().__init__(cosmology, z_l, name=name) diff --git a/src/caustics/lenses/mass_sheet.py b/src/caustics/lenses/mass_sheet.py index 76b46aa1..97f8f4e5 100644 --- a/src/caustics/lenses/mass_sheet.py +++ b/src/caustics/lenses/mass_sheet.py @@ -1,12 +1,11 @@ -# mypy: disable-error-code="operator" -from typing import Optional, Union +# mypy: disable-error-code="operator,dict-item" +from typing import Optional, Union, Annotated import torch from torch import Tensor -from ..cosmology import Cosmology from ..utils import translate_rotate -from .base import ThinLens +from .base import ThinLens, CosmologyType, NameType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -64,12 +63,22 @@ class MassSheet(ThinLens): def __init__( self, - cosmology: Cosmology, - z_l: Optional[Union[Tensor, float]] = None, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - surface_density: Optional[Union[Tensor, float]] = None, - name: Optional[str] = None, + cosmology: CosmologyType, + z_l: ZLType = None, + x0: Annotated[ + Optional[Union[Tensor, float]], + "x-coordinate of the shear center in the lens plane", + True, + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], + "y-coordinate of the shear center in the lens plane", + True, + ] = None, + surface_density: Annotated[ + Optional[Union[Tensor, float]], "Surface density", True + ] = None, + name: NameType = None, ): super().__init__(cosmology, z_l, name=name) diff --git a/src/caustics/lenses/multiplane.py b/src/caustics/lenses/multiplane.py index f001e1a9..a658c7bb 100644 --- a/src/caustics/lenses/multiplane.py +++ b/src/caustics/lenses/multiplane.py @@ -5,8 +5,7 @@ from torch import Tensor from ..constants import arcsec_to_rad, rad_to_arcsec, c_Mpc_s -from ..cosmology import Cosmology -from .base import ThickLens, ThinLens +from .base import ThickLens, NameType, CosmologyType, LensesType from ..parametrized import unpack from ..packed import Packed @@ -19,7 +18,7 @@ class Multiplane(ThickLens): Attributes ---------- - lenses (list[ThinLens]) + lenses list of ThinLens List of thin lenses. Parameters @@ -33,7 +32,7 @@ class Multiplane(ThickLens): """ def __init__( - self, cosmology: Cosmology, lenses: list[ThinLens], name: Optional[str] = None + self, cosmology: CosmologyType, lenses: LensesType, name: NameType = None ): super().__init__(cosmology, name=name) self.lenses = lenses diff --git a/src/caustics/lenses/nfw.py b/src/caustics/lenses/nfw.py index 9dcbe8ee..70f47911 100644 --- a/src/caustics/lenses/nfw.py +++ b/src/caustics/lenses/nfw.py @@ -1,14 +1,13 @@ -# mypy: disable-error-code="operator,union-attr" +# mypy: disable-error-code="operator,union-attr,dict-item" from math import pi -from typing import Optional, Union +from typing import Optional, Union, Annotated, Literal import torch from torch import Tensor from ..constants import G_over_c2, arcsec_to_rad, rad_to_arcsec -from ..cosmology import Cosmology from ..utils import translate_rotate -from .base import ThinLens +from .base import ThinLens, NameType, CosmologyType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -100,15 +99,26 @@ class NFW(ThinLens): def __init__( self, - cosmology: Cosmology, - z_l: Optional[Union[Tensor, float]] = None, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - m: Optional[Union[Tensor, float]] = None, - c: Optional[Union[Tensor, float]] = None, - s: float = 0.0, - use_case="batchable", - name: Optional[str] = None, + cosmology: CosmologyType, + z_l: ZLType = None, + x0: Annotated[ + Optional[Union[Tensor, float]], "X coordinate of the lens center", True + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], "Y coordinate of the lens center", True + ] = None, + m: Annotated[Optional[Union[Tensor, float]], "Mass of the lens", True] = None, + c: Annotated[ + Optional[Union[Tensor, float]], "Concentration parameter of the lens", True + ] = None, + s: Annotated[ + float, + "Softening parameter to avoid singularities at the center of the lens", + ] = 0.0, + use_case: Annotated[ + Literal["batchable", "differentiable"], "the NFW/TNFW profile" + ] = "batchable", + name: NameType = None, ): """ Initialize an instance of the NFW lens class. diff --git a/src/caustics/lenses/pixelated_convergence.py b/src/caustics/lenses/pixelated_convergence.py index 3c99d9ea..e23abc58 100644 --- a/src/caustics/lenses/pixelated_convergence.py +++ b/src/caustics/lenses/pixelated_convergence.py @@ -1,6 +1,6 @@ -# mypy: disable-error-code="index" +# mypy: disable-error-code="index,dict-item" from math import pi -from typing import Optional +from typing import Optional, Annotated, Union, Literal import torch import torch.nn.functional as F @@ -8,9 +8,8 @@ from torch import Tensor import numpy as np -from ..cosmology import Cosmology from ..utils import get_meshgrid, interp2d, safe_divide, safe_log -from .base import ThinLens +from .base import ThinLens, CosmologyType, NameType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -26,18 +25,41 @@ class PixelatedConvergence(ThinLens): def __init__( self, - pixelscale: float, - n_pix: int, - cosmology: Cosmology, - z_l: Optional[Tensor] = None, - x0: Optional[Tensor] = torch.tensor(0.0), - y0: Optional[Tensor] = torch.tensor(0.0), - convergence_map: Optional[Tensor] = None, - shape: Optional[tuple[int, ...]] = None, - convolution_mode: str = "fft", - use_next_fast_len: bool = True, - padding: str = "zero", - name: Optional[str] = None, + pixelscale: Annotated[float, "pixelscale"], + n_pix: Annotated[int, "The number of pixels on each side of the grid"], + cosmology: CosmologyType, + z_l: ZLType = None, + x0: Annotated[ + Optional[Union[Tensor, float]], + "The x-coordinate of the center of the grid", + True, + ] = torch.tensor(0.0), + y0: Annotated[ + Optional[Union[Tensor, float]], + "The y-coordinate of the center of the grid", + True, + ] = torch.tensor(0.0), + convergence_map: Annotated[ + Optional[Tensor], + "A 2D tensor representing the convergence map", + True, + ] = None, + shape: Annotated[ + Optional[tuple[int, ...]], "The shape of the convergence map" + ] = None, + convolution_mode: Annotated[ + Literal["fft", "conv2d"], + "The convolution mode for calculating deflection angles and lensing potential", + ] = "fft", + use_next_fast_len: Annotated[ + bool, + "If True, adds additional padding to speed up the FFT by calling `scipy.fft.next_fast_len`", + ] = True, + padding: Annotated[ + Literal["zero", "circular", "reflect", "tile"], + "Specifies the type of padding", + ] = "zero", + name: NameType = None, ): """Strong lensing with user provided kappa map diff --git a/src/caustics/lenses/point.py b/src/caustics/lenses/point.py index c3d61743..ba423479 100644 --- a/src/caustics/lenses/point.py +++ b/src/caustics/lenses/point.py @@ -1,12 +1,11 @@ -# mypy: disable-error-code="operator" -from typing import Optional, Union +# mypy: disable-error-code="operator,dict-item" +from typing import Optional, Union, Annotated import torch from torch import Tensor -from ..cosmology import Cosmology from ..utils import translate_rotate -from .base import ThinLens +from .base import ThinLens, CosmologyType, NameType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -60,13 +59,25 @@ class Point(ThinLens): def __init__( self, - cosmology: Cosmology, - z_l: Optional[Union[Tensor, float]] = None, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - th_ein: Optional[Union[Tensor, float]] = None, - s: float = 0.0, - name: Optional[str] = None, + cosmology: CosmologyType, + z_l: ZLType = None, + x0: Annotated[ + Optional[Union[Tensor, float]], + "X coordinate of the center of the lens", + True, + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], + "Y coordinate of the center of the lens", + True, + ] = None, + th_ein: Annotated[ + Optional[Union[Tensor, float]], "Einstein radius of the lens", True + ] = None, + s: Annotated[ + float, "Softening parameter to prevent numerical instabilities" + ] = 0.0, + name: NameType = None, ): """ Initialize the Point class. diff --git a/src/caustics/lenses/pseudo_jaffe.py b/src/caustics/lenses/pseudo_jaffe.py index 69da2351..68fa823e 100644 --- a/src/caustics/lenses/pseudo_jaffe.py +++ b/src/caustics/lenses/pseudo_jaffe.py @@ -1,14 +1,13 @@ -# mypy: disable-error-code="operator" +# mypy: disable-error-code="operator,dict-item" from math import pi -from typing import Optional, Union +from typing import Optional, Union, Annotated import torch from torch import Tensor -from ..cosmology import Cosmology from ..constants import arcsec_to_rad, G_over_c2 from ..utils import translate_rotate -from .base import ThinLens +from .base import ThinLens, CosmologyType, NameType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -74,15 +73,34 @@ class PseudoJaffe(ThinLens): def __init__( self, - cosmology: Cosmology, - z_l: Optional[Union[Tensor, float]] = None, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - mass: Optional[Union[Tensor, float]] = None, - core_radius: Optional[Union[Tensor, float]] = None, - scale_radius: Optional[Union[Tensor, float]] = None, - s: float = 0.0, - name: Optional[str] = None, + cosmology: CosmologyType, + z_l: ZLType = None, + x0: Annotated[ + Optional[Union[Tensor, float]], + "X coordinate of the center of the lens", + True, + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], + "Y coordinate of the center of the lens", + True, + ] = None, + mass: Annotated[ + Optional[Union[Tensor, float]], "Total mass of the lens", True, "Msol" + ] = None, + core_radius: Annotated[ + Optional[Union[Tensor, float]], "Core radius of the lens", True, "arcsec" + ] = None, + scale_radius: Annotated[ + Optional[Union[Tensor, float]], + "Scaling radius of the lens", + True, + "arcsec", + ] = None, + s: Annotated[ + float, "Softening parameter to prevent numerical instabilities" + ] = 0.0, + name: NameType = None, ): """ Initialize the PseudoJaffe class. diff --git a/src/caustics/lenses/sie.py b/src/caustics/lenses/sie.py index c5fab214..3deab483 100644 --- a/src/caustics/lenses/sie.py +++ b/src/caustics/lenses/sie.py @@ -1,11 +1,10 @@ -# mypy: disable-error-code="operator,union-attr" -from typing import Optional, Union +# mypy: disable-error-code="operator,union-attr,dict-item" +from typing import Optional, Union, Annotated from torch import Tensor -from ..cosmology import Cosmology from ..utils import derotate, translate_rotate -from .base import ThinLens +from .base import ThinLens, CosmologyType, NameType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -72,15 +71,27 @@ class SIE(ThinLens): def __init__( self, - cosmology: Cosmology, - z_l: Optional[Union[Tensor, float]] = None, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - q: Optional[Union[Tensor, float]] = None, # TODO change to true axis ratio - phi: Optional[Union[Tensor, float]] = None, - b: Optional[Union[Tensor, float]] = None, - s: float = 0.0, - name: Optional[str] = None, + cosmology: CosmologyType, + z_l: ZLType = None, + x0: Annotated[ + Optional[Union[Tensor, float]], "The x-coordinate of the lens center", True + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], "The y-coordinate of the lens center", True + ] = None, + q: Annotated[ + Optional[Union[Tensor, float]], "The axis ratio of the lens", True + ] = None, # TODO change to true axis ratio + phi: Annotated[ + Optional[Union[Tensor, float]], + "The orientation angle of the lens (position angle)", + True, + ] = None, + b: Annotated[ + Optional[Union[Tensor, float]], "The Einstein radius of the lens", True + ] = None, + s: Annotated[float, "The core radius of the lens"] = 0.0, + name: NameType = None, ): """ Initialize the SIE lens model. diff --git a/src/caustics/lenses/singleplane.py b/src/caustics/lenses/singleplane.py index 15b4dddb..2f6e2ea4 100644 --- a/src/caustics/lenses/singleplane.py +++ b/src/caustics/lenses/singleplane.py @@ -3,8 +3,7 @@ import torch from torch import Tensor -from ..cosmology import Cosmology -from .base import ThinLens +from .base import ThinLens, CosmologyType, NameType, LensesType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -31,15 +30,15 @@ class SinglePlane(ThinLens): def __init__( self, - cosmology: Cosmology, - lenses: list[ThinLens], - name: Optional[str] = None, - **kwargs, + cosmology: CosmologyType, + lenses: LensesType, + name: NameType = None, + z_l: ZLType = None, ): """ Initialize the SinglePlane lens model. """ - super().__init__(cosmology, name=name, **kwargs) + super().__init__(cosmology, z_l=z_l, name=name) self.lenses = lenses for lens in lenses: self.add_parametrized(lens) diff --git a/src/caustics/lenses/sis.py b/src/caustics/lenses/sis.py index d413b9cb..9ce4e8ee 100644 --- a/src/caustics/lenses/sis.py +++ b/src/caustics/lenses/sis.py @@ -1,11 +1,10 @@ -# mypy: disable-error-code="operator" -from typing import Optional, Union +# mypy: disable-error-code="operator,dict-item" +from typing import Optional, Union, Annotated from torch import Tensor -from ..cosmology import Cosmology from ..utils import translate_rotate -from .base import ThinLens +from .base import ThinLens, CosmologyType, NameType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -38,9 +37,7 @@ class SIS(ThinLens): y0: Optional[Union[Tensor, float]] The y-coordinate of the lens center. - *Unit: arcsec* - - th_ein (Optional[Union[Tensor, float]]) + th_ein: Optional[Union[Tensor, float]] The Einstein radius of the lens. *Unit: arcsec* @@ -60,13 +57,19 @@ class SIS(ThinLens): def __init__( self, - cosmology: Cosmology, - z_l: Optional[Union[Tensor, float]] = None, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - th_ein: Optional[Union[Tensor, float]] = None, - s: float = 0.0, - name: Optional[str] = None, + cosmology: CosmologyType, + z_l: ZLType = None, + x0: Annotated[ + Optional[Union[Tensor, float]], "The x-coordinate of the lens center", True + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], "The y-coordinate of the lens center", True + ] = None, + th_ein: Annotated[ + Optional[Union[Tensor, float]], "The Einstein radius of the lens", True + ] = None, + s: Annotated[float, "A smoothing factor"] = 0.0, + name: NameType = None, ): """ Initialize the SIS lens model. diff --git a/src/caustics/lenses/tnfw.py b/src/caustics/lenses/tnfw.py index 04a302d6..1bf6119e 100644 --- a/src/caustics/lenses/tnfw.py +++ b/src/caustics/lenses/tnfw.py @@ -1,14 +1,13 @@ -# mypy: disable-error-code="operator,union-attr" +# mypy: disable-error-code="operator,union-attr,dict-item" from math import pi -from typing import Optional, Union +from typing import Optional, Union, Literal, Annotated import torch from torch import Tensor from ..constants import G_over_c2, arcsec_to_rad, rad_to_arcsec -from ..cosmology import Cosmology from ..utils import translate_rotate -from .base import ThinLens +from .base import ThinLens, CosmologyType, NameType, ZLType from ..parametrized import unpack from ..packed import Packed @@ -111,17 +110,46 @@ class TNFW(ThinLens): def __init__( self, - cosmology: Cosmology, - z_l: Optional[Union[Tensor, float]] = None, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - mass: Optional[Union[Tensor, float]] = None, - scale_radius: Optional[Union[Tensor, float]] = None, - tau: Optional[Union[Tensor, float]] = None, - s: float = 0.0, - interpret_m_total_mass: bool = True, - use_case="batchable", - name: Optional[str] = None, + cosmology: CosmologyType, + z_l: ZLType = None, + x0: Annotated[ + Optional[Union[Tensor, float]], + "Center of lens position on x-axis", + True, + "arcsec", + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], + "Center of lens position on y-axis", + True, + "arcsec", + ] = None, + mass: Annotated[ + Optional[Union[Tensor, float]], "Mass of the lens", True, "Msol" + ] = None, + scale_radius: Annotated[ + Optional[Union[Tensor, float]], + "Scale radius of the TNFW lens", + True, + "arcsec", + ] = None, + tau: Annotated[ + Optional[Union[Tensor, float]], + "Truncation scale. Ratio of truncation radius to scale radius", + True, + "rt/rs", + ] = None, + s: Annotated[ + float, + "Softening parameter to avoid singularities at the center of the lens", + ] = 0.0, + interpret_m_total_mass: Annotated[ + bool, "Indicates how to interpret the mass variable 'm'" + ] = True, + use_case: Annotated[ + Literal["batchable", "differentiable"], "the NFW/TNFW profile" + ] = "batchable", + name: NameType = None, ): """ Initialize an instance of the TNFW lens class. diff --git a/src/caustics/light/base.py b/src/caustics/light/base.py index 505a6bf0..3d636ca2 100644 --- a/src/caustics/light/base.py +++ b/src/caustics/light/base.py @@ -1,5 +1,5 @@ from abc import abstractmethod -from typing import Optional +from typing import Optional, Annotated from torch import Tensor @@ -8,6 +8,8 @@ __all__ = ("Source",) +NameType = Annotated[Optional[str], "Name of the source"] + class Source(Parametrized): """ diff --git a/src/caustics/light/pixelated.py b/src/caustics/light/pixelated.py index 52e3e001..0ff4280c 100644 --- a/src/caustics/light/pixelated.py +++ b/src/caustics/light/pixelated.py @@ -1,10 +1,10 @@ # mypy: disable-error-code="union-attr" -from typing import Optional, Union +from typing import Optional, Union, Annotated from torch import Tensor from ..utils import interp2d -from .base import Source +from .base import Source, NameType from ..parametrized import unpack from ..packed import Packed @@ -51,12 +51,31 @@ class Pixelated(Source): def __init__( self, - image: Optional[Tensor] = None, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - pixelscale: Optional[Union[Tensor, float]] = None, - shape: Optional[tuple[int, ...]] = None, - name: Optional[str] = None, + image: Annotated[ + Optional[Tensor], + "The source image from which brightness values will be interpolated.", + True, + ] = None, + x0: Annotated[ + Optional[Union[Tensor, float]], + "The x-coordinate of the source image's center.", + True, + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], + "The y-coordinate of the source image's center.", + True, + ] = None, + pixelscale: Annotated[ + Optional[Union[Tensor, float]], + "The pixelscale of the source image in the lens plane", + True, + "arcsec/pixel", + ] = None, + shape: Annotated[ + Optional[tuple[int, ...]], "The shape of the source image." + ] = None, + name: NameType = None, ): """ Constructs the `Pixelated` object with the given parameters. diff --git a/src/caustics/light/sersic.py b/src/caustics/light/sersic.py index d73e8200..45fccb71 100644 --- a/src/caustics/light/sersic.py +++ b/src/caustics/light/sersic.py @@ -1,10 +1,10 @@ # mypy: disable-error-code="operator,union-attr" -from typing import Optional, Union +from typing import Optional, Union, Annotated from torch import Tensor from ..utils import to_elliptical, translate_rotate -from .base import Source +from .base import Source, NameType from ..parametrized import unpack from ..packed import Packed @@ -71,16 +71,45 @@ class Sersic(Source): def __init__( self, - x0: Optional[Union[Tensor, float]] = None, - y0: Optional[Union[Tensor, float]] = None, - q: Optional[Union[Tensor, float]] = None, - phi: Optional[Union[Tensor, float]] = None, - n: Optional[Union[Tensor, float]] = None, - Re: Optional[Union[Tensor, float]] = None, - Ie: Optional[Union[Tensor, float]] = None, - s: float = 0.0, - use_lenstronomy_k=False, - name: Optional[str] = None, + x0: Annotated[ + Optional[Union[Tensor, float]], + "The x-coordinate of the Sersic source's center", + True, + ] = None, + y0: Annotated[ + Optional[Union[Tensor, float]], + "The y-coordinate of the Sersic source's center", + True, + ] = None, + q: Annotated[ + Optional[Union[Tensor, float]], "The axis ratio of the Sersic source", True + ] = None, + phi: Annotated[ + Optional[Union[Tensor, float]], + "The orientation of the Sersic source (position angle)", + True, + ] = None, + n: Annotated[ + Optional[Union[Tensor, float]], + "The Sersic index, which describes the degree of concentration of the source", + True, + ] = None, + Re: Annotated[ + Optional[Union[Tensor, float]], + "The scale length of the Sersic source", + True, + ] = None, + Ie: Annotated[ + Optional[Union[Tensor, float]], + "The intensity at the effective radius", + True, + ] = None, + s: Annotated[float, "A small constant for numerical stability"] = 0.0, + use_lenstronomy_k: Annotated[ + bool, + "A flag indicating whether to use lenstronomy to compute the value of k.", + ] = False, + name: NameType = None, ): """ Constructs the `Sersic` object with the given parameters. diff --git a/src/caustics/models/__init__.py b/src/caustics/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/caustics/models/api.py b/src/caustics/models/api.py new file mode 100644 index 00000000..671cb07d --- /dev/null +++ b/src/caustics/models/api.py @@ -0,0 +1,35 @@ +# mypy: disable-error-code="import-untyped" +import yaml +from pathlib import Path +from typing import Union + +from ..sims.simulator import Simulator +from ..io import from_file +from .utils import setup_simulator_models, create_model, Field +from .base_models import StateConfig + + +def build_simulator(config_path: Union[str, Path]) -> Simulator: + """ + Build a simulator from the configuration + """ + simulators = setup_simulator_models() + Config = create_model( + "Config", __base__=StateConfig, simulator=(simulators, Field(...)) + ) + + # Load the yaml config + yaml_bytes = from_file(config_path) + config_dict = yaml.safe_load(yaml_bytes) + # Create config model + config = Config(**config_dict) + + # Get the simulator + sim = config.simulator.model_obj() + + # Load state if available + simulator_state = config.state + if simulator_state is not None: + sim.load_state_dict(simulator_state.load.path) + + return sim diff --git a/src/caustics/models/base_models.py b/src/caustics/models/base_models.py new file mode 100644 index 00000000..3a973b1f --- /dev/null +++ b/src/caustics/models/base_models.py @@ -0,0 +1,90 @@ +from typing import Optional, Any, Dict +from pydantic import BaseModel, Field, ConfigDict +from ..parametrized import Parametrized + + +class Parameters(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class InitKwargs(Parameters): + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class Base(BaseModel): + name: str = Field(..., description="Name of the object") + kind: str = Field(..., description="Kind of the object") + params: Optional[Parameters] = Field(None, description="Parameters of the object") + init_kwargs: Optional[InitKwargs] = Field( + None, description="Initiation keyword arguments for object creation" + ) + + # internal + _cls: Any + + def __init__(self, **data): + super().__init__(**data) + + def _get_init_kwargs_dump(self, init_kwargs: InitKwargs) -> Dict[str, Any]: + """ + Get the model dump of the class parameters, + if the field is a model then get the model object. + + Parameters + ---------- + init_kwargs : ClassParams + The class parameters to dump + + Returns + ------- + dict + The model dump of the class parameters + """ + model_dict = {} + for f in init_kwargs.model_fields_set: + model = getattr(init_kwargs, f) + if isinstance(model, Base): + model_dict[f] = model.model_obj() + elif isinstance(model, list): + model_dict[f] = [m.model_obj() for m in model] + else: + model_dict[f] = getattr(init_kwargs, f) + return model_dict + + @classmethod + def _set_class(cls, parametrized_cls: Parametrized) -> type["Base"]: + """ + Set the class of the object. + + Parameters + ---------- + cls : Parametrized + The class to set. + """ + cls._cls = parametrized_cls + return cls + + def model_obj(self) -> Any: + if not self._cls: + raise ValueError( + "The class is not set. Please set the class before calling this method." + ) + init_kwargs = ( + self._get_init_kwargs_dump(self.init_kwargs) if self.init_kwargs else {} + ) # Capture None case + params = self.params.model_dump() if self.params else {} # Capture None case + return self._cls(name=self.name, **init_kwargs, **params) + + +class FileInput(BaseModel): + path: str = Field(..., description="The path to the file") + + +class StateDict(BaseModel): + load: FileInput + + +class StateConfig(BaseModel): + state: Optional[StateDict] = Field( + None, description="State safetensor for the simulator" + ) diff --git a/src/caustics/models/registry.py b/src/caustics/models/registry.py new file mode 100644 index 00000000..51311483 --- /dev/null +++ b/src/caustics/models/registry.py @@ -0,0 +1,126 @@ +from functools import lru_cache +from collections import ChainMap +from typing import MutableMapping, Iterator, Optional + +from caustics.parametrized import Parametrized +from caustics.utils import _import_func_or_class + + +class _KindRegistry(MutableMapping[str, "Parametrized | str"]): + cosmology = { + "FlatLambdaCDM": "caustics.cosmology.FlatLambdaCDM.FlatLambdaCDM", + } + single_lenses = { + "EPL": "caustics.lenses.epl.EPL", + "ExternalShear": "caustics.lenses.external_shear.ExternalShear", + "PixelatedConvergence": "caustics.lenses.pixelated_convergence.PixelatedConvergence", + "NFW": "caustics.lenses.nfw.NFW", + "Point": "caustics.lenses.point.Point", + "PseudoJaffe": "caustics.lenses.pseudo_jaffe.PseudoJaffe", + "SIE": "caustics.lenses.sie.SIE", + "SIS": "caustics.lenses.sis.SIS", + "TNFW": "caustics.lenses.tnfw.TNFW", + "MassSheet": "caustics.lenses.mass_sheet.MassSheet", + "SinglePlane": "caustics.lenses.singleplane.SinglePlane", + } + multi_lenses = { + "Multiplane": "caustics.lenses.multiplane.Multiplane", + } + light = { + "Pixelated": "caustics.light.pixelated.Pixelated", + "Sersic": "caustics.light.sersic.Sersic", + } + simulators = {"Lens_Source": "caustics.sims.lens_source.Lens_Source"} + + known_kinds = { + **cosmology, + **single_lenses, + **multi_lenses, + **light, + **simulators, + } + + def __init__(self) -> None: + self._m: ChainMap[str, "Parametrized | str"] = ChainMap({}, self.known_kinds) # type: ignore + + def __getitem__(self, item: str) -> Parametrized: + kind_mod: "str | Parametrized | None" = self._m.get(item, None) + if kind_mod is None: + raise KeyError(f"{item} not in registry") + if isinstance(kind_mod, str): + cls = _import_func_or_class(kind_mod) + else: + cls = kind_mod # type: ignore + return cls # type: ignore + + def __setitem__(self, item: str, value: "Parametrized | str") -> None: + if not ( + (isinstance(value, type) and issubclass(value, Parametrized)) + or isinstance(value, str) + ): + raise ValueError( + f"expected Parametrized subclass, got: {type(value).__name__!r}" + ) + self._m[item] = value + + def __delitem__(self, __v: str) -> None: + raise NotImplementedError("removal is unsupported") + + def __len__(self) -> int: + return len(set(self._m)) + + def __iter__(self) -> Iterator[str]: + return iter(set(self._m)) + + +_registry = _KindRegistry() + + +def available_kinds() -> list[str]: + """ + Return a list of classes that are available in the registry. + """ + return list(_registry) + + +def register_kind( + name: str, + cls: "Parametrized | str", + *, + clobber: bool = False, +) -> None: + """register a UPath implementation with a protocol + + Parameters + ---------- + name : str + Protocol name to associate with the class + cls : Parametrized or str + The caustics parametrized subclass or a str representing the + full path to the class like package.module.class. + clobber: + Whether to overwrite a protocol with the same name; if False, + will raise instead. + """ + if not clobber and name in _registry: + raise ValueError(f"{name!r} is already in registry and clobber is False!") + _registry[name] = cls + + +@lru_cache +def get_kind( + name: str, +) -> Optional[Parametrized]: + """Get a class from the registry by name. + + Parameters + ---------- + kind : str + The name of the kind to get. + + Returns + ------- + cls : Parametrized + The class associated with the given name. + """ + return _registry[name] diff --git a/src/caustics/models/utils.py b/src/caustics/models/utils.py new file mode 100644 index 00000000..75b49a68 --- /dev/null +++ b/src/caustics/models/utils.py @@ -0,0 +1,277 @@ +# mypy: disable-error-code="union-attr, valid-type, has-type, assignment, arg-type, dict-item, return-value, misc" +import typing +from typing import List, Literal, Dict, Annotated, Union, Any, Tuple +import inspect +from pydantic import Field, create_model, field_validator, ValidationInfo +import torch + +from ..parametrized import Parametrized +from .base_models import Base, Parameters, InitKwargs +from .registry import get_kind, _registry +from ..parametrized import ClassParam +from ..utils import _import_func_or_class, _eval_expression + +PARAMS = "params" +INIT_KWARGS = "init_kwargs" + + +def _get_kwargs_field_definitions( + parametrized_class: Parametrized, dependant_models: Dict[str, Any] = {} +) -> Dict[str, Dict[str, Any]]: + """ + Get the field definitions for the parameters and init_kwargs of a Parametrized class + + Parameters + ---------- + parametrized_class : Parametrized + The Parametrized class to get the field definitions for. + dependant_models : Dict[str, Any], optional + The dependent models to use, by default {} + See: https://docs.pydantic.dev/latest/concepts/unions/#nested-discriminated-unions + + Returns + ------- + dict + The resulting field definitions dictionary + """ + cls_signature = inspect.signature(parametrized_class) + kwargs_field_definitions: Dict[str, Dict[str, Any]] = {PARAMS: {}, INIT_KWARGS: {}} + for k, v in cls_signature.parameters.items(): + if k != "name": + anno = v.annotation + dtype = anno.__origin__ + cls_param = ClassParam(*anno.__metadata__) + if cls_param.isParam: + kwargs_field_definitions[PARAMS][k] = ( + dtype, + Field(default=v.default, description=cls_param.description), + ) + # Below is to handle cases for init kwargs + elif k in dependant_models: + dependant_model = dependant_models[k] + if isinstance(dependant_model, list): + # For the multi lens case + # dependent model is wrapped in a list + dependant_model = dependant_model[0] + kwargs_field_definitions[INIT_KWARGS][k] = ( + List[dependant_model], + Field([], description=cls_param.description), + ) + else: + kwargs_field_definitions[INIT_KWARGS][k] = ( + dependant_model, + Field(..., description=cls_param.description), + ) + elif v.default == inspect._empty: + kwargs_field_definitions[INIT_KWARGS][k] = ( + dtype, + Field(..., description=cls_param.description), + ) + else: + kwargs_field_definitions[INIT_KWARGS][k] = ( + dtype, + Field(v.default, description=cls_param.description), + ) + return kwargs_field_definitions + + +def create_pydantic_model( + cls: "Parametrized | str", dependant_models: Dict[str, type] = {} +) -> Base: + """ + Create a pydantic model from a Parametrized class. + + Parameters + ---------- + cls : Parametrized | str + The Parametrized class to create the model from. + dependant_models : Dict[str, type], optional + The dependent models to use, by default {} + See: https://docs.pydantic.dev/latest/concepts/unions/#nested-discriminated-unions + + Returns + ------- + Base + The pydantic model of the Parametrized class. + """ + if isinstance(cls, str): + parametrized_class = get_kind(cls) # type: ignore + + # Get the field definitions for parameters and init_kwargs + kwargs_field_definitions = _get_kwargs_field_definitions( + parametrized_class, dependant_models + ) + + # Create the model field definitions + field_definitions = { + "kind": (Literal[parametrized_class.__name__], Field(parametrized_class.__name__)), # type: ignore + } + + if kwargs_field_definitions[PARAMS]: + + def _param_field_tensor_check(cls, v): + """Checks the ``params`` fields input + and converts to tensor if necessary""" + if not isinstance(v, torch.Tensor): + if isinstance(v, str): + v = _eval_expression(v) + v = torch.as_tensor(v) + return v + + # Setup the pydantic models for the parameters and init_kwargs + ParamsModel = create_model( + f"{parametrized_class.__name__}_Params", + __base__=Parameters, + __validators__={ + # Convert to tensor before passing to the model for additional validation + "field_tensor_check": field_validator( + "*", mode="before", check_fields=True + )(_param_field_tensor_check) + }, + **kwargs_field_definitions[PARAMS], + ) + field_definitions["params"] = ( + ParamsModel, + Field(ParamsModel(), description="Parameters of the object"), + ) + + if kwargs_field_definitions[INIT_KWARGS]: + + def _init_kwargs_field_check(cls, v, info: ValidationInfo): + """Checks the ``init_kwargs`` fields input""" + field_name = info.field_name + field = cls.model_fields[field_name] + anno_args = typing.get_args(field.annotation) + if len(anno_args) == 2 and anno_args[1] == type(None): + # This means that the anno is optional + expected_type = next( + filter(lambda x: x is not None, typing.get_args(field.annotation)) + ) + if not isinstance(v, expected_type): + if isinstance(v, dict): + if all(k in ["func", "kwargs"] for k in v.keys()): + # Special case for the init_kwargs + # this is to allow for creating tensor with some + # caustics utils function, such as + # `caustics.utils.gaussian` + func = _import_func_or_class(v["func"]) + v = func(**v["kwargs"]) # type: ignore + else: + raise ValueError( + f"Dictionary with keys 'func' and 'kwargs' expected, got: {v.keys()}" + ) + elif expected_type == torch.Tensor: + # Try to cast to tensor if expected type is tensor + v = torch.as_tensor(v) + else: + # Try to cast to the expected type + v = expected_type(v) + return v + + InitKwargsModel = create_model( + f"{parametrized_class.__name__}_Init_Kwargs", + __base__=InitKwargs, + **kwargs_field_definitions[INIT_KWARGS], + __validators__={ + "field_check": field_validator("*", mode="before", check_fields=True)( + _init_kwargs_field_check + ) + }, + ) + field_definitions["init_kwargs"] = ( + InitKwargsModel, + Field({}, description="Initiation keyword arguments of the object"), + ) + + # Create the model + model = create_model( + parametrized_class.__name__, __base__=Base, **field_definitions + ) + # Set the imported parametrized class to the model + # this will be accessible as `model._cls` + model = model._set_class(parametrized_class) + return model + + +def setup_pydantic_models() -> Tuple[type[Annotated], type[Annotated]]: + """ + Setup the pydantic models for the light sources and lenses. + + Returns + ------- + light_sources : type[Annotated] + The annotated union of the light source pydantic models + lenses : type[Annotated] + The annotated union of the lens pydantic models + """ + # Cosmology + cosmology_models = [create_pydantic_model(cosmo) for cosmo in _registry.cosmology] + cosmology = Annotated[Union[tuple(cosmology_models)], Field(discriminator="kind")] + # Light + light_models = [create_pydantic_model(light) for light in _registry.light] + light_sources = Annotated[Union[tuple(light_models)], Field(discriminator="kind")] + # Single Lens + lens_dependant_models = {"cosmology": cosmology} + single_lens_models = [ + create_pydantic_model(lens, dependant_models=lens_dependant_models) + for lens in _registry.single_lenses + if lens != "SinglePlane" # make exception for single plane + ] + single_lenses = Annotated[ + Union[tuple(single_lens_models)], Field(discriminator="kind") + ] + # Single plane + # this is a special case since single plane + # is a multi lens system + # but this is an option for multi lens + single_plane_model = create_pydantic_model( + "SinglePlane", + dependant_models={"lenses": [single_lenses], **lens_dependant_models}, + ) + single_lenses_and_plane = Annotated[ + Union[tuple([single_plane_model, *single_lens_models])], + Field(discriminator="kind"), + ] + # Multi Lens + multi_lens_models = [ + create_pydantic_model( + lens, + dependant_models={ + "lenses": [single_lenses_and_plane], + **lens_dependant_models, + }, + ) + for lens in _registry.multi_lenses + ] + lenses = Annotated[ + Union[tuple([single_plane_model, *single_lens_models, *multi_lens_models])], + Field(discriminator="kind"), + ] + return light_sources, lenses + + +def setup_simulator_models() -> type[Annotated]: + """ + Setup the pydantic models for the simulators + + Returns + ------- + type[Annotated] + The annotated union of the simulator pydantic models + """ + light_sources, lenses = setup_pydantic_models() + # Hard code the dependants for now + # there's currently only one simulator + # in the system. + dependents = { + "Lens_Source": { + "source": light_sources, + "lens_light": light_sources, + "lens": lenses, + } + } + simulators_models = [ + create_pydantic_model(sim, dependant_models=dependents.get(sim)) + for sim in _registry.simulators + ] + return Annotated[Union[tuple(simulators_models)], Field(discriminator="kind")] diff --git a/src/caustics/parametrized.py b/src/caustics/parametrized.py index a830ca4b..e7ac56eb 100644 --- a/src/caustics/parametrized.py +++ b/src/caustics/parametrized.py @@ -2,6 +2,8 @@ from collections import OrderedDict from math import prod from typing import Optional, Union, List +from dataclasses import dataclass + import functools import itertools as it import inspect @@ -20,6 +22,13 @@ __all__ = ("Parametrized", "unpack") +@dataclass +class ClassParam: + description: str + isParam: bool = False + unit: Optional[str] = None + + def check_valid_name(name): if keyword.iskeyword(name) or not bool(re.match("^[a-zA-Z_][a-zA-Z0-9_]*$", name)): raise NameError( diff --git a/src/caustics/sims/lens_source.py b/src/caustics/sims/lens_source.py index b9cafb88..2844ea53 100644 --- a/src/caustics/sims/lens_source.py +++ b/src/caustics/sims/lens_source.py @@ -2,15 +2,18 @@ from scipy.fft import next_fast_len from torch.nn.functional import avg_pool2d, conv2d -from typing import Optional +from typing import Optional, Annotated, Literal, Union import torch +from torch import Tensor -from .simulator import Simulator +from .simulator import Simulator, NameType from ..utils import ( get_meshgrid, gaussian_quadrature_grid, gaussian_quadrature_integrator, ) +from ..lenses.base import Lens +from ..light.base import Source __all__ = ("Lens_Source",) @@ -42,17 +45,17 @@ class Lens_Source(Simulator): Attributes ---------- - lens + lens: Lens caustics lens mass model object - source + source: Source caustics light object which defines the background source pixelscale: float pixelscale of the sampling grid. pixels_x: int number of pixels on the x-axis for the sampling grid - lens_light: (optional) + lens_light: Source, optional caustics light object which defines the lensing object's light - psf: (optional) + psf: Tensor, optional An image to convolve with the scene. Note that if ``upsample_factor > 1`` the psf must also be at the higher resolution. pixels_y: Optional[int] number of pixels on the y-axis for the sampling grid. If left as ``None`` then this will simply be equal to ``gridx`` @@ -76,18 +79,31 @@ class Lens_Source(Simulator): def __init__( self, - lens, - source, - pixelscale: float, - pixels_x: int, - lens_light=None, - psf=None, - pixels_y: Optional[int] = None, - upsample_factor: int = 1, - psf_pad=True, - psf_mode="fft", - z_s=None, - name: str = "sim", + lens: Annotated[Lens, "caustics lens mass model object"], + source: Annotated[ + Source, "caustics light object which defines the background source" + ], + pixelscale: Annotated[float, "pixelscale of the sampling grid"], + pixels_x: Annotated[ + int, "number of pixels on the x-axis for the sampling grid" + ], + lens_light: Annotated[ + Optional[Source], + "caustics light object which defines the lensing object's light", + ] = None, + psf: Annotated[Optional[Tensor], "An image to convolve with the scene"] = None, + pixels_y: Annotated[ + Optional[int], "number of pixels on the y-axis for the sampling grid" + ] = None, + upsample_factor: Annotated[int, "Amount of upsampling to model the image"] = 1, + psf_pad: Annotated[bool, "Flag to apply padding to psf"] = True, + psf_mode: Annotated[ + Literal["fft", "conv2d"], "Mode for convolving psf" + ] = "fft", + z_s: Annotated[ + Optional[Union[Tensor, float]], "Redshift of the source", True + ] = None, + name: NameType = "sim", ): super().__init__(name) diff --git a/src/caustics/sims/simulator.py b/src/caustics/sims/simulator.py index b12e1457..0dfcc258 100644 --- a/src/caustics/sims/simulator.py +++ b/src/caustics/sims/simulator.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Annotated, Optional from torch import Tensor from ..parametrized import Parametrized @@ -7,6 +7,8 @@ __all__ = ("Simulator",) +NameType = Annotated[Optional[str], "Name of the simulator"] + class Simulator(Parametrized): """A caustics simulator using Parametrized framework. diff --git a/src/caustics/utils.py b/src/caustics/utils.py index 2ade0c59..a1bf7fb5 100644 --- a/src/caustics/utils.py +++ b/src/caustics/utils.py @@ -1,6 +1,7 @@ # mypy: disable-error-code="misc" from math import pi -from typing import Callable, Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union, Any +from importlib import import_module from functools import partial, lru_cache import torch @@ -10,6 +11,58 @@ from scipy.special import roots_legendre +def _import_func_or_class(module_path: str) -> Any: + """ + Import a function or class from a module path + + Parameters + ---------- + module_path : str + The module path to import from + + Returns + ------- + Callable + The imported function or class + """ + module_name, name = module_path.rsplit(".", 1) + mod = import_module(module_name) + return getattr(mod, name) # type: ignore + + +def _eval_expression(input_string: str) -> Union[int, float]: + """ + Evaluates a string expression to create an integer or float + + Parameters + ---------- + input_string : str + The string expression to evaluate + + Returns + ------- + Union[int, float] + The result of the evaluation + + Raises + ------ + NameError + If a disallowed constant is used + """ + # Allowed modules to use string evaluation + allowed_names = {"pi": pi} + # Compile the input string + code = compile(input_string, "", "eval") + # Check for disallowed names + for name in code.co_names: + if name not in allowed_names: + # Throw an error if a disallowed name is used + raise NameError(f"Use of {name} not allowed") + # Evaluate the input string without using builtins + # for security + return eval(code, {"__builtins__": {}}, allowed_names) + + def flip_axis_ratio(q, phi): """ Makes the value of 'q' positive, then swaps x and y axes if 'q' is larger than 1. diff --git a/tests/conftest.py b/tests/conftest.py index 75882806..5b1267fe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,42 @@ import os import torch import pytest +import typing # Add the helpers directory to the path so we can import the helpers sys.path.append(os.path.join(os.path.dirname(__file__), "utils")) +from caustics.models.utils import setup_pydantic_models + CUDA_AVAILABLE = torch.cuda.is_available() +LIGHT_ANNOTATED, LENSES_ANNOTATED = setup_pydantic_models() + + +def _get_models(annotated): + typehint = typing.get_args(annotated)[0] + pydantic_models = typing.get_args(typehint) + if isinstance(pydantic_models, tuple): + pydantic_models = {m.__name__: m for m in pydantic_models} + else: + pydantic_models = {pydantic_models.__name__: pydantic_models} + return pydantic_models + + +@pytest.fixture +def light_models(): + return _get_models(LIGHT_ANNOTATED) + + +@pytest.fixture +def lens_models(): + return _get_models(LENSES_ANNOTATED) + + +@pytest.fixture(params=["yaml", "no_yaml"]) +def sim_source(request): + return request.param + @pytest.fixture( params=[ diff --git a/tests/models/test_mod_api.py b/tests/models/test_mod_api.py new file mode 100644 index 00000000..a045dabf --- /dev/null +++ b/tests/models/test_mod_api.py @@ -0,0 +1,228 @@ +from tempfile import NamedTemporaryFile +import os +import yaml + +import pytest +import torch +from pydantic import create_model + +import caustics +from caustics.models.utils import setup_simulator_models +from caustics.models.base_models import StateConfig, Field +from utils.models import setup_complex_multiplane_yaml +import textwrap + + +@pytest.fixture +def ConfigModel(): + simulators = setup_simulator_models() + return create_model( + "Config", __base__=StateConfig, simulator=(simulators, Field(...)) + ) + + +@pytest.fixture +def x_input(): + return torch.tensor([ + # z_s z_l x0 y0 q phi b x0 y0 q phi n Re + 1.5, 0.5, -0.2, 0.0, 0.4, 1.5708, 1.7, 0.0, 0.0, 0.5, -0.985, 1.3, 1.0, + # Ie x0 y0 q phi n Re Ie + 5.0, -0.2, 0.0, 0.8, 0.0, 1., 1.0, 10.0 + ]) # fmt: skip + + +@pytest.fixture +def sim_yaml(): + return textwrap.dedent( + """\ + cosmology: &cosmo + name: cosmo + kind: FlatLambdaCDM + + lens: &lens + name: lens + kind: SIE + init_kwargs: + cosmology: *cosmo + + src: &src + name: source + kind: Sersic + + lnslt: &lnslt + name: lenslight + kind: Sersic + + simulator: + name: minisim + kind: Lens_Source + init_kwargs: + # Single lense + lens: *lens + source: *src + lens_light: *lnslt + pixelscale: 0.05 + pixels_x: 100 + """ + ) + + +def _write_temp_yaml(yaml_str: str): + # Create temp file + f = NamedTemporaryFile("w", delete=False) + f.write(yaml_str) + f.flush() + f.close() + + return f.name + + +@pytest.fixture +def sim_yaml_file(sim_yaml): + temp_file = _write_temp_yaml(sim_yaml) + + yield temp_file + + if os.path.exists(temp_file): + os.unlink(temp_file) + + +@pytest.fixture +def simple_config_dict(sim_yaml): + return yaml.safe_load(sim_yaml) + + +@pytest.fixture +def sim_obj(): + cosmology = caustics.FlatLambdaCDM() + sie = caustics.SIE(cosmology=cosmology, name="lens") + src = caustics.Sersic(name="source") + lnslt = caustics.Sersic(name="lenslight") + return caustics.Lens_Source( + lens=sie, source=src, lens_light=lnslt, pixelscale=0.05, pixels_x=100 + ) + + +def test_build_simulator(sim_yaml_file, sim_obj, x_input): + sim = caustics.build_simulator(sim_yaml_file) + + result = sim(x_input, quad_level=3) + expected_result = sim_obj(x_input, quad_level=3) + assert sim.get_graph(True, True) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected_result) + + +def test_complex_build_simulator(): + yaml_str = setup_complex_multiplane_yaml() + x = torch.tensor( + [ + # z_s x0 y0 q phi n Re + 1.5, + 0.0, + 0.0, + 0.5, + -0.985, + 1.3, + 1.0, + # Ie x0 y0 q phi n Re Ie + 5.0, + -0.2, + 0.0, + 0.8, + 0.0, + 1.0, + 1.0, + 10.0, + ] + ) + # Create temp file + temp_file = _write_temp_yaml(yaml_str) + + # Open the temp file and build the simulator + sim = caustics.build_simulator(temp_file) + image = sim(x, quad_level=3) + assert isinstance(image, torch.Tensor) + + # Remove the temp file + if os.path.exists(temp_file): + os.unlink(temp_file) + + +def test_build_simulator_w_state(sim_yaml_file, sim_obj, x_input): + sim = caustics.build_simulator(sim_yaml_file) + params = dict(zip(sim.x_order, x_input)) + + # Set the parameters from x input + # using set attribute to the module objects + # this makes the the params to be static + for k, v in params.items(): + n, p = k.split(".") + if n == sim.name: + setattr(sim, p, v) + continue + key = sim._module_key_map[n] + mod = getattr(sim, key) + setattr(mod, p, v) + + state_dict = sim.state_dict() + + # Save the state + state_path = None + with NamedTemporaryFile("wb", suffix=".st", delete=False) as f: + state_path = f.name + state_dict.save(state_path) + + # Add the path to state to the sim yaml + with open(sim_yaml_file, "a") as f: + f.write( + textwrap.dedent( + f""" + state: + load: + path: {state_path} + """ + ) + ) + + # Load the state + # First remove the original sim + del sim + newsim = caustics.build_simulator(sim_yaml_file) + result = newsim(quad_level=3) + expected_result = sim_obj(x_input, quad_level=3) + assert newsim.get_graph(True, True) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected_result) + + +@pytest.mark.parametrize( + "psf", + [ + { + "func": "caustics.utils.gaussian", + "kwargs": { + "pixelscale": 0.05, + "nx": 11, + "ny": 12, + "sigma": 0.2, + "upsample": 2, + }, + }, + {"function": "caustics.utils.gaussian", "sigma": 0.2}, + [[2.0], [2.0]], + ], +) +@pytest.mark.parametrize("pixels_y", ["50", 50.3]) # will get casted to int +def test_init_kwargs_validate(ConfigModel, simple_config_dict, psf, pixels_y): + # Add psf + test_config_dict = {**simple_config_dict} + test_config_dict["simulator"]["init_kwargs"]["psf"] = psf + test_config_dict["simulator"]["init_kwargs"]["pixels_y"] = pixels_y + if isinstance(psf, dict) and "func" not in psf: + with pytest.raises(ValueError): + ConfigModel(**test_config_dict) + else: + # Test that the init_kwargs are validated + config = ConfigModel(**test_config_dict) + assert config.simulator.model_obj() diff --git a/tests/models/test_mod_registry.py b/tests/models/test_mod_registry.py new file mode 100644 index 00000000..2e10c0a0 --- /dev/null +++ b/tests/models/test_mod_registry.py @@ -0,0 +1,97 @@ +import pytest + +import caustics +from caustics.models.registry import ( + _KindRegistry, + available_kinds, + register_kind, + get_kind, + _registry, +) +from caustics.parameter import Parameter +from caustics.parametrized import Parametrized + + +class TestKindRegistry: + expected_attrs = [ + "cosmology", + "single_lenses", + "multi_lenses", + "light", + "simulators", + "known_kinds", + "_m", + ] + + def test_constructor(self): + registry = _KindRegistry() + + for attr in self.expected_attrs: + assert hasattr(registry, attr) + + @pytest.mark.parametrize("kind", ["NonExistingClass", "SIE", caustics.Sersic]) + def test_getitem(self, kind, mocker): + registry = _KindRegistry() + + if kind == "NonExistingClass": + with pytest.raises(KeyError): + registry[kind] + elif isinstance(kind, str): + cls = registry[kind] + assert cls == getattr(caustics, kind) + else: + test_key = "TestSersic" + registry.known_kinds[test_key] = kind + cls = registry[test_key] + assert cls == kind + + @pytest.mark.parametrize("kind", [Parameter, caustics.Sersic, "caustics.SIE"]) + def test_setitem(self, kind): + registry = _KindRegistry() + key = "TestSersic" + if isinstance(kind, str): + registry[key] = kind + assert key in registry._m + elif issubclass(kind, Parametrized): + registry[key] = kind + assert registry[key] == kind + else: + with pytest.raises(ValueError): + registry[key] = kind + + def test_delitem(self): + registry = _KindRegistry() + with pytest.raises(NotImplementedError): + del registry["Sersic"] + + def test_len(self): + registry = _KindRegistry() + assert len(registry) == len(set(registry._m)) + + def test_iter(self): + registry = _KindRegistry() + assert set(registry) == set(registry._m) + + +def test_available_kinds(): + assert available_kinds() == list(_registry) + + +def test_register_kind(): + key = "TestSersic2" + value = caustics.Sersic + register_kind(key, value) + assert key in _registry._m + assert _registry[key] == value + + with pytest.raises(ValueError): + register_kind("SIE", "caustics.SIE") + + +def test_get_kind(): + kind = "Sersic" + cls = get_kind(kind) + assert cls == caustics.Sersic + kind = "NonExistingClass" + with pytest.raises(KeyError): + cls = get_kind(kind) diff --git a/tests/models/test_mod_utils.py b/tests/models/test_mod_utils.py new file mode 100644 index 00000000..edd88a7d --- /dev/null +++ b/tests/models/test_mod_utils.py @@ -0,0 +1,118 @@ +import pytest +import inspect +import typing +from typing import Annotated, Dict +from caustics.models.registry import _registry, get_kind +from caustics.models.utils import ( + create_pydantic_model, + setup_pydantic_models, + setup_simulator_models, + _get_kwargs_field_definitions, + PARAMS, + INIT_KWARGS, +) +from caustics.models.base_models import Base +from caustics.parametrized import ClassParam + + +@pytest.fixture(params=_registry.known_kinds) +def kind(request): + return request.param + + +@pytest.fixture +def parametrized_class(kind): + return get_kind(kind) + + +def test_create_pydantic_model(kind): + model = create_pydantic_model(kind) + kind_cls = get_kind(kind) + expected_fields = {"kind", "name", "params", "init_kwargs"} + + assert model.__base__ == Base + assert model.__name__ == kind + assert model._cls == kind_cls + assert set(model.model_fields.keys()) == expected_fields + + +def test__get_kwargs_field_definitions(parametrized_class): + kwargs_fd = _get_kwargs_field_definitions(parametrized_class) + + cls_signature = inspect.signature(parametrized_class) + class_metadata = { + k: { + "dtype": v.annotation.__origin__, + "default": v.default, + "class_param": ClassParam(*v.annotation.__metadata__), + } + for k, v in cls_signature.parameters.items() + } + + for k, v in class_metadata.items(): + if k != "name": + if v["class_param"].isParam: + assert k in kwargs_fd[PARAMS] + assert isinstance(kwargs_fd[PARAMS][k], tuple) + assert kwargs_fd[PARAMS][k][0] == v["dtype"] + field_info = kwargs_fd[PARAMS][k][1] + else: + assert k in kwargs_fd[INIT_KWARGS] + assert isinstance(kwargs_fd[INIT_KWARGS][k], tuple) + assert kwargs_fd[INIT_KWARGS][k][0] == v["dtype"] + field_info = kwargs_fd[INIT_KWARGS][k][1] + + if v["default"] == inspect._empty: + # Skip empty defaults + continue + assert field_info.default == v["default"] + + +def _check_nested_discriminated_union( + input_anno: type[Annotated], class_paths: Dict[str, str] +): + # Check to see if the model selection is Annotated type + assert typing.get_origin(input_anno) == Annotated + # Check to see if the discriminator is "kind" + assert input_anno.__metadata__[0].discriminator == "kind" + + if typing.get_origin(input_anno.__origin__) == typing.Union: + models = input_anno.__origin__.__args__ + else: + # For single models + models = [input_anno.__origin__] + + # Check to see if the models are in the registry + assert len(models) == len(class_paths) + # Go through each model and check that it's pointing to the right class + for model in models: + assert model.__name__ in class_paths + assert model._cls == get_kind(model.__name__) + + +def test_setup_pydantic_models(): + # light, lenses + pydantic_models_annotated = setup_pydantic_models() + + registry_dict = { + "light": _registry.light, + "lenses": { + **_registry.single_lenses, + **_registry.multi_lenses, + }, + } + + pm_anno_dict = { + k: v for (k, v) in zip(list(registry_dict.keys()), pydantic_models_annotated) + } + + for key, pydantic_model_anno in pm_anno_dict.items(): + class_paths = registry_dict[key] + _check_nested_discriminated_union(pydantic_model_anno, class_paths) + + +def test_setup_simulator_models(): + simulators = setup_simulator_models() + + class_paths = _registry.simulators + _check_nested_discriminated_union(simulators, class_paths) diff --git a/tests/test_epl.py b/tests/test_epl.py index 5a099fc6..f0d79491 100644 --- a/tests/test_epl.py +++ b/tests/test_epl.py @@ -1,4 +1,5 @@ from math import pi +import yaml import lenstronomy.Util.param_util as param_util import torch @@ -9,10 +10,25 @@ from caustics.lenses import EPL -def test_lenstronomy(device): - # Models - cosmology = FlatLambdaCDM(name="cosmo") - lens = EPL(name="epl", cosmology=cosmology) +def test_lenstronomy(sim_source, device, lens_models): + if sim_source == "yaml": + yaml_str = """\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + lens: &lens + name: epl + kind: EPL + init_kwargs: + cosmology: *cosmology + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("EPL") + lens = mod(**yaml_dict["lens"]).model_obj() + else: + # Models + cosmology = FlatLambdaCDM(name="cosmo") + lens = EPL(name="epl", cosmology=cosmology) lens = lens.to(device=device) # There is also an EPL_NUMBA class lenstronomy, but it shouldn't matter much lens_model_list = ["EPL"] diff --git a/tests/test_external_shear.py b/tests/test_external_shear.py index cec365f9..8af65ee4 100644 --- a/tests/test_external_shear.py +++ b/tests/test_external_shear.py @@ -1,4 +1,5 @@ import torch +import yaml from lenstronomy.LensModel.lens_model import LensModel from utils import lens_test_helper @@ -6,13 +7,28 @@ from caustics.lenses import ExternalShear -def test(device): +def test(sim_source, device, lens_models): atol = 1e-5 rtol = 1e-5 - # Models - cosmology = FlatLambdaCDM(name="cosmo") - lens = ExternalShear(name="shear", cosmology=cosmology) + if sim_source == "yaml": + yaml_str = """\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + lens: &lens + name: shear + kind: ExternalShear + init_kwargs: + cosmology: *cosmology + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("ExternalShear") + lens = mod(**yaml_dict["lens"]).model_obj() + else: + # Models + cosmology = FlatLambdaCDM(name="cosmo") + lens = ExternalShear(name="shear", cosmology=cosmology) lens.to(device=device) lens_model_list = ["SHEAR"] lens_ls = LensModel(lens_model_list=lens_model_list) diff --git a/tests/test_masssheet.py b/tests/test_masssheet.py index 81750ea9..051ddc52 100644 --- a/tests/test_masssheet.py +++ b/tests/test_masssheet.py @@ -1,14 +1,30 @@ import torch +import yaml from caustics.cosmology import FlatLambdaCDM from caustics.lenses import MassSheet from caustics.utils import get_meshgrid -def test(device): - # Models - cosmology = FlatLambdaCDM(name="cosmo") - lens = MassSheet(name="sheet", cosmology=cosmology) +def test(sim_source, device, lens_models): + if sim_source == "yaml": + yaml_str = """\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + lens: &lens + name: sheet + kind: MassSheet + init_kwargs: + cosmology: *cosmology + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("MassSheet") + lens = mod(**yaml_dict["lens"]).model_obj() + else: + # Models + cosmology = FlatLambdaCDM(name="cosmo") + lens = MassSheet(name="sheet", cosmology=cosmology) lens.to(device=device) diff --git a/tests/test_multiplane.py b/tests/test_multiplane.py index 25c4415f..5aa13898 100644 --- a/tests/test_multiplane.py +++ b/tests/test_multiplane.py @@ -1,4 +1,5 @@ from math import pi +import yaml import lenstronomy.Util.param_util as param_util import torch @@ -12,14 +13,12 @@ from caustics.utils import get_meshgrid -def test(device): +def test(sim_source, device, lens_models): rtol = 0 atol = 5e-3 # Setup z_s = torch.tensor(1.5, dtype=torch.float32) - cosmology = FlatLambdaCDM(name="cosmo") - cosmology.to(dtype=torch.float32, device=device) # Parameters xs = [ @@ -29,11 +28,50 @@ def test(device): ] x = torch.tensor([p for _xs in xs for p in _xs], dtype=torch.float32, device=device) - lens = Multiplane( - name="multiplane", - cosmology=cosmology, - lenses=[SIE(name=f"sie_{i}", cosmology=cosmology) for i in range(len(xs))], - ) + if sim_source == "yaml": + yaml_str = """\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + sie1: &sie1 + name: sie_1 + kind: SIE + init_kwargs: + cosmology: *cosmology + sie2: &sie2 + name: sie_2 + kind: SIE + init_kwargs: + cosmology: *cosmology + sie3: &sie3 + name: sie_3 + kind: SIE + init_kwargs: + cosmology: *cosmology + + lens: &lens + name: multiplane + kind: Multiplane + init_kwargs: + cosmology: *cosmology + lenses: + - *sie1 + - *sie2 + - *sie3 + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("Multiplane") + lens = mod(**yaml_dict["lens"]).model_obj() + lens.to(dtype=torch.float32, device=device) + cosmology = lens.cosmology + else: + cosmology = FlatLambdaCDM(name="cosmo") + cosmology.to(dtype=torch.float32, device=device) + lens = Multiplane( + name="multiplane", + cosmology=cosmology, + lenses=[SIE(name=f"sie_{i}", cosmology=cosmology) for i in range(len(xs))], + ) # lenstronomy kwargs_ls = [] diff --git a/tests/test_nfw.py b/tests/test_nfw.py index 6dbf33dc..a9c30191 100644 --- a/tests/test_nfw.py +++ b/tests/test_nfw.py @@ -2,6 +2,7 @@ # import lenstronomy.Util.param_util as param_util import torch +import yaml from astropy.cosmology import FlatLambdaCDM as FlatLambdaCDM_AP from astropy.cosmology import default_cosmology @@ -18,14 +19,31 @@ Ob0_default = float(default_cosmology.get().Ob0) -def test(device): +def test(sim_source, device, lens_models): atol = 1e-5 rtol = 3e-2 - - # Models - cosmology = CausticFlatLambdaCDM(name="cosmo") z_l = torch.tensor(0.1) - lens = NFW(name="nfw", cosmology=cosmology, z_l=z_l) + + if sim_source == "yaml": + yaml_str = f"""\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + lens: &lens + name: nfw + kind: NFW + params: + z_l: {float(z_l)} + init_kwargs: + cosmology: *cosmology + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("NFW") + lens = mod(**yaml_dict["lens"]).model_obj() + else: + # Models + cosmology = CausticFlatLambdaCDM(name="cosmo") + lens = NFW(name="nfw", cosmology=cosmology, z_l=z_l) lens_model_list = ["NFW"] lens_ls = LensModel(lens_model_list=lens_model_list) @@ -53,10 +71,29 @@ def test(device): lens_test_helper(lens, lens_ls, z_s, x, kwargs_ls, atol, rtol, device=device) -def test_runs(device): - cosmology = CausticFlatLambdaCDM(name="cosmo") +def test_runs(sim_source, device, lens_models): z_l = torch.tensor(0.1) - lens = NFW(name="nfw", cosmology=cosmology, z_l=z_l, use_case="differentiable") + if sim_source == "yaml": + yaml_str = f"""\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + lens: &lens + name: nfw + kind: NFW + params: + z_l: {float(z_l)} + init_kwargs: + cosmology: *cosmology + use_case: differentiable + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("NFW") + lens = mod(**yaml_dict["lens"]).model_obj() + else: + # Models + cosmology = CausticFlatLambdaCDM(name="cosmo") + lens = NFW(name="nfw", cosmology=cosmology, z_l=z_l, use_case="differentiable") lens.to(device=device) # Parameters z_s = torch.tensor(0.5) diff --git a/tests/test_point.py b/tests/test_point.py index 3513ac6d..2584c944 100644 --- a/tests/test_point.py +++ b/tests/test_point.py @@ -1,4 +1,5 @@ import torch +import yaml from lenstronomy.LensModel.lens_model import LensModel from utils import lens_test_helper @@ -6,13 +7,31 @@ from caustics.lenses import Point -def test(device): +def test(sim_source, device, lens_models): atol = 1e-5 rtol = 1e-5 - - # Models - cosmology = FlatLambdaCDM(name="cosmo") - lens = Point(name="point", cosmology=cosmology, z_l=torch.tensor(0.9)) + z_l = torch.tensor(0.9) + + if sim_source == "yaml": + yaml_str = f"""\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + lens: &lens + name: point + kind: Point + params: + z_l: {float(z_l)} + init_kwargs: + cosmology: *cosmology + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("Point") + lens = mod(**yaml_dict["lens"]).model_obj() + else: + # Models + cosmology = FlatLambdaCDM(name="cosmo") + lens = Point(name="point", cosmology=cosmology, z_l=z_l) lens_model_list = ["POINT_MASS"] lens_ls = LensModel(lens_model_list=lens_model_list) diff --git a/tests/test_pseudo_jaffe.py b/tests/test_pseudo_jaffe.py index b3b6c328..84d75fe7 100644 --- a/tests/test_pseudo_jaffe.py +++ b/tests/test_pseudo_jaffe.py @@ -1,4 +1,5 @@ import torch +import yaml from lenstronomy.LensModel.lens_model import LensModel from utils import lens_test_helper @@ -6,13 +7,29 @@ from caustics.lenses import PseudoJaffe -def test(device): +def test(sim_source, device, lens_models): atol = 1e-5 rtol = 1e-5 - # Models - cosmology = FlatLambdaCDM(name="cosmo") - lens = PseudoJaffe(name="pj", cosmology=cosmology) + if sim_source == "yaml": + yaml_str = """\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + lens: + name: pj + kind: PseudoJaffe + init_kwargs: + cosmology: *cosmology + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("PseudoJaffe") + lens = mod(**yaml_dict["lens"]).model_obj() + cosmology = lens.cosmology + else: + # Models + cosmology = FlatLambdaCDM(name="cosmo") + lens = PseudoJaffe(name="pj", cosmology=cosmology) lens_model_list = ["PJAFFE"] lens_ls = LensModel(lens_model_list=lens_model_list) diff --git a/tests/test_sersic.py b/tests/test_sersic.py index f3a1c74b..b024df9c 100644 --- a/tests/test_sersic.py +++ b/tests/test_sersic.py @@ -1,5 +1,6 @@ import lenstronomy.Util.param_util as param_util import numpy as np +import yaml import torch from lenstronomy.Data.pixel_grid import PixelGrid from lenstronomy.LightModel.light_model import LightModel @@ -8,13 +9,26 @@ from caustics.utils import get_meshgrid -def test(device): +def test(sim_source, device, light_models): # Caustics setup res = 0.05 nx = 200 ny = 200 thx, thy = get_meshgrid(res, nx, ny, device=device) - sersic = Sersic(name="sersic", use_lenstronomy_k=True) + + if sim_source == "yaml": + yaml_str = """\ + light: + name: sersic + kind: Sersic + init_kwargs: + use_lenstronomy_k: true + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = light_models.get("Sersic") + sersic = mod(**yaml_dict["light"]).model_obj() + else: + sersic = Sersic(name="sersic", use_lenstronomy_k=True) sersic.to(device=device) # Lenstronomy setup ra_at_xy_0, dec_at_xy_0 = (-5 + res / 2, -5 + res / 2) diff --git a/tests/test_sie.py b/tests/test_sie.py index 6afd894e..91748496 100644 --- a/tests/test_sie.py +++ b/tests/test_sie.py @@ -1,4 +1,5 @@ from math import pi +import yaml import lenstronomy.Util.param_util as param_util import torch @@ -10,13 +11,28 @@ from caustics.utils import get_meshgrid -def test(device): +def test(sim_source, device, lens_models): atol = 1e-5 rtol = 1e-5 - # Models - cosmology = FlatLambdaCDM(name="cosmo") - lens = SIE(name="sie", cosmology=cosmology) + if sim_source == "yaml": + yaml_str = """\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + lens: &lens + name: sie + kind: SIE + init_kwargs: + cosmology: *cosmology + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("SIE") + lens = mod(**yaml_dict["lens"]).model_obj() + else: + # Models + cosmology = FlatLambdaCDM(name="cosmo") + lens = SIE(name="sie", cosmology=cosmology) lens_model_list = ["SIE"] lens_ls = LensModel(lens_model_list=lens_model_list) diff --git a/tests/test_simulator_runs.py b/tests/test_simulator_runs.py index c0915254..91569067 100644 --- a/tests/test_simulator_runs.py +++ b/tests/test_simulator_runs.py @@ -7,40 +7,114 @@ from caustics.lenses import SIE from caustics.light import Sersic from caustics.utils import gaussian +from caustics import build_simulator +from utils import mock_from_file -def test_simulator_runs(device): - # Model - cosmology = FlatLambdaCDM(name="cosmo") - lensmass = SIE( - name="lens", - cosmology=cosmology, - z_l=1.0, - x0=0.0, - y0=0.01, - q=0.5, - phi=pi / 3.0, - b=1.0, - ) - source = Sersic( - name="source", x0=0.01, y0=-0.03, q=0.6, phi=-pi / 4, n=2.0, Re=0.5, Ie=1.0 - ) - lenslight = Sersic( - name="lenslight", x0=0.0, y0=0.01, q=0.7, phi=pi / 4, n=3.0, Re=0.7, Ie=1.0 - ) +def test_simulator_runs(sim_source, device, mocker): + if sim_source == "yaml": + yaml_str = """\ + cosmology: &cosmology + name: "cosmo" + kind: FlatLambdaCDM - psf = gaussian(0.05, 11, 11, 0.2, upsample=2) + lensmass: &lensmass + name: lens + kind: SIE + params: + z_l: 1.0 + x0: 0.0 + y0: 0.01 + q: 0.5 + phi: pi / 3.0 + b: 1.0 + init_kwargs: + cosmology: *cosmology + + source: &source + name: source + kind: Sersic + params: + x0: 0.01 + y0: -0.03 + q: 0.6 + phi: -pi / 4 + n: 2.0 + Re: 0.5 + Ie: 1.0 + + lenslight: &lenslight + name: lenslight + kind: Sersic + params: + x0: 0.0 + y0: 0.01 + q: 0.7 + phi: pi / 4 + n: 3.0 + Re: 0.7 + Ie: 1.0 + + psf: &psf + func: caustics.utils.gaussian + kwargs: + pixelscale: 0.05 + nx: 11 + ny: 12 + sigma: 0.2 + upsample: 2 + + simulator: + name: simulator + kind: Lens_Source + params: + z_s: 2.0 + init_kwargs: + # Single lense + lens: *lensmass + source: *source + lens_light: *lenslight + pixelscale: 0.05 + pixels_x: 50 + psf: *psf + """ + mock_from_file(mocker, yaml_str) + sim = build_simulator("/path/to/sim.yaml") # Path doesn't actually exists + else: + # Model + cosmology = FlatLambdaCDM(name="cosmo") + lensmass = SIE( + name="lens", + cosmology=cosmology, + z_l=1.0, + x0=0.0, + y0=0.01, + q=0.5, + phi=pi / 3.0, + b=1.0, + ) + + source = Sersic( + name="source", x0=0.01, y0=-0.03, q=0.6, phi=-pi / 4, n=2.0, Re=0.5, Ie=1.0 + ) + lenslight = Sersic( + name="lenslight", x0=0.0, y0=0.01, q=0.7, phi=pi / 4, n=3.0, Re=0.7, Ie=1.0 + ) + + psf = gaussian(0.05, 11, 11, 0.2, upsample=2) + + sim = Lens_Source( + name="simulator", + lens=lensmass, + source=source, + pixelscale=0.05, + pixels_x=50, + lens_light=lenslight, + psf=psf, + z_s=2.0, + ) - sim = Lens_Source( - lens=lensmass, - source=source, - pixelscale=0.05, - pixels_x=50, - lens_light=lenslight, - psf=psf, - z_s=2.0, - ) sim.to(device=device) assert torch.all(torch.isfinite(sim())) diff --git a/tests/test_sis.py b/tests/test_sis.py index b240deb0..a4594e55 100644 --- a/tests/test_sis.py +++ b/tests/test_sis.py @@ -1,18 +1,37 @@ import torch from lenstronomy.LensModel.lens_model import LensModel from utils import lens_test_helper +import yaml from caustics.cosmology import FlatLambdaCDM from caustics.lenses import SIS -def test(device): +def test(sim_source, device, lens_models): atol = 1e-5 rtol = 1e-5 - - # Models - cosmology = FlatLambdaCDM(name="cosmo") - lens = SIS(name="sis", cosmology=cosmology, z_l=torch.tensor(0.5)) + z_l = torch.tensor(0.5) + + if sim_source == "yaml": + yaml_str = f"""\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + lens: &lens + name: sis + kind: SIS + params: + z_l: {float(z_l)} + init_kwargs: + cosmology: *cosmology + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("SIS") + lens = mod(**yaml_dict["lens"]).model_obj() + else: + # Models + cosmology = FlatLambdaCDM(name="cosmo") + lens = SIS(name="sis", cosmology=cosmology, z_l=z_l) lens_model_list = ["SIS"] lens_ls = LensModel(lens_model_list=lens_model_list) diff --git a/tests/test_tnfw.py b/tests/test_tnfw.py index 8b2e314e..33a2b5eb 100644 --- a/tests/test_tnfw.py +++ b/tests/test_tnfw.py @@ -1,6 +1,7 @@ # from math import pi # import lenstronomy.Util.param_util as param_util +import yaml import torch from astropy.cosmology import FlatLambdaCDM as FlatLambdaCDM_AP from astropy.cosmology import default_cosmology @@ -18,14 +19,35 @@ Ob0_default = float(default_cosmology.get().Ob0) -def test(device): +def test(sim_source, device, lens_models): atol = 1e-5 rtol = 3e-2 - - # Models - cosmology = CausticFlatLambdaCDM(name="cosmo") z_l = torch.tensor(0.1) - lens = TNFW(name="tnfw", cosmology=cosmology, z_l=z_l, interpret_m_total_mass=False) + + if sim_source == "yaml": + yaml_str = f"""\ + cosmology: &cosmology + name: cosmo + kind: FlatLambdaCDM + lens: &lens + name: tnfw + kind: TNFW + params: + z_l: {float(z_l)} + init_kwargs: + cosmology: *cosmology + interpret_m_total_mass: false + """ + yaml_dict = yaml.safe_load(yaml_str.encode("utf-8")) + mod = lens_models.get("TNFW") + lens = mod(**yaml_dict["lens"]).model_obj() + else: + # Models + cosmology = CausticFlatLambdaCDM(name="cosmo") + lens = TNFW( + name="tnfw", cosmology=cosmology, z_l=z_l, interpret_m_total_mass=False + ) + lens_model_list = ["TNFW"] lens_ls = LensModel(lens_model_list=lens_model_list) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 57444540..b8cc9102 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -14,7 +14,11 @@ from caustics.utils import get_meshgrid from caustics.sims import Simulator from caustics.cosmology import FlatLambdaCDM +from .models import mock_from_file +__all__ = ( + "mock_from_file", +) def setup_simulator(cosmo_static=False, use_nfw=True, simulator_static=False, batched_params=False, device=None): n_pix = 20 diff --git a/tests/utils/models.py b/tests/utils/models.py new file mode 100644 index 00000000..51ab882b --- /dev/null +++ b/tests/utils/models.py @@ -0,0 +1,178 @@ +import yaml +import torch +import numpy as np + +import caustics + + +def mock_from_file(mocker, yaml_str): + # Mock the from_file function + # this way, we don't need to use a real file + mocker.patch("caustics.models.api.from_file", return_value=yaml_str.encode("utf-8")) + + +def obj_to_yaml(obj_dict: dict): + yaml_string = yaml.safe_dump(obj_dict, sort_keys=False) + string_list = yaml_string.split("\n") + id_str = string_list[0] + f" &{string_list[0]}".strip(":") + string_list[0] = id_str + return "\n".join(string_list).replace("'", "") + + +def setup_complex_multiplane_yaml(): + # initialization stuff for lenses + cosmology = caustics.FlatLambdaCDM(name="cosmo") + cosmo = { + cosmology.name: { + "name": cosmology.name, + "kind": cosmology.__class__.__name__, + } + } + cosmology.to(dtype=torch.float32) + n_pix = 100 + res = 0.05 + upsample_factor = 2 + fov = res * n_pix + thx, thy = caustics.utils.get_meshgrid( + res / upsample_factor, + upsample_factor * n_pix, + upsample_factor * n_pix, + dtype=torch.float32, + ) + z_s = torch.tensor(1.5, dtype=torch.float32) + all_lenses = [] + all_single_planes = [] + + N_planes = 10 + N_lenses = 2 # per plane + + z_plane = np.linspace(0.1, 1.0, N_planes) + planes = [] + + for p, z_p in enumerate(z_plane): + lenses = [] + lens_keys = [] + + if p == N_planes // 2: + lens = caustics.NFW( + cosmology=cosmology, + z_l=z_p, + x0=torch.tensor(0.0), + y0=torch.tensor(0.0), + m=torch.tensor(10**11), + c=torch.tensor(10.0), + s=torch.tensor(0.001), + ) + lenses.append(lens) + all_lenses.append( + { + lens.name: { + "name": lens.name, + "kind": lens.__class__.__name__, + "params": { + k: float(v.value) + for k, v in lens.module_params.static.items() + }, + "init_kwargs": {"cosmology": f"*{cosmology.name}"}, + } + } + ) + lens_keys.append(f"*{lens.name}") + else: + for _ in range(N_lenses): + lens = caustics.NFW( + cosmology=cosmology, + z_l=z_p, + x0=torch.tensor(np.random.uniform(-fov / 2.0, fov / 2.0)), + y0=torch.tensor(np.random.uniform(-fov / 2.0, fov / 2.0)), + m=torch.tensor(10 ** np.random.uniform(8, 9)), + c=torch.tensor(np.random.uniform(4, 40)), + s=torch.tensor(0.001), + ) + lenses.append(lens) + all_lenses.append( + { + lens.name: { + "name": lens.name, + "kind": lens.__class__.__name__, + "params": { + k: float(v.value) + for k, v in lens.module_params.static.items() + }, + "init_kwargs": {"cosmology": f"*{cosmology.name}"}, + } + } + ) + lens_keys.append(f"*{lens.name}") + + single_plane = caustics.lenses.SinglePlane( + z_l=z_p, cosmology=cosmology, lenses=lenses, name=f"plane_{p}" + ) + planes.append(single_plane) + all_single_planes.append( + { + single_plane.name: { + "name": single_plane.name, + "kind": single_plane.__class__.__name__, + "params": { + k: float(v.value) + for k, v in single_plane.module_params.static.items() + }, + "init_kwargs": { + "lenses": lens_keys, + "cosmology": f"*{cosmology.name}", + }, + } + } + ) + + lens = caustics.lenses.Multiplane( + name="multiplane", cosmology=cosmology, lenses=planes + ) + multi_dict = { + lens.name: { + "name": lens.name, + "kind": lens.__class__.__name__, + "init_kwargs": { + "lenses": [f"*{p.name}" for p in planes], + "cosmology": f"*{cosmology.name}", + }, + } + } + lenses_yaml = ( + [obj_to_yaml(cosmo)] + + [obj_to_yaml(lens) for lens in all_lenses] + + [obj_to_yaml(plane) for plane in all_single_planes] + + [obj_to_yaml(multi_dict)] + ) + + source_yaml = obj_to_yaml({ + "source": { + "name": "source", + "kind": "Sersic", + } + }) + + lenslight_yaml = obj_to_yaml({ + "lnslight": { + "name": "lnslight", + "kind": "Sersic", + } + }) + + sim_yaml = obj_to_yaml({ + "simulator": { + "name": "sim", + "kind": "Lens_Source", + "init_kwargs": { + "lens": f"*{lens.name}", + "source": "*source", + "lens_light": "*lnslight", + "pixelscale": 0.05, + "pixels_x": 100, + } + } + }) + + all_yaml_list = lenses_yaml + [source_yaml, lenslight_yaml, sim_yaml] + return "\n".join(all_yaml_list)