From 883add6ebaa3be0b0cac05e9ac08787064132730 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Wed, 3 Apr 2024 14:05:32 +0200 Subject: [PATCH 01/13] fix: import blackmanharris from scipy.signal.windows --- hera_sim/sigchain.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hera_sim/sigchain.py b/hera_sim/sigchain.py index 5774da86..f5f44b91 100644 --- a/hera_sim/sigchain.py +++ b/hera_sim/sigchain.py @@ -16,7 +16,7 @@ from pyuvdata import UVBeam from pyuvsim import AnalyticBeam from scipy import stats -from scipy.signal import blackmanharris +from scipy.signal.windows import blackmanharris from typing import Callable from . import DATA_PATH, interpolators, utils From 7b1b5cf492b19967afc3018a31e73a91857bc049 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Wed, 3 Apr 2024 14:31:40 +0200 Subject: [PATCH 02/13] refactor: fix some upcoming numpy 2 bugs --- hera_sim/cli_utils.py | 4 ++-- .../tutorials_data/end_to_end/make_mock_catalog.py | 7 +++++-- hera_sim/foregrounds.py | 12 ++++++------ hera_sim/interpolators.py | 4 ++-- hera_sim/rfi.py | 9 +++++---- hera_sim/simulate.py | 12 ++++++------ hera_sim/tests/test_simulator.py | 10 ++++------ hera_sim/tests/test_utils.py | 4 ++-- pyproject.toml | 13 ++++++++++++- scripts/hera-sim-simulate.py | 2 +- 10 files changed, 45 insertions(+), 32 deletions(-) diff --git a/hera_sim/cli_utils.py b/hera_sim/cli_utils.py index f1b95ed5..95181d66 100644 --- a/hera_sim/cli_utils.py +++ b/hera_sim/cli_utils.py @@ -61,7 +61,7 @@ def validate_config(config: dict): is not valid. """ if config.get("defaults") is not None: - if type(config["defaults"]) is not str: + if not isinstance(config["defaults"], str): raise ValueError( "Defaults in the CLI may only be specified using a string. " "The string used may specify either a path to a configuration " @@ -155,7 +155,7 @@ def write_calfits( # Update gain keys to conform to write_cal assumptions. # New Simulator gains have keys (ant, pol), so shouldn't need # special pre-processing. - if all(np.issctype(type(ant)) for ant in gains.keys()): + if all(issubclass(type(ant), np.generic) for ant in gains.keys()): # Old-style, single polarization assumption. gains = {(ant, "Jee"): gain for ant, gain in gains.items()} diff --git a/hera_sim/data/tutorials_data/end_to_end/make_mock_catalog.py b/hera_sim/data/tutorials_data/end_to_end/make_mock_catalog.py index 81d11705..ac568b65 100644 --- a/hera_sim/data/tutorials_data/end_to_end/make_mock_catalog.py +++ b/hera_sim/data/tutorials_data/end_to_end/make_mock_catalog.py @@ -18,9 +18,12 @@ def make_mock_catalog( sigma=5, index_low=-3, index_high=-1, + seed=None, ): """Generate and svae a mock point source catalog.""" # Easiset to load the metadata this way. + rng = np.random.default_rng(seed) + temp_uvdata = initialize_uvdata_from_params(str(obsparam_file))[0] center_time = np.mean(np.unique(temp_uvdata.time_array)) ref_freq = np.mean(np.unique(temp_uvdata.freq_array)) @@ -38,11 +41,11 @@ def make_mock_catalog( decs = np.array([row[2] for row in sky_model_recarray]) # Randomly assign fluxes (whether this is realistic or not). - ref_fluxes = np.random.lognormal(mean=1, sigma=sigma, size=len(ras)) + ref_fluxes = rng.lognormal(mean=1, sigma=sigma, size=len(ras)) # Don't include super bright sources. ref_fluxes[ref_fluxes > flux_cut] = 1 / ref_fluxes[ref_fluxes > flux_cut] # Assign spectral indices. - indices = np.random.uniform(low=index_low, high=index_high, size=len(ras)) + indices = rng.uniform(low=index_low, high=index_high, size=len(ras)) # Actually add in the spectral structure. freqs = np.unique(temp_uvdata.freq_array) diff --git a/hera_sim/foregrounds.py b/hera_sim/foregrounds.py index f8e54863..fdebad62 100644 --- a/hera_sim/foregrounds.py +++ b/hera_sim/foregrounds.py @@ -264,13 +264,13 @@ def __call__(self, lsts, freqs, bl_vec, **kwargs): # get baseline length (it should already be in ns) bl_len_ns = np.linalg.norm(bl_vec) + rng = np.random.default_rng() + # randomly generate source RAs - ras = np.random.uniform(0, 2 * np.pi, nsrcs) + ras = rng.uniform(0, 2 * np.pi, nsrcs) # draw spectral indices from normal distribution - spec_indices = np.random.normal( - spectral_index_mean, spectral_index_std, size=nsrcs - ) + spec_indices = rng.normal(spectral_index_mean, spectral_index_std, size=nsrcs) # calculate beam width, hardcoded for HERA beam_width = (40 * 60) * (f0 / freqs) / units.sday.to("s") * 2 * np.pi @@ -278,7 +278,7 @@ def __call__(self, lsts, freqs, bl_vec, **kwargs): # draw flux densities from a power law alpha = beta + 1 flux_densities = ( - Smax**alpha + Smin**alpha * (1 - np.random.uniform(size=nsrcs)) + Smax**alpha + Smin**alpha * (1 - rng.uniform(size=nsrcs)) ) ** (1 / alpha) # initialize the visibility array @@ -290,7 +290,7 @@ def __call__(self, lsts, freqs, bl_vec, **kwargs): lst_ind = np.argmin(np.abs(utils.compute_ha(lsts, ra))) # slight offset in delay? why?? - dtau = np.random.uniform(-1, 1) * 0.1 * bl_len_ns + dtau = rng.uniform(-1, 1) * 0.1 * bl_len_ns # fill in the corresponding region of the visibility array vis[lst_ind, :] += flux * (freqs / f0) ** index diff --git a/hera_sim/interpolators.py b/hera_sim/interpolators.py index 0ece1937..432e441a 100644 --- a/hera_sim/interpolators.py +++ b/hera_sim/interpolators.py @@ -316,8 +316,8 @@ def _check_format(self): ) assert self._obj in self._data.keys() and "freqs" in self._data.keys(), ( "You've chosen to use an interp1d object for modeling the " - "{}. Please ensure that the `.npz` archive has the following " - "keys: 'freqs', '{}'".format(self._obj, self._obj) + f"{self._obj}. Please ensure that the `.npz` archive has the following " + f"keys: 'freqs', '{self._obj}'" ) else: # we can relax this a bit and allow for users to also pass a npz diff --git a/hera_sim/rfi.py b/hera_sim/rfi.py index d800bf8a..a859232b 100644 --- a/hera_sim/rfi.py +++ b/hera_sim/rfi.py @@ -96,14 +96,15 @@ def __call__(self, lsts, freqs): ch2 = ch1 + 1 if self.f0 > freqs[ch1] else ch1 - 1 # generate some random phases - phs1, phs2 = np.random.uniform(0, 2 * np.pi, size=2) + rng = np.random.default_rng() + phs1, phs2 = rng.uniform(0, 2 * np.pi, size=2) # find out when the station is broadcasting is_on = 0.999 * np.cos(lsts * u.sday.to("s") / self.timescale + phs1) is_on = is_on > (1 - 2 * self.duty_cycle) # generate a signal and filter it according to when it's on - signal = np.random.normal(self.strength, self.std, lsts.size) + signal = rng.normal(self.strength, self.std, lsts.size) signal = np.where(is_on, signal, 0) * np.exp(1j * phs2) # now add the signal to the rfi array @@ -484,8 +485,8 @@ def _listify_params(self, bands, *args): "values with the same length as the number of DTV " "bands specified. For reference, the DTV bands you " "specified have the following characteristics: \n" - "f_min : {fmin} \nf_max : {fmax}\n N_bands : " - "{Nchan}".format(fmin=bands[0], fmax=bands[-1], Nchan=Nchan) + f"f_min : {bands[0]} \nf_max : {bands[-1]}\n N_bands : " + f"{Nchan}" ) # everything should be in order now, so diff --git a/hera_sim/simulate.py b/hera_sim/simulate.py index 805e97f7..803f234a 100644 --- a/hera_sim/simulate.py +++ b/hera_sim/simulate.py @@ -727,7 +727,7 @@ def _apply_filter(vis_filter, ant1, ant2, pol): return not pol == vis_filter[0] # Otherwise, assume that this specifies an antenna. else: - return not vis_filter[0] in (ant1, ant2) + return vis_filter[0] not in (ant1, ant2) elif len(vis_filter) == 2: # TODO: This will need to be updated when we support ant strings. # Three cases: two pols; an ant+pol; a baseline. @@ -756,7 +756,7 @@ def _apply_filter(vis_filter, ant1, ant2, pol): for key in vis_filter: if isinstance(key, str): pols.append(key) - elif type(key) is int: + elif isinstance(key, int): ants.append(key) # We want polarization and ant1 or ant2 in the filter. # This would be used in simulating e.g. a few feeds that have an @@ -1228,7 +1228,7 @@ def _seed_rng(self, seed, model, ant1=None, ant2=None, pol=None): """ if seed is None: return - if type(seed) is int: + if isinstance(seed, int): np.random.seed(seed) return seed if not isinstance(seed, str): @@ -1390,7 +1390,7 @@ def _get_component( component: Union[str, type[SimulationComponent], SimulationComponent] ) -> Union[SimulationComponent, type[SimulationComponent]]: """Normalize a component to be either a class or instance.""" - if np.issubclass_(component, SimulationComponent): + if issubclass(component, SimulationComponent): return component elif isinstance(component, str): try: @@ -1435,7 +1435,7 @@ def _get_model_name(model): """Find out the (lowercase) name of a provided model.""" if isinstance(model, str): return model.lower() - elif np.issubclass_(model, SimulationComponent): + elif issubclass(model, SimulationComponent): return model.__name__.lower() elif isinstance(model, SimulationComponent): return model.__class__.__name__.lower() @@ -1472,7 +1472,7 @@ def _parse_key(self, key: Union[int, str, AntPair, AntPairPol]) -> AntPairPol: "pair with a polarization string." ) if len(key) == 2: - if all(type(val) is int for val in key): + if all(isinstance(val, int) for val in key): ant1, ant2 = key pol = None else: diff --git a/hera_sim/tests/test_simulator.py b/hera_sim/tests/test_simulator.py index 8f928bb7..aaa71223 100644 --- a/hera_sim/tests/test_simulator.py +++ b/hera_sim/tests/test_simulator.py @@ -418,10 +418,10 @@ def test_run_sim(): # write something to it with open(tmp_sim_file, "w") as sim_file: sim_file.write( - """ + f""" diffuse_foreground: Tsky_mdl: !Tsky - datafile: {}/HERA_Tsky_Reformatted.npz + datafile: {DATA_PATH}/HERA_Tsky_Reformatted.npz pol: yy pntsrc_foreground: nsrcs: 500 @@ -436,16 +436,14 @@ def test_run_sim(): phs: 2.1123 thermal_noise: Tsky_mdl: !Tsky - datafile: {}/HERA_Tsky_Reformatted.npz + datafile: {DATA_PATH}/HERA_Tsky_Reformatted.npz pol: xx integration_time: 9.72 rfi_scatter: scatter_chance: 0.99 scatter_strength: 5.7 scatter_std: 2.2 - """.format( - DATA_PATH, DATA_PATH - ) + """ ) sim = create_sim(autos=True) sim.run_sim(tmp_sim_file) diff --git a/hera_sim/tests/test_utils.py b/hera_sim/tests/test_utils.py index 4ac51817..c00ad5f7 100644 --- a/hera_sim/tests/test_utils.py +++ b/hera_sim/tests/test_utils.py @@ -326,7 +326,7 @@ def test_use_pre_computed_filter(freqs, lsts, filter_type): @pytest.mark.parametrize("shape", [100, (100, 200)]) def test_gen_white_noise_shape(shape): noise = utils.gen_white_noise(shape) - if type(shape) is int: + if isinstance(shape, int): shape = (shape,) assert noise.shape == shape @@ -384,7 +384,7 @@ def test_Jy2T(freqs, omega_p): @pytest.mark.parametrize("obj", [1, (1, 2), "abc", np.array([13])]) def test_listify(obj): - assert type(utils._listify(obj)) is list + assert isinstance(utils._listify(obj), list) @pytest.mark.parametrize("jit", [True, False]) diff --git a/pyproject.toml b/pyproject.toml index fac1440f..8e595f40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,6 @@ build-backend = "setuptools.build_meta" [tool.black] line-length = 88 -py36 = false exclude = ''' /( \.eggs @@ -21,3 +20,15 @@ exclude = ''' | dist )/ ''' + +[tool.ruff] +line-length = 88 +target-version = "py39" + +[tool.ruff.lint] +select = [ + "UP", # pyupgrade + "E", # pycodestyle + "W", # pycodestyle warning + "NPY", # numpy-specific rules +] diff --git a/scripts/hera-sim-simulate.py b/scripts/hera-sim-simulate.py index d91e9136..2076114a 100644 --- a/scripts/hera-sim-simulate.py +++ b/scripts/hera-sim-simulate.py @@ -157,7 +157,7 @@ this_sim.data.data_array = data if bda_params: this_sim.data = bda_tools.apply_bda(this_sim.data, **bda_params) - if type(data) is dict: + if isinstance(data, dict): # The component is a gain-like term, so save as a calfits file. ext = os.path.splitext(filename)[1] if ext == "": From 8b1b79d6be2c8e215e9b96c634445862a1208c19 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Wed, 3 Apr 2024 14:36:45 +0200 Subject: [PATCH 03/13] ci: add ruff check for npy --- .pre-commit-config.yaml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 360c3b90..d712b921 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,4 +57,12 @@ repos: - repo: https://github.com/asottile/setup-cfg-fmt rev: v2.5.0 hooks: - - id: setup-cfg-fmt \ No newline at end of file + - id: setup-cfg-fmt + +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.1.4 + hooks: + # Run the linter. + - id: ruff + args: [--fix] From 80ab80a081387883799b331a51899d4de3196d3e Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Wed, 3 Apr 2024 16:33:53 +0200 Subject: [PATCH 04/13] fix: bugs in updating numpy stuff --- hera_sim/cli_utils.py | 2 +- hera_sim/simulate.py | 14 +++++++++----- hera_sim/visibilities/simulators.py | 14 ++++++++------ 3 files changed, 18 insertions(+), 12 deletions(-) diff --git a/hera_sim/cli_utils.py b/hera_sim/cli_utils.py index 95181d66..1ea14a33 100644 --- a/hera_sim/cli_utils.py +++ b/hera_sim/cli_utils.py @@ -155,7 +155,7 @@ def write_calfits( # Update gain keys to conform to write_cal assumptions. # New Simulator gains have keys (ant, pol), so shouldn't need # special pre-processing. - if all(issubclass(type(ant), np.generic) for ant in gains.keys()): + if all(np.isscalar(ant) for ant in gains.keys()): # Old-style, single polarization assumption. gains = {(ant, "Jee"): gain for ant, gain in gains.items()} diff --git a/hera_sim/simulate.py b/hera_sim/simulate.py index 803f234a..e4ef5d3d 100644 --- a/hera_sim/simulate.py +++ b/hera_sim/simulate.py @@ -6,6 +6,7 @@ :class:`Simulator`, please refer to the tutorials. """ +import contextlib import functools import inspect import numpy as np @@ -1390,9 +1391,7 @@ def _get_component( component: Union[str, type[SimulationComponent], SimulationComponent] ) -> Union[SimulationComponent, type[SimulationComponent]]: """Normalize a component to be either a class or instance.""" - if issubclass(component, SimulationComponent): - return component - elif isinstance(component, str): + if isinstance(component, str): try: return get_model(component) except KeyError: @@ -1403,6 +1402,9 @@ def _get_component( elif isinstance(component, SimulationComponent): return component else: + with contextlib.suppress(TypeError): + if issubclass(component, SimulationComponent): + return component raise TypeError( "The input type for the component was not understood. " "Must be a string, or a class/instance of type 'SimulationComponent'. " @@ -1435,11 +1437,13 @@ def _get_model_name(model): """Find out the (lowercase) name of a provided model.""" if isinstance(model, str): return model.lower() - elif issubclass(model, SimulationComponent): - return model.__name__.lower() elif isinstance(model, SimulationComponent): return model.__class__.__name__.lower() else: + with contextlib.suppress(TypeError): + if issubclass(model, SimulationComponent): + return model.__name__.lower() + raise TypeError( "You are trying to simulate an effect using a custom function. " "Please refer to the tutorial for instructions regarding how " diff --git a/hera_sim/visibilities/simulators.py b/hera_sim/visibilities/simulators.py index 18f35b28..53f64a58 100644 --- a/hera_sim/visibilities/simulators.py +++ b/hera_sim/visibilities/simulators.py @@ -494,11 +494,13 @@ def load_simulator_from_yaml(config: Path | str) -> VisibilitySimulator: module = importlib.import_module(module) simulator_cls = getattr(module, simulator_cls.split(".")[-1]) - if not issubclass(simulator_cls, VisibilitySimulator): - raise TypeError( - f"Specified simulator {simulator_cls} is not a subclass of" - "VisibilitySimulator!" - ) + try: + if not issubclass(simulator_cls, VisibilitySimulator): + raise TypeError( + f"Specified simulator {simulator_cls} is not a subclass of" + "VisibilitySimulator!" + ) + except TypeError as e: + raise TypeError(f"Specified simulator {simulator_cls} is not a class!") from e - assert issubclass(simulator_cls, VisibilitySimulator) return simulator_cls.from_yaml(cfg) From 976c9173a27aa2c98eff74823c095c465a6321ea Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Fri, 19 Apr 2024 11:22:20 +0200 Subject: [PATCH 05/13] maint: revert rng changes --- hera_sim/foregrounds.py | 12 ++++++------ hera_sim/rfi.py | 5 ++--- pyproject.toml | 4 ++++ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/hera_sim/foregrounds.py b/hera_sim/foregrounds.py index fdebad62..f8e54863 100644 --- a/hera_sim/foregrounds.py +++ b/hera_sim/foregrounds.py @@ -264,13 +264,13 @@ def __call__(self, lsts, freqs, bl_vec, **kwargs): # get baseline length (it should already be in ns) bl_len_ns = np.linalg.norm(bl_vec) - rng = np.random.default_rng() - # randomly generate source RAs - ras = rng.uniform(0, 2 * np.pi, nsrcs) + ras = np.random.uniform(0, 2 * np.pi, nsrcs) # draw spectral indices from normal distribution - spec_indices = rng.normal(spectral_index_mean, spectral_index_std, size=nsrcs) + spec_indices = np.random.normal( + spectral_index_mean, spectral_index_std, size=nsrcs + ) # calculate beam width, hardcoded for HERA beam_width = (40 * 60) * (f0 / freqs) / units.sday.to("s") * 2 * np.pi @@ -278,7 +278,7 @@ def __call__(self, lsts, freqs, bl_vec, **kwargs): # draw flux densities from a power law alpha = beta + 1 flux_densities = ( - Smax**alpha + Smin**alpha * (1 - rng.uniform(size=nsrcs)) + Smax**alpha + Smin**alpha * (1 - np.random.uniform(size=nsrcs)) ) ** (1 / alpha) # initialize the visibility array @@ -290,7 +290,7 @@ def __call__(self, lsts, freqs, bl_vec, **kwargs): lst_ind = np.argmin(np.abs(utils.compute_ha(lsts, ra))) # slight offset in delay? why?? - dtau = rng.uniform(-1, 1) * 0.1 * bl_len_ns + dtau = np.random.uniform(-1, 1) * 0.1 * bl_len_ns # fill in the corresponding region of the visibility array vis[lst_ind, :] += flux * (freqs / f0) ** index diff --git a/hera_sim/rfi.py b/hera_sim/rfi.py index a859232b..b0d62996 100644 --- a/hera_sim/rfi.py +++ b/hera_sim/rfi.py @@ -96,15 +96,14 @@ def __call__(self, lsts, freqs): ch2 = ch1 + 1 if self.f0 > freqs[ch1] else ch1 - 1 # generate some random phases - rng = np.random.default_rng() - phs1, phs2 = rng.uniform(0, 2 * np.pi, size=2) + phs1, phs2 = np.random.uniform(0, 2 * np.pi, size=2) # find out when the station is broadcasting is_on = 0.999 * np.cos(lsts * u.sday.to("s") / self.timescale + phs1) is_on = is_on > (1 - 2 * self.duty_cycle) # generate a signal and filter it according to when it's on - signal = rng.normal(self.strength, self.std, lsts.size) + signal = np.random.normal(self.strength, self.std, lsts.size) signal = np.where(is_on, signal, 0) * np.exp(1j * phs2) # now add the signal to the rfi array diff --git a/pyproject.toml b/pyproject.toml index 8e595f40..894995d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,3 +32,7 @@ select = [ "W", # pycodestyle warning "NPY", # numpy-specific rules ] + +ignore = [ + "NPY002", # RNG -- fix soon! +] From 097636db0be65ce4f7efda2d4b620587f9001146 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Tue, 23 Apr 2024 10:19:11 +0200 Subject: [PATCH 06/13] maint: drop 3.9, add 3.12 --- .github/workflows/test_suite.yaml | 2 +- CHANGELOG.rst | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_suite.yaml b/.github/workflows/test_suite.yaml index eb598903..dba2d54e 100644 --- a/.github/workflows/test_suite.yaml +++ b/.github/workflows/test_suite.yaml @@ -20,7 +20,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python-version: [3.9, "3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@main diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 309bc491..de32cff9 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -5,6 +5,11 @@ Changelog dev === +Deprecated +---------- + +- Support for Python 3.9 has been dropped. + Fixed ----- - API calls for pyuvdata v2.4.0. From e86694dc371206038dc55742e9a811d5b7712362 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Tue, 23 Apr 2024 10:40:05 +0200 Subject: [PATCH 07/13] test: add tests of bad specification of simulator.yaml --- hera_sim/tests/test_vis.py | 12 ++++++++++++ hera_sim/visibilities/simulators.py | 4 ++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/hera_sim/tests/test_vis.py b/hera_sim/tests/test_vis.py index 7cc13e47..a094c129 100644 --- a/hera_sim/tests/test_vis.py +++ b/hera_sim/tests/test_vis.py @@ -630,3 +630,15 @@ def test_bad_load(tmpdir): with pytest.raises(AttributeError, match="The given simulator"): load_simulator_from_yaml(tmpdir / "bad_sim.yaml") + + with open(tmpdir / "bad_sim.yaml", "w") as fl: + fl.write("""simulator: hera_sim.foregrounds.DiffuseForeground\n""") + + with pytest.raises(ValueError, match="is not a subclass of VisibilitySimulator"): + load_simulator_from_yaml(tmpdir / "bad_sim.yaml") + + with open(tmpdir / "bad_sim.yaml", "w") as fl: + fl.write("""simulator: hera_sim.foregrounds.diffuse_foreground\n""") + + with pytest.raises(TypeError, match="is not a class"): + load_simulator_from_yaml(tmpdir / "bad_sim.yaml") diff --git a/hera_sim/visibilities/simulators.py b/hera_sim/visibilities/simulators.py index 53f64a58..6d20ca64 100644 --- a/hera_sim/visibilities/simulators.py +++ b/hera_sim/visibilities/simulators.py @@ -496,8 +496,8 @@ def load_simulator_from_yaml(config: Path | str) -> VisibilitySimulator: try: if not issubclass(simulator_cls, VisibilitySimulator): - raise TypeError( - f"Specified simulator {simulator_cls} is not a subclass of" + raise ValueError( + f"Specified simulator {simulator_cls} is not a subclass of " "VisibilitySimulator!" ) except TypeError as e: From c082fe1c759fcc8d829b2b10db37e5f748d82347 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Tue, 23 Apr 2024 11:25:14 +0200 Subject: [PATCH 08/13] test: add tests of corner-cases --- hera_sim/simulate.py | 19 +++++++++++-- hera_sim/tests/test_simulator.py | 46 ++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/hera_sim/simulate.py b/hera_sim/simulate.py index e4ef5d3d..114ff3ef 100644 --- a/hera_sim/simulate.py +++ b/hera_sim/simulate.py @@ -19,6 +19,7 @@ from deprecation import deprecated from pathlib import Path from pyuvdata import UVData +from pyuvdata import utils as uvutils from typing import Optional, Union from . import __version__, io, utils @@ -1464,16 +1465,30 @@ def _parse_key(self, key: Union[int, str, AntPair, AntPairPol]) -> AntPairPol: elif isinstance(key, str): if key.lower() in ("auto", "cross"): raise NotImplementedError("Functionality not yet supported.") + if key.lower() not in { + **uvutils.POL_STR2NUM_DICT, + **uvutils.JONES_STR2NUM_DICT, + **uvutils.CONJ_POL_DICT, + }: + raise ValueError(f"Invalid polarization string: {key}.") ant1, ant2, pol = None, None, key else: try: iter(key) - if len(key) not in (2, 3): + if len(key) not in (2, 3) or ( + len(key) == 3 + and not ( + isinstance(key[0], int) + and isinstance(key[1], int) + and isinstance(key[2], str) + ) + ): raise TypeError + except TypeError: raise ValueError( "Key must be an integer, string, antenna pair, or antenna " - "pair with a polarization string." + f"pair with a polarization string. Got {key}" ) if len(key) == 2: if all(isinstance(val, int) for val in key): diff --git a/hera_sim/tests/test_simulator.py b/hera_sim/tests/test_simulator.py index aaa71223..24d37d93 100644 --- a/hera_sim/tests/test_simulator.py +++ b/hera_sim/tests/test_simulator.py @@ -702,3 +702,49 @@ def test_cached_filters(): sim2.add("diffuse_foreground", seed=seed) defaults.deactivate() assert np.allclose(sim1.data.data_array, sim2.data.data_array) + + +def test_get_model_name(): + assert Simulator._get_model_name("noiselike_eor") == "noiselike_eor" + assert Simulator._get_model_name("NOISELIKE_EOR") == "noiselike_eor" + + assert Simulator._get_model_name(DiffuseForeground) == "diffuseforeground" + assert Simulator._get_model_name(diffuse_foreground) == "diffuseforeground" + + with pytest.raises( + TypeError, match="You are trying to simulate an effect using a custom function" + ): + Simulator._get_model_name(lambda x: x) + + with pytest.raises( + TypeError, match="You are trying to simulate an effect using a custom function" + ): + Simulator._get_model_name(3) + + +def test_parse_key(base_sim: Simulator): + assert base_sim._parse_key(None) == (None, None, None) + assert base_sim._parse_key(1) == (1, None, None) + assert base_sim._parse_key( + base_sim.data.baseline_array[-1] + ) == base_sim.data.baseline_to_antnums(base_sim.data.baseline_array[-1]) + (None,) + + with pytest.raises(NotImplementedError, match="Functionality not yet supported"): + base_sim._parse_key("auto") + + assert base_sim._parse_key("ee") == (None, None, "ee") + + for badkey in [3.14, [1, 2, 3], (1,)]: + print(badkey) + with pytest.raises( + ValueError, + match="Key must be an integer, string, antenna pair, or antenna pair with", + ): + base_sim._parse_key(badkey) + + with pytest.raises(ValueError, match="Invalid polarization string"): + base_sim._parse_key("bad_pol") + + assert base_sim._parse_key((1, 2)) == (1, 2, None) + assert base_sim._parse_key((1, "Jee")) == (1, None, "Jee") + assert base_sim._parse_key((1, 2, "ee")) == (1, 2, "ee") From 5c6f74bb80684eb5e5c406f1b83c2782442c5e54 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Tue, 23 Apr 2024 11:37:33 +0200 Subject: [PATCH 09/13] fix: _parse_key for (int, int, None) --- hera_sim/simulate.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hera_sim/simulate.py b/hera_sim/simulate.py index 114ff3ef..88c66b52 100644 --- a/hera_sim/simulate.py +++ b/hera_sim/simulate.py @@ -1478,9 +1478,9 @@ def _parse_key(self, key: Union[int, str, AntPair, AntPairPol]) -> AntPairPol: if len(key) not in (2, 3) or ( len(key) == 3 and not ( - isinstance(key[0], int) - and isinstance(key[1], int) - and isinstance(key[2], str) + isinstance(key[0], (int, None)) + and isinstance(key[1], (int, None)) + and isinstance(key[2], (str, None)) ) ): raise TypeError From c23e390069687c2681eff3fff70f8202cface4df Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Tue, 23 Apr 2024 15:41:18 +0200 Subject: [PATCH 10/13] fix: corner cases for _parse_key --- hera_sim/simulate.py | 65 +++++++++++++++++++++++++++----------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/hera_sim/simulate.py b/hera_sim/simulate.py index 88c66b52..9f10abd4 100644 --- a/hera_sim/simulate.py +++ b/hera_sim/simulate.py @@ -1453,6 +1453,28 @@ def _get_model_name(model): def _parse_key(self, key: Union[int, str, AntPair, AntPairPol]) -> AntPairPol: """Convert a key of at-most length-3 to an (ant1, ant2, pol) tuple.""" + valid_pols = { + k.lower() + for k in { + **uvutils.POL_STR2NUM_DICT, + **uvutils.JONES_STR2NUM_DICT, + **uvutils.CONJ_POL_DICT, + } + } + valid_pols.update({"jee", "jen", "jne", "jnn"}) + + def checkpol(pol): + if pol is None: + return None + + if not isinstance(pol, str): + raise TypeError(f"Invalid polarization type: {type(pol)}.") + + if pol.lower() not in valid_pols: + raise ValueError(f"Invalid polarization string: {pol}.") + + return pol + if key is None: ant1, ant2, pol = None, None, None elif np.issubdtype(type(key), int): @@ -1465,40 +1487,33 @@ def _parse_key(self, key: Union[int, str, AntPair, AntPairPol]) -> AntPairPol: elif isinstance(key, str): if key.lower() in ("auto", "cross"): raise NotImplementedError("Functionality not yet supported.") - if key.lower() not in { - **uvutils.POL_STR2NUM_DICT, - **uvutils.JONES_STR2NUM_DICT, - **uvutils.CONJ_POL_DICT, - }: - raise ValueError(f"Invalid polarization string: {key}.") + key = checkpol(key) ant1, ant2, pol = None, None, key else: + + def intify(x): + return x if x is None else int(x) + try: - iter(key) - if len(key) not in (2, 3) or ( - len(key) == 3 - and not ( - isinstance(key[0], (int, None)) - and isinstance(key[1], (int, None)) - and isinstance(key[2], (str, None)) - ) - ): + iter(key) # ensure it's iterable + if len(key) not in (2, 3): raise TypeError + if len(key) == 2: + if all(isinstance(val, int) for val in key): + ant1, ant2 = key + pol = None + else: + ant1, pol = intify(key[0]), checkpol(key[1]) + ant2 = None + else: + ant1, ant2, pol = intify(key[0]), intify(key[1]), checkpol(key[2]) + except TypeError: raise ValueError( "Key must be an integer, string, antenna pair, or antenna " - f"pair with a polarization string. Got {key}" + f"pair with a polarization string. Got {key}." ) - if len(key) == 2: - if all(isinstance(val, int) for val in key): - ant1, ant2 = key - pol = None - else: - ant1, pol = key - ant2 = None - else: - ant1, ant2, pol = key return ant1, ant2, pol def _sanity_check(self, model): From e2729a609335845614ba5a417c747d4e8cb212d1 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Tue, 23 Apr 2024 19:02:26 +0200 Subject: [PATCH 11/13] test: more resilient equality testing --- hera_sim/tests/test_simulator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/hera_sim/tests/test_simulator.py b/hera_sim/tests/test_simulator.py index 24d37d93..a15cd424 100644 --- a/hera_sim/tests/test_simulator.py +++ b/hera_sim/tests/test_simulator.py @@ -285,18 +285,19 @@ def test_get_multiplicative_effect(base_sim, pol, ant1): gains = base_sim.add("gains", seed="once", ret_vis=True) _gains = base_sim.get("gains", key=(ant1, pol)) if pol is not None and ant1 is not None: - assert np.all(gains[(ant1, pol)] == _gains) + assert np.allclose(gains[(ant1, pol)] == _gains) elif pol is None and ant1 is not None: assert all( - np.all(gains[(ant1, _pol)] == _gains[(ant1, _pol)]) + np.allclose(gains[(ant1, _pol)] == _gains[(ant1, _pol)]) for _pol in base_sim.data.get_feedpols() ) elif pol is not None and ant1 is None: assert all( - np.all(gains[(ant, pol)] == _gains[(ant, pol)]) for ant in base_sim.antpos + np.allclose(gains[(ant, pol)] == _gains[(ant, pol)]) + for ant in base_sim.antpos ) else: - assert all(np.all(gains[antpol] == _gains[antpol]) for antpol in gains) + assert all(np.allclose(gains[antpol] == _gains[antpol]) for antpol in gains) def test_not_add_vis(base_sim): From 9455392c5c80e0dffa11aee9db39149353b4d195 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Tue, 23 Apr 2024 19:52:29 +0200 Subject: [PATCH 12/13] test: more resilient equality testing --- hera_sim/tests/test_simulator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/hera_sim/tests/test_simulator.py b/hera_sim/tests/test_simulator.py index a15cd424..8905164e 100644 --- a/hera_sim/tests/test_simulator.py +++ b/hera_sim/tests/test_simulator.py @@ -285,19 +285,19 @@ def test_get_multiplicative_effect(base_sim, pol, ant1): gains = base_sim.add("gains", seed="once", ret_vis=True) _gains = base_sim.get("gains", key=(ant1, pol)) if pol is not None and ant1 is not None: - assert np.allclose(gains[(ant1, pol)] == _gains) + assert np.allclose(gains[(ant1, pol)], _gains) elif pol is None and ant1 is not None: assert all( - np.allclose(gains[(ant1, _pol)] == _gains[(ant1, _pol)]) + np.allclose(gains[(ant1, _pol)], _gains[(ant1, _pol)]) for _pol in base_sim.data.get_feedpols() ) elif pol is not None and ant1 is None: assert all( - np.allclose(gains[(ant, pol)] == _gains[(ant, pol)]) + np.allclose(gains[(ant, pol)], _gains[(ant, pol)]) for ant in base_sim.antpos ) else: - assert all(np.allclose(gains[antpol] == _gains[antpol]) for antpol in gains) + assert all(np.allclose(gains[antpol], _gains[antpol]) for antpol in gains) def test_not_add_vis(base_sim): From 921110aeb4f38b654a6469f724ba3fe42f2fe6e2 Mon Sep 17 00:00:00 2001 From: Steven Murray Date: Tue, 23 Apr 2024 22:16:02 +0200 Subject: [PATCH 13/13] style: pre-commit autoupdate --- .flake8 | 2 ++ .pre-commit-config.yaml | 9 ++++----- hera_sim/tests/test_sim_red_data.py | 4 ++-- hera_sim/visibilities/cli.py | 4 ++-- hera_sim/visibilities/matvis.py | 2 +- 5 files changed, 11 insertions(+), 10 deletions(-) diff --git a/.flake8 b/.flake8 index 267e45f2..04fb6718 100644 --- a/.flake8 +++ b/.flake8 @@ -17,6 +17,8 @@ ignore = G004, # Logging statement uses + (this makes no sense...) G003, + # Allow builtin module names + A005, max-line-length = 88 # Should be 18. max-complexity = 35 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d712b921..4269a723 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ exclude: 'config.yaml|config_examples/.*.yaml|hera_sim/config/H1C.yaml|hera_sim/ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.5.0 + rev: v4.6.0 hooks: - id: trailing-whitespace - id: check-added-large-files @@ -25,7 +25,6 @@ repos: - flake8-rst-docstrings #- flake8-docstrings # not available for flake8>5 - flake8-builtins - - flake8-logging-format - flake8-rst-docstrings - flake8-rst # - flake8-markdown # not available for flake8>5 (check later...) @@ -34,7 +33,7 @@ repos: - flake8-print - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.2.0 + rev: 24.4.0 hooks: - id: black @@ -49,7 +48,7 @@ repos: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v3.15.1 + rev: v3.15.2 hooks: - id: pyupgrade args: [--py39-plus] @@ -61,7 +60,7 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.1.4 + rev: v0.4.1 hooks: # Run the linter. - id: ruff diff --git a/hera_sim/tests/test_sim_red_data.py b/hera_sim/tests/test_sim_red_data.py index 4137ee7f..f4ee2806 100644 --- a/hera_sim/tests/test_sim_red_data.py +++ b/hera_sim/tests/test_sim_red_data.py @@ -58,13 +58,13 @@ def test_sim_red_data_4pol(antpos, gain_data): for pol in ["xx", "xy", "yx", "yy"]: ai, aj, pol = bl0 ans0 = data[ai, aj, pol] / ( - gains[(ai, f"J{pol[0]*2}")] * gains[(aj, f"J{pol[1]*2}")].conj() + gains[(ai, f"J{pol[0] * 2}")] * gains[(aj, f"J{pol[1] * 2}")].conj() ) for bl in bls[1:]: ai, aj, pol = bl ans = data[ai, aj, pol] / ( - gains[(ai, f"J{pol[0]*2}")] * gains[(aj, f"J{pol[1]*2}")].conj() + gains[(ai, f"J{pol[0] * 2}")] * gains[(aj, f"J{pol[1] * 2}")].conj() ) # compare calibrated visibilities knowing the input gains diff --git a/hera_sim/visibilities/cli.py b/hera_sim/visibilities/cli.py index 1b035dba..fa9c005e 100644 --- a/hera_sim/visibilities/cli.py +++ b/hera_sim/visibilities/cli.py @@ -139,8 +139,8 @@ def run_vis_sim(args): ram_avail = psutil.virtual_memory().available / 1024**3 cprint( - f"[bold {'red' if ram < 1.5*ram_avail else 'green'}] This simulation will use " - f"at least {ram:.2f}GB of RAM (Available: {ram_avail:.2f}GB).[/]" + f"[bold {'red' if ram < 1.5 * ram_avail else 'green'}] This simulation will use" + f" at least {ram:.2f}GB of RAM (Available: {ram_avail:.2f}GB).[/]" ) if args.object_name is None: diff --git a/hera_sim/visibilities/matvis.py b/hera_sim/visibilities/matvis.py index fe7f54b2..e602e47e 100644 --- a/hera_sim/visibilities/matvis.py +++ b/hera_sim/visibilities/matvis.py @@ -378,7 +378,7 @@ def simulate(self, data_model): if self.mpi_comm is not None and i % nproc != myid: continue - logger.info(f"Simulating Frequency {i+1}/{len(data_model.freqs)}") + logger.info(f"Simulating Frequency {i + 1}/{len(data_model.freqs)}") # Call matvis function to simulate visibilities vis = self._matvis(