diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f4712a9..92d2ff14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,8 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [0.2.2] - Unreleased +## [0.3.0] - Unreleased ### Added +- Support for averaging incoming visibilities in time and frequency *before* + doing any work on them. +- When writing out visibilities, it is now possible to write out the smallest + contiguous band of unflagged channels. - Plots can be written to a specific directory, not only the CWD. Fixes #18. - Support for visibilities using the ant2-ant1 ordering rather than ant1-ant2. - Add new errors @@ -18,6 +22,10 @@ Versioning](https://semver.org/spec/v2.0.0.html). - Benchmarks - Raw MWA, uvfits and measurement set reading. - More CUDA benchmarks for the modelling code. +- Support for "argument files". This is an advanced feature that most users + probably should avoid. Previously, argument files were supported for the + di-calibrate subcommand, but now it is more consistently supported among the + "big" subcommands. ### Fixed - When raw MWA data is missing gpubox files in what is otherwise a contiguous @@ -34,9 +42,16 @@ Versioning](https://semver.org/spec/v2.0.0.html). even with the explicit "ignore input data tile flags". - Some aspects of hyperdrive weren't using user-specified array positions correctly. The help text also indicated the wrong units. +- Fine-channel flags and fine-channel-per-coarse channel flags are now checked + for validity. ### Changed - The performance of CPU visibility modelling has been dramatically improved. +- The command-line interface has been overhauled. Some things may be different, + but generally the options and flags are much more consistent between + subcommands. +- The preamble to "big" subcommands, like di-calibrate, has been overhauled to + be much easier to read. - Plotting - Legend labels have changed to $g_x$, $D_x$, $D_y$, $g_y$ ($g$ for gain, $D$ for leakage). Thanks Jack Line. diff --git a/Cargo.toml b/Cargo.toml index 9b1a8a3a..613b91a0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,7 +113,7 @@ approx = "0.5.1" assert_cmd = "2.0.0" criterion = { version = "0.4.0", default_features = false } indoc = "2.0.1" -marlu = { version = "0.10.0", features = ["approx"] } +marlu = { version = "0.10.1", features = ["approx"] } ndarray = { version = "0.15.4", features = ["approx-0_5"] } serial_test = "2.0.0" tar = "0.4.38" diff --git a/benches/bench.rs b/benches/bench.rs index 66fe16c2..9fec522c 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -18,16 +18,14 @@ use tempfile::Builder; use vec1::{vec1, Vec1}; use mwa_hyperdrive::{ - averaging::{Chanblock, Timeblock}, - beam::{create_fee_beam_object, Delays}, - di_calibrate::calibrate_timeblocks, - model, + calibrate_timeblocks, create_fee_beam_object, + model::{self, SkyModeller}, srclist::{ get_instrumental_flux_densities, ComponentType, FluxDensity, FluxDensityType, ShapeletCoeff, Source, SourceComponent, SourceList, }, - CrossData, MsReader, Polarisations, RawDataCorrections, RawDataReader, TileBaselineFlags, - UvfitsReader, + Chanblock, CrossData, Delays, MsReader, Polarisations, RawDataCorrections, RawDataReader, + TileBaselineFlags, Timeblock, UvfitsReader, }; fn model_benchmarks(c: &mut Criterion) { @@ -73,9 +71,7 @@ fn model_benchmarks(c: &mut Criterion) { }, ); } - let modeller = model::new_sky_modeller( - #[cfg(feature = "cuda")] - true, + let modeller = model::SkyModellerCpu::new( &*beam, &source_list, Polarisations::default(), @@ -87,8 +83,7 @@ fn model_benchmarks(c: &mut Criterion) { MWA_LAT_RAD, dut1, apply_precession, - ) - .unwrap(); + ); b.iter(|| { modeller @@ -137,8 +132,7 @@ fn model_benchmarks(c: &mut Criterion) { }, ); } - let modeller = model::new_sky_modeller( - false, + let modeller = model::SkyModellerCuda::new( &*beam, &source_list, Polarisations::default(), @@ -196,9 +190,7 @@ fn model_benchmarks(c: &mut Criterion) { }, ); } - let modeller = model::new_sky_modeller( - #[cfg(feature = "cuda")] - true, + let modeller = model::SkyModellerCpu::new( &*beam, &source_list, Polarisations::default(), @@ -210,8 +202,7 @@ fn model_benchmarks(c: &mut Criterion) { MWA_LAT_RAD, dut1, apply_precession, - ) - .unwrap(); + ); b.iter(|| { modeller @@ -262,8 +253,7 @@ fn model_benchmarks(c: &mut Criterion) { }, ); } - let modeller = model::new_sky_modeller( - false, + let modeller = model::SkyModellerCuda::new( &*beam, &source_list, Polarisations::default(), @@ -332,9 +322,7 @@ fn model_benchmarks(c: &mut Criterion) { }, ); } - let modeller = model::new_sky_modeller( - #[cfg(feature = "cuda")] - true, + let modeller = model::SkyModellerCpu::new( &*beam, &source_list, Polarisations::default(), @@ -346,8 +334,7 @@ fn model_benchmarks(c: &mut Criterion) { MWA_LAT_RAD, dut1, apply_precession, - ) - .unwrap(); + ); b.iter(|| { modeller @@ -408,8 +395,7 @@ fn model_benchmarks(c: &mut Criterion) { }, ); } - let modeller = model::new_sky_modeller( - false, + let modeller = model::SkyModellerCuda::new( &*beam, &source_list, Polarisations::default(), @@ -443,16 +429,16 @@ fn calibrate_benchmarks(c: &mut Criterion) { index: 0, range: 0..num_timesteps, timestamps: vec1![ - Epoch::from_gpst_seconds(1090008640.0), - Epoch::from_gpst_seconds(1090008641.0), - Epoch::from_gpst_seconds(1090008642.0), - Epoch::from_gpst_seconds(1090008643.0), - Epoch::from_gpst_seconds(1090008644.0), - Epoch::from_gpst_seconds(1090008645.0), - Epoch::from_gpst_seconds(1090008646.0), - Epoch::from_gpst_seconds(1090008647.0), - Epoch::from_gpst_seconds(1090008648.0), - Epoch::from_gpst_seconds(1090008649.0), + (Epoch::from_gpst_seconds(1090008640.0), 0), + (Epoch::from_gpst_seconds(1090008641.0), 1), + (Epoch::from_gpst_seconds(1090008642.0), 2), + (Epoch::from_gpst_seconds(1090008643.0), 3), + (Epoch::from_gpst_seconds(1090008644.0), 4), + (Epoch::from_gpst_seconds(1090008645.0), 5), + (Epoch::from_gpst_seconds(1090008646.0), 6), + (Epoch::from_gpst_seconds(1090008647.0), 7), + (Epoch::from_gpst_seconds(1090008648.0), 8), + (Epoch::from_gpst_seconds(1090008649.0), 9), ], median: Epoch::from_gpst_seconds(1090008644.5), }); @@ -464,7 +450,7 @@ fn calibrate_benchmarks(c: &mut Criterion) { chanblocks.push(Chanblock { chanblock_index: i_chanblock as _, unflagged_index: i_chanblock as _, - _freq: 150e6 + i_chanblock as f64, + freq: 150e6 + i_chanblock as f64, }) } let chanblocks = Vec1::try_from_vec(chanblocks).unwrap(); @@ -490,7 +476,6 @@ fn calibrate_benchmarks(c: &mut Criterion) { 1e-4, Polarisations::default(), false, - false, ); }); } @@ -562,9 +547,10 @@ fn io_benchmarks(c: &mut Criterion) { }; // Open the readers. + let gpuboxes = [gpubox]; let raw = RawDataReader::new( - &&metafits, - &[&gpubox], + &metafits, + &gpuboxes, None, RawDataCorrections::do_nothing(), None, @@ -619,8 +605,8 @@ fn io_benchmarks(c: &mut Criterion) { |b| { // Need to re-make the reader with the new corrections. let raw = RawDataReader::new( - &&metafits, - &[&gpubox], + &metafits, + &gpuboxes, None, RawDataCorrections::default(), None, diff --git a/mdbook/src/SUMMARY.md b/mdbook/src/SUMMARY.md index 31cdac76..ac87e3f1 100644 --- a/mdbook/src/SUMMARY.md +++ b/mdbook/src/SUMMARY.md @@ -26,6 +26,7 @@ - [Apply solutions](user/solutions_apply/intro.md) - [Simple usage](user/solutions_apply/simple.md) - [Plot solutions](user/plotting.md) +- [Convert visibilities](user/vis_convert/intro.md) - [Simulate visibilities](user/vis_simulate/intro.md) - [Subtract visibilities](user/vis_subtract/intro.md) diff --git a/mdbook/src/user/vis_convert/intro.md b/mdbook/src/user/vis_convert/intro.md new file mode 100644 index 00000000..46742887 --- /dev/null +++ b/mdbook/src/user/vis_convert/intro.md @@ -0,0 +1,21 @@ +# Convert visibilities + +`vis-convert` reads in visibilities and writes them out, performing whatever +transformations were requested on the way (e.g. ignore autos, average to a +particular time resolution, flag some tiles, etc.). + +~~~admonish info title="Simple examples" +```shell +hyperdrive vis-convert \ + -d *gpubox* *.metafits \ + --tile-flags Tile011 Tile012 \ + -o hyp_converted.uvfits hyp_converted.ms +``` + +```shell +hyperdrive vis-convert \ + -d *.uvfits \ + --no-autos \ + -o hyp_converted.ms +``` +~~~ diff --git a/src/averaging/mod.rs b/src/averaging/mod.rs index 7dae8e3d..8a7e700d 100644 --- a/src/averaging/mod.rs +++ b/src/averaging/mod.rs @@ -10,10 +10,12 @@ mod tests; pub(crate) use error::AverageFactorError; -use std::collections::HashSet; -use std::ops::Range; +use std::{collections::HashSet, num::NonZeroUsize, ops::Range}; use hifitime::{Duration, Epoch}; +use itertools::Itertools; +use marlu::Jones; +use ndarray::prelude::*; use vec1::Vec1; use crate::unit_parsing::{parse_freq, parse_time, FreqFormat, TimeFormat}; @@ -22,7 +24,7 @@ use crate::unit_parsing::{parse_freq, parse_time, FreqFormat, TimeFormat}; #[derive(Debug, Clone)] pub struct Timeblock { /// The timeblock index. e.g. If all observation timesteps are being used in - /// a single calibration timeblock, then its index is 0. + /// a single timeblock, then this index is 0. pub index: usize, /// The range of indices into an *unflagged* array of visibilities. @@ -30,10 +32,9 @@ pub struct Timeblock { /// The timesteps comprising a timeblock need not be contiguous, however, we /// want the timestep visibilities to be contiguous. Here, `range` indicates /// the *unflagged* timestep indices *for this timeblock*. e.g. If timeblock - /// 0 represents timestep 10 and timeblock 1 represents timesteps 15 and 16 - /// (and these are the only timesteps used for calibration), then timeblock - /// 0's range is 0..1 (only one index, 0), whereas timeblock 1's range is - /// 1..3 (two indices starting at 1). + /// 0 represents timestep 10 and timeblock 1 represents timesteps 15 and + /// 16 , then timeblock 0's range is 0..1 (only one index, 0), whereas + /// timeblock 1's range is 1..3 (indices 1 and 2). /// /// We can use a range because the timesteps belonging to a timeblock are /// always contiguous. @@ -43,6 +44,10 @@ pub struct Timeblock { /// timesteps into all available timestamps. pub timestamps: Vec1, + /// These are the indices (0 indexed) that map the incoming timestamps to + /// the timestamps that are available in this `Timeblock`. + pub timesteps: Vec1, + /// The median timestamp of the *ideal* timeblock. /// /// e.g. If we have 9 timesteps and we're averaging 3, the averaged @@ -64,82 +69,91 @@ pub struct Timeblock { #[derive(Debug, Clone)] pub struct Chanblock { /// The chanblock index, regardless of flagging. e.g. If the first two - /// calibration chanblocks are flagged, then the first unflagged chanblock - /// has a chanblock_index of 2 but an unflagged_index of 0. + /// chanblocks are flagged, then the first unflagged chanblock has a + /// chanblock_index of 2 but an unflagged_index of 0. pub chanblock_index: u16, /// The index into an *unflagged* array of visibilities. Regardless of the - /// first unflagged chanblock's index, its unflagged index is 0. + /// first unflagged chanblock's `chanblock_index`, its `unflagged_index` + /// is 0. pub unflagged_index: u16, - // TODO: Use frequency information. May become important for calibration - // solutions and what frequencies they apply to. /// The centroid frequency for this chanblock \[Hz\]. - pub _freq: f64, + pub freq: f64, } /// A spectral windows, a.k.a. a contiguous-band of fine-frequency channels -/// (possibly made up of multiple contiguous coarse channels). Multiple `Fence`s -/// allow a "picket fence" observation to be represented. Calibration is run on -/// each independent `Fence`. +/// (possibly made up of multiple contiguous coarse channels). Multiple `Spw`s +/// allow a "picket fence" observation to be represented. #[derive(Debug)] -pub(crate) struct Fence { - /// The unflagged calibration [Chanblock]s in this [Fence]. +pub(crate) struct Spw { + /// The unflagged [`Chanblock`]s in this [`Spw`]. pub(crate) chanblocks: Vec, + /// The indices of the flagged channels in the un-averaged input data. + /// + /// The type is `u16` to keep the memory usage down; these probably need to + /// be promoted to `usize` when being used. + pub(crate) flagged_chan_indices: HashSet, + /// The indices of the flagged chanblocks. /// /// The type is `u16` to keep the memory usage down; these probably need to /// be promoted to `usize` when being used. - pub(crate) flagged_chanblock_indices: Vec, + pub(crate) flagged_chanblock_indices: HashSet, - /// The first chanblock's centroid frequency (may be flagged) \[Hz\]. - pub(crate) _first_freq: f64, + /// The number of channels to average per chanblock. + pub(crate) chans_per_chanblock: NonZeroUsize, - /// The frequency gap between consecutive chanblocks \[Hz\]. If this isn't - /// defined, it's because there's only one chanblock. - pub(crate) _freq_res: Option, -} + /// The frequency gap between consecutive chanblocks \[Hz\]. + pub(crate) freq_res: f64, -impl Fence { - fn _get_total_num_chanblocks(&self) -> usize { - self.chanblocks.len() + self.flagged_chanblock_indices.len() - } + /// The first chanblock's centroid frequency (may be flagged) \[Hz\]. + pub(crate) first_freq: f64, +} - fn _get_freqs(&self) -> Vec { - if let Some(freq_res) = self._freq_res { - (0..self._get_total_num_chanblocks()) - .map(|i_chanblock| self._first_freq + i_chanblock as f64 * freq_res) - .collect() - } else { - vec![self._first_freq] +impl Spw { + /// Get all the frequencies of a spectral window (flagged and unflagged). + pub(crate) fn get_all_freqs(&self) -> Vec1 { + let n = self.chanblocks.len() + self.flagged_chanblock_indices.len(); + let mut freqs = Vec::with_capacity(n); + for i in 0..n { + freqs.push((i as f64).mul_add(self.freq_res, self.first_freq)); } + Vec1::try_from_vec(freqs).expect("unlikely to fail as a SPW should have at least 1 channel") } } /// Given *all* the available timestamps in some input data, the number of /// timesteps to average together into a timeblock and which timesteps to use, -/// return timeblocks to be used for calibration. Timestamps and timesteps must -/// be ascendingly sorted. +/// return timeblocks. Timestamps and timesteps must be ascendingly sorted. If +/// `timesteps_to_use` isn't given, this function assumes all timestamps will +/// be used. /// /// The timestamps must be regular in some time resolution, but gaps are /// allowed; e.g. [100, 101, 103, 104] is valid, can the code will determine a /// time resolution of 1. pub(super) fn timesteps_to_timeblocks( all_timestamps: &Vec1, - time_average_factor: usize, - timesteps_to_use: &Vec1, + time_resolution: Duration, + time_average_factor: NonZeroUsize, + timesteps_to_use: Option<&Vec1>, ) -> Vec1 { - let time_res = all_timestamps - .windows(2) - .fold(Duration::from_seconds(f64::INFINITY), |a, t| { - a.min(t[1] - t[0]) - }); - let timestamps_to_use = timesteps_to_use.mapped_ref( - |&t_step| - // TODO: Handle incorrect timestep indices. - *all_timestamps.get(t_step).unwrap(), // Could use square brackets, but this way the unwrap is clear. - ); + let (timestamps_to_use, timesteps_to_use) = match timesteps_to_use { + Some(timesteps_to_use) => { + let timestamps_to_use = timesteps_to_use.mapped_ref( + |&t_step| + // TODO: Handle incorrect timestep indices. + *all_timestamps.get(t_step).expect("timestep correctly indexes timestamps"), // Could use square brackets, but this way the potential error is clear. + ); + (timestamps_to_use, timesteps_to_use.clone()) + } + None => ( + all_timestamps.clone(), + Vec1::try_from_vec((0..all_timestamps.len()).collect::>()) + .expect("cannot be empty"), + ), + }; // Populate the median timestamps of all timeblocks based off of the first // timestamp. e.g. If there are 10 timestamps with an averaging factor of 3, @@ -162,14 +176,13 @@ pub(super) fn timesteps_to_timeblocks( let mut timeblocks = vec![]; let timeblock_length = Duration::from_total_nanoseconds( - // time_average_factor as i128 * time_res.total_nanoseconds(), - (time_average_factor - 1) as i128 * time_res.total_nanoseconds(), + (time_average_factor.get() - 1) as i128 * time_resolution.total_nanoseconds(), ); let half_a_timeblock = timeblock_length / 2; let first_timestamp = *timestamps_to_use.first(); let last_timestamp = *timestamps_to_use.last(); - let time_res = time_res.total_nanoseconds() as u128; - let time_average_factor = time_average_factor as u128; + let time_res = time_resolution.total_nanoseconds() as u128; + let time_average_factor = time_average_factor.get() as u128; let mut timeblock_index = 0; let mut timestep_index = 0; for i in 0.. { @@ -190,17 +203,25 @@ pub(super) fn timesteps_to_timeblocks( break; } - let timeblock_timestamps = timestamps_to_use - .iter() - .filter(|ts| (timeblock_start..=timeblock_end).contains(ts)) - .copied() - .collect::>(); + let (timeblock_timestamps, timeblock_timesteps): (Vec, Vec) = + timestamps_to_use + .iter() + .zip(timesteps_to_use.iter()) + .filter_map(|(timestamp, timestep)| { + if (timeblock_start..=timeblock_end).contains(timestamp) { + Some((*timestamp, *timestep)) + } else { + None + } + }) + .unzip(); if !timeblock_timestamps.is_empty() { let num_timeblock_timestamps = timeblock_timestamps.len(); timeblocks.push(Timeblock { index: timeblock_index, range: timestep_index..timestep_index + num_timeblock_timestamps, - timestamps: Vec1::try_from_vec(timeblock_timestamps).unwrap(), + timestamps: Vec1::try_from_vec(timeblock_timestamps).expect("cannot be empty"), + timesteps: Vec1::try_from_vec(timeblock_timesteps).expect("cannot be empty"), median: timeblock_median, }); timeblock_index += 1; @@ -208,83 +229,76 @@ pub(super) fn timesteps_to_timeblocks( } } - Vec1::try_from_vec(timeblocks).unwrap() + Vec1::try_from_vec(timeblocks).expect("cannot be empty") } -/// Returns a vector of [Fence]s (potentially multiple contiguous-bands of fine -/// channels) to use in calibration. If there's more than one [Fence], then this -/// is a "picket fence" observation. +/// Returns a vector of [`Spw`]s (potentially multiple contiguous-bands of fine +/// channels). If there's more than one [`Spw`], then this is a "picket fence" +/// observation. pub(super) fn channels_to_chanblocks( all_channel_freqs: &[u64], - frequency_resolution: Option, - freq_average_factor: usize, - flagged_channels: &HashSet, -) -> Vec { + freq_resolution: u64, + freq_average_factor: NonZeroUsize, + flagged_chan_indices: &HashSet, +) -> Vec { // Handle 0 or 1 provided frequencies here. match all_channel_freqs { [] => return vec![], [f] => { - let (chanblocks, flagged_chanblock_indices) = if flagged_channels.contains(&0) { - (vec![], vec![0]) + let spw = if flagged_chan_indices.contains(&0) { + Spw { + chanblocks: vec![], + flagged_chan_indices: HashSet::from([0]), + flagged_chanblock_indices: HashSet::from([0]), + chans_per_chanblock: freq_average_factor, + freq_res: freq_resolution as f64, + first_freq: *f as f64, + } } else { - ( - vec![Chanblock { + Spw { + chanblocks: vec![Chanblock { chanblock_index: 0, unflagged_index: 0, - _freq: *f as f64, + freq: *f as f64, }], - vec![], - ) + flagged_chan_indices: HashSet::new(), + flagged_chanblock_indices: HashSet::new(), + chans_per_chanblock: freq_average_factor, + freq_res: freq_resolution as f64, + first_freq: *f as f64, + } }; - return vec![Fence { - chanblocks, - flagged_chanblock_indices, - _first_freq: *f as f64, - _freq_res: None, - }]; + return vec![spw]; } _ => (), // More complicated logic needed. } - // If the frequency resolution wasn't provided, we find the minimum gap - // between consecutive frequencies and use this instead. - let freq_res = frequency_resolution - .map(|f| f.round() as u64) - .unwrap_or_else(|| { - // Iterate over all the frequencies and find the smallest gap between - // any pair. - all_channel_freqs.windows(2).fold(u64::MAX, |acc, window| { - let diff = window[1] - window[0]; - acc.min(diff) - }) - }); - - // Find any picket fences here. - let mut fence_index_ends = vec![]; - all_channel_freqs - .windows(2) - .enumerate() + // Find any picket SPWs here. + let mut spw_index_ends = vec![]; + (0..) + .zip(all_channel_freqs.windows(2)) .for_each(|(i, window)| { - if window[1] - window[0] > freq_res { - fence_index_ends.push(i + 1); + if window[1] - window[0] > freq_resolution { + spw_index_ends.push(i + 1); } }); - let mut fences = Vec::with_capacity(fence_index_ends.len() + 1); - let biggest_freq_diff = freq_res * freq_average_factor as u64; + let mut spws = Vec::with_capacity(spw_index_ends.len() + 1); + let biggest_freq_diff = freq_resolution * freq_average_factor.get() as u64; let mut chanblocks = vec![]; - let mut flagged_chanblock_indices = vec![]; + let mut flagged_chanblock_indices = HashSet::new(); let mut i_chanblock = 0; let mut i_unflagged_chanblock = 0; let mut current_freqs = vec![]; - let mut first_fence_freq = None; + let mut first_spw_freq = None; let mut first_freq = None; let mut all_flagged = true; + let mut this_spw_flagged_chans = HashSet::new(); - for (i_chan, &freq) in all_channel_freqs.iter().enumerate() { - match first_fence_freq { + for (i_chan, &freq) in (0..).zip(all_channel_freqs.iter()) { + match first_spw_freq { Some(_) => (), - None => first_fence_freq = Some(freq), + None => first_spw_freq = Some(freq), } match first_freq { Some(_) => (), @@ -293,14 +307,14 @@ pub(super) fn channels_to_chanblocks( if freq - first_freq.unwrap() >= biggest_freq_diff { if all_flagged { - flagged_chanblock_indices.push(i_chanblock); + flagged_chanblock_indices.insert(i_chanblock); } else { - let centroid_freq = - first_freq.unwrap() + freq_res / 2 * (freq_average_factor - 1) as u64; + let centroid_freq = first_freq.unwrap() + + freq_resolution / 2 * (freq_average_factor.get() - 1) as u64; chanblocks.push(Chanblock { chanblock_index: i_chanblock, unflagged_index: i_unflagged_chanblock, - _freq: centroid_freq as f64, + freq: centroid_freq as f64, }); i_unflagged_chanblock += 1; } @@ -311,43 +325,55 @@ pub(super) fn channels_to_chanblocks( } current_freqs.push(freq as f64); - if !flagged_channels.contains(&i_chan) { + if flagged_chan_indices.contains(&i_chan) { + this_spw_flagged_chans.insert(i_chan); + } else { all_flagged = false; } - if fence_index_ends.contains(&i_chan) { - fences.push(Fence { + if spw_index_ends.contains(&i_chan) { + spws.push(Spw { chanblocks: chanblocks.clone(), + flagged_chan_indices: this_spw_flagged_chans.clone(), flagged_chanblock_indices: flagged_chanblock_indices.clone(), - _first_freq: first_fence_freq.unwrap() as f64, - _freq_res: Some(biggest_freq_diff as f64), + chans_per_chanblock: freq_average_factor, + freq_res: biggest_freq_diff as f64, + first_freq: (first_spw_freq.unwrap() + + freq_resolution / 2 * (freq_average_factor.get() - 1) as u64) + as f64, }); - first_fence_freq = Some(freq); + first_spw_freq = Some(freq); chanblocks.clear(); flagged_chanblock_indices.clear(); + this_spw_flagged_chans.clear(); } } // Deal with any leftover data. if let Some(first_freq) = first_freq { if all_flagged { - flagged_chanblock_indices.push(i_chanblock); + flagged_chanblock_indices.insert(i_chanblock); } else { - let centroid_freq = first_freq + freq_res / 2 * (freq_average_factor - 1) as u64; + let centroid_freq = + first_freq + freq_resolution / 2 * (freq_average_factor.get() - 1) as u64; chanblocks.push(Chanblock { chanblock_index: i_chanblock, unflagged_index: i_unflagged_chanblock, - _freq: centroid_freq as f64, + freq: centroid_freq as f64, }); } - fences.push(Fence { + spws.push(Spw { chanblocks, + flagged_chan_indices: this_spw_flagged_chans, flagged_chanblock_indices, - _first_freq: first_fence_freq.unwrap() as f64, - _freq_res: Some(biggest_freq_diff as f64), + chans_per_chanblock: freq_average_factor, + freq_res: biggest_freq_diff as f64, + first_freq: (first_spw_freq.unwrap() + + freq_resolution / 2 * (freq_average_factor.get() - 1) as u64) + as f64, }); } - fences + spws } /// Determine a time average factor given a time resolution and user input. Use @@ -358,13 +384,13 @@ pub(super) fn channels_to_chanblocks( pub(super) fn parse_time_average_factor( time_resolution: Option, user_input_time_factor: Option<&str>, - default: usize, -) -> Result { + default: NonZeroUsize, +) -> Result { match (time_resolution, user_input_time_factor.map(parse_time)) { (None, _) => { // If the time resolution is unknown, we assume it's because there's // only one timestep. - Ok(1) + Ok(NonZeroUsize::new(1).unwrap()) } (_, None) => { // "None" indicates we should follow default behaviour. @@ -384,7 +410,8 @@ pub(super) fn parse_time_average_factor( return Err(AverageFactorError::NotInteger); } - Ok(factor.round() as _) + let u = factor.round() as usize; + Ok(NonZeroUsize::new(u).expect("is not 0")) } // User input is OK and has a unit. @@ -408,7 +435,8 @@ pub(super) fn parse_time_average_factor( }); } - Ok(factor.round() as _) + let u = factor.round() as usize; + Ok(NonZeroUsize::new(u).expect("is not 0")) } } } @@ -421,13 +449,13 @@ pub(super) fn parse_time_average_factor( pub(super) fn parse_freq_average_factor( freq_resolution: Option, user_input_freq_factor: Option<&str>, - default: usize, -) -> Result { + default: NonZeroUsize, +) -> Result { match (freq_resolution, user_input_freq_factor.map(parse_freq)) { (None, _) => { // If the freq. resolution is unknown, we assume it's because // there's only one channel. - Ok(1) + Ok(NonZeroUsize::new(1).unwrap()) } (_, None) => { // "None" indicates we should follow default behaviour. @@ -447,7 +475,8 @@ pub(super) fn parse_freq_average_factor( return Err(AverageFactorError::NotInteger); } - Ok(factor.round() as _) + let u = factor.round() as usize; + Ok(NonZeroUsize::new(u).expect("is not 0")) } // User input is OK and has a unit. @@ -471,7 +500,129 @@ pub(super) fn parse_freq_average_factor( }); } - Ok(factor.round() as _) + let u = factor.round() as usize; + Ok(NonZeroUsize::new(u).expect("is not 0")) } } } + +pub(crate) fn vis_average( + jones_from_tfb: ArrayView3>, + mut jones_to_fb: ArrayViewMut2>, + weight_from_tfb: ArrayView3, + mut weight_to_fb: ArrayViewMut2, + flagged_chanblock_indices: &HashSet, +) { + let avg_time = jones_from_tfb.len_of(Axis(0)); + let avg_freq = (jones_from_tfb.len_of(Axis(1)) as f64 + / (jones_to_fb.len_of(Axis(0)) + flagged_chanblock_indices.len()) as f64) + .ceil() as usize; + + // { + // assert_eq!(jones_from_tfb.dim(), weight_from_tfb.dim()); + // assert_eq!(jones_to_fb.dim(), weight_to_fb.dim()); + // let (_time_from, freq_from, baseline_from) = jones_from_tfb.dim(); + // let (freqs_to, baseline_to) = jones_to_fb.dim(); + // assert_eq!( + // (freq_from as f64 / avg_freq as f64).floor() as usize, + // freqs_to + flagged_chan_indices.len(), + // ); + // assert_eq!( + // avg_freq * (freqs_to + flagged_chan_indices.len()), + // freq_from + // ); + // assert_eq!(baseline_from, baseline_to); + // } + + // iterate along time axis in chunks of avg_time + jones_from_tfb + .axis_chunks_iter(Axis(0), avg_time) + .zip(weight_from_tfb.axis_chunks_iter(Axis(0), avg_time)) + .for_each(|(jones_chunk_tfb, weight_chunk_tfb)| { + jones_chunk_tfb + .axis_iter(Axis(2)) + .zip(weight_chunk_tfb.axis_iter(Axis(2))) + .zip(jones_to_fb.axis_iter_mut(Axis(1))) + .zip(weight_to_fb.axis_iter_mut(Axis(1))) + .for_each( + |(((jones_chunk_tf, weight_chunk_tf), mut jones_to_f), mut weight_to_f)| { + jones_chunk_tf + .axis_chunks_iter(Axis(1), avg_freq) + .zip(weight_chunk_tf.axis_chunks_iter(Axis(1), avg_freq)) + .enumerate() + .filter(|(i, _)| !flagged_chanblock_indices.contains(&(*i as u16))) + .map(|(_, d)| d) + .zip(jones_to_f.iter_mut()) + .zip(weight_to_f.iter_mut()) + .for_each( + |(((jones_chunk_tf, weight_chunk_tf), jones_to), weight_to)| { + vis_average_weights_non_zero( + jones_chunk_tf, + weight_chunk_tf, + jones_to, + weight_to, + ); + }, + ); + }, + ); + }); +} + +/// Average a chunk of visibilities and weights (both must have the same +/// dimensions) into an output vis and weight. This function allows the weights +/// to be negative; if all of the weights in the chunk are negative or 0, the +/// averaged visibility is considered "flagged". +#[inline] +fn vis_average_weights_non_zero( + jones_chunk_tf: ArrayView2>, + weight_chunk_tf: ArrayView2, + jones_to: &mut Jones, + weight_to: &mut f32, +) { + let mut jones_weighted_sum = Jones::default(); + let mut weight_sum = 0.0; + let mut flagged = true; + + // iterate through time chunks + jones_chunk_tf + .iter() + .zip_eq(weight_chunk_tf.iter()) + .for_each(|(jones, weight)| { + let jones = Jones::::from(*jones); + let weight = *weight as f64; + + if weight > 0.0 { + // This visibility is not flagged. + if flagged { + // If previous visibilities were flagged, we need to discard + // that information. + jones_weighted_sum = jones * weight; + weight_sum = weight; + flagged = false; + } else { + // Otherwise, we're accumulating this unflagged vis. + jones_weighted_sum += jones * weight; + weight_sum += weight; + } + } else { + // This visibility is flagged. + if flagged { + // If all prior vis were also flagged, we accumulate here. + jones_weighted_sum += jones * weight; + weight_sum += weight; + } + // Nothing needs to be done if there were preceding unflagged + // vis. + } + }); + + if weight_sum == 0.0 { + // If the weight is 0, we can't divide the accumulated vis by the + // accumulated weight. So, divide by the chunk size instead. + *jones_to = Jones::from(jones_weighted_sum / jones_chunk_tf.len() as f64); + } else { + *jones_to = Jones::from(jones_weighted_sum / weight_sum); + } + *weight_to = weight_sum as f32; +} diff --git a/src/averaging/tests.rs b/src/averaging/tests.rs index 39c36bb4..4ed698dc 100644 --- a/src/averaging/tests.rs +++ b/src/averaging/tests.rs @@ -17,10 +17,15 @@ fn test_timesteps_to_timeblocks() { .collect(); let all_timestamps = Vec1::try_from_vec(all_timestamps).unwrap(); - let time_average_factor = 1; + let time_res = Duration::from_seconds(2.0); + let time_average_factor = NonZeroUsize::new(1).unwrap(); let timesteps_to_use = vec1![2, 3, 4, 5]; - let timeblocks = - timesteps_to_timeblocks(&all_timestamps, time_average_factor, ×teps_to_use); + let timeblocks = timesteps_to_timeblocks( + &all_timestamps, + time_res, + time_average_factor, + Some(×teps_to_use), + ); // Time average factor is 1; 1 timestep per timeblock. assert_eq!(timeblocks.len(), 4); for ((timeblock, expected_indices), expected_timestamp) in timeblocks @@ -34,16 +39,20 @@ fn test_timesteps_to_timeblocks() { } assert_eq!( - average_epoch(&timeblock.timestamps).to_gpst_seconds(), + average_epoch(timeblock.timestamps).to_gpst_seconds(), expected_timestamp ); assert_eq!(timeblock.median.to_gpst_seconds(), expected_timestamp); } - let time_average_factor = 2; + let time_average_factor = NonZeroUsize::new(2).unwrap(); let timesteps_to_use = vec1![2, 3, 4, 5]; - let timeblocks = - timesteps_to_timeblocks(&all_timestamps, time_average_factor, ×teps_to_use); + let timeblocks = timesteps_to_timeblocks( + &all_timestamps, + time_res, + time_average_factor, + Some(×teps_to_use), + ); // 2 timesteps per timeblock. assert_eq!(timeblocks.len(), 2); for ((timeblock, expected_indices), expected_timestamp) in timeblocks @@ -57,16 +66,20 @@ fn test_timesteps_to_timeblocks() { } assert_eq!( - average_epoch(&timeblock.timestamps).to_gpst_seconds(), + average_epoch(timeblock.timestamps).to_gpst_seconds(), expected_timestamp ); assert_eq!(timeblock.median.to_gpst_seconds(), expected_timestamp); } - let time_average_factor = 3; + let time_average_factor = NonZeroUsize::new(3).unwrap(); let timesteps_to_use = vec1![2, 3, 4, 5]; - let timeblocks = - timesteps_to_timeblocks(&all_timestamps, time_average_factor, ×teps_to_use); + let timeblocks = timesteps_to_timeblocks( + &all_timestamps, + time_res, + time_average_factor, + Some(×teps_to_use), + ); // 3 timesteps per timeblock, but the last timeblock has only one timestep. assert_eq!(timeblocks.len(), 2); for ((timeblock, expected_indices), expected_timestamp) in timeblocks @@ -80,7 +93,7 @@ fn test_timesteps_to_timeblocks() { } assert_eq!( - average_epoch(&timeblock.timestamps).to_gpst_seconds(), + average_epoch(timeblock.timestamps).to_gpst_seconds(), expected_timestamp ); // The median is different from the average for the second timeblock. @@ -94,10 +107,15 @@ fn test_timesteps_to_timeblocks() { let timesteps_to_use = vec1![2, 15, 16]; // Average all the timesteps together. This is what is used to calculate the // time average factor in this case. - let time_average_factor = *timesteps_to_use.last() - *timesteps_to_use.first() + 1; - assert_eq!(time_average_factor, 15); - let timeblocks = - timesteps_to_timeblocks(&all_timestamps, time_average_factor, ×teps_to_use); + let time_average_factor = + NonZeroUsize::new(*timesteps_to_use.last() - *timesteps_to_use.first() + 1).unwrap(); + assert_eq!(time_average_factor.get(), 15); + let timeblocks = timesteps_to_timeblocks( + &all_timestamps, + time_res, + time_average_factor, + Some(×teps_to_use), + ); assert_eq!(timeblocks.len(), 1); for ((timeblock, expected_indices), expected_timestamp) in timeblocks.into_iter().zip([[0, 1, 2]]).zip([1065880150.0]) @@ -108,7 +126,7 @@ fn test_timesteps_to_timeblocks() { } assert_eq!( - average_epoch(&timeblock.timestamps).to_gpst_seconds(), + average_epoch(timeblock.timestamps).to_gpst_seconds(), expected_timestamp ); // (2 + 16) / 2 = 9 is the median timestep @@ -120,146 +138,153 @@ fn test_timesteps_to_timeblocks() { #[test] fn test_channels_to_chanblocks() { let all_channel_freqs = [12000]; - let freq_average_factor = 1; + let freq_average_factor = NonZeroUsize::new(1).unwrap(); let mut flagged_channels = HashSet::new(); - let fences = channels_to_chanblocks( + let freq_res = 1000; + let spws = channels_to_chanblocks( &all_channel_freqs, - None, + freq_res, freq_average_factor, &flagged_channels, ); - assert_eq!(fences.len(), 1); - assert_eq!(fences[0].chanblocks.len(), 1); - assert!(fences[0].flagged_chanblock_indices.is_empty()); - assert_abs_diff_eq!(fences[0].chanblocks[0]._freq, 12000.0); - assert_abs_diff_eq!(fences[0]._first_freq, 12000.0); - assert!(fences[0]._freq_res.is_none()); + assert_eq!(spws.len(), 1); + assert_eq!(spws[0].chanblocks.len(), 1); + assert!(spws[0].flagged_chanblock_indices.is_empty()); + assert_abs_diff_eq!(spws[0].chanblocks[0].freq, 12000.0); + assert_abs_diff_eq!(spws[0].freq_res, freq_res as f64); + assert_abs_diff_eq!(spws[0].first_freq, 12000.0); let all_channel_freqs = [10000, 11000, 12000, 13000, 14000]; - let fences = channels_to_chanblocks( + let spws = channels_to_chanblocks( &all_channel_freqs, - None, + freq_res, freq_average_factor, &flagged_channels, ); - assert_eq!(fences.len(), 1); - assert_eq!(fences[0].chanblocks.len(), 5); - assert!(fences[0].flagged_chanblock_indices.is_empty()); - assert_abs_diff_eq!(fences[0].chanblocks[0]._freq, 10000.0); - assert_abs_diff_eq!(fences[0].chanblocks[1]._freq, 11000.0); - assert_abs_diff_eq!(fences[0].chanblocks[2]._freq, 12000.0); - assert_abs_diff_eq!(fences[0].chanblocks[3]._freq, 13000.0); - assert_abs_diff_eq!(fences[0].chanblocks[4]._freq, 14000.0); - assert_abs_diff_eq!(fences[0]._first_freq, 10000.0); - assert_abs_diff_eq!(fences[0]._freq_res.unwrap(), 1000.0); + assert_eq!(spws.len(), 1); + assert_eq!(spws[0].chanblocks.len(), 5); + assert!(spws[0].flagged_chanblock_indices.is_empty()); + assert_abs_diff_eq!(spws[0].chanblocks[0].freq, 10000.0); + assert_abs_diff_eq!(spws[0].chanblocks[1].freq, 11000.0); + assert_abs_diff_eq!(spws[0].chanblocks[2].freq, 12000.0); + assert_abs_diff_eq!(spws[0].chanblocks[3].freq, 13000.0); + assert_abs_diff_eq!(spws[0].chanblocks[4].freq, 14000.0); + assert_abs_diff_eq!(spws[0].freq_res, 1000.0); + assert_abs_diff_eq!(spws[0].first_freq, 10000.0); let all_channel_freqs = [10000, 11000, 12000, 13000, 14000, 20000]; - let fences = channels_to_chanblocks( + let spws = channels_to_chanblocks( &all_channel_freqs, - None, + freq_res, freq_average_factor, &flagged_channels, ); - assert_eq!(fences.len(), 2); - assert_eq!(fences[0].chanblocks.len(), 5); - assert_eq!(fences[1].chanblocks.len(), 1); - assert!(fences[0].flagged_chanblock_indices.is_empty()); - assert!(fences[1].flagged_chanblock_indices.is_empty()); - assert_abs_diff_eq!(fences[0].chanblocks[0]._freq, 10000.0); - assert_abs_diff_eq!(fences[0].chanblocks[1]._freq, 11000.0); - assert_abs_diff_eq!(fences[0].chanblocks[2]._freq, 12000.0); - assert_abs_diff_eq!(fences[0].chanblocks[3]._freq, 13000.0); - assert_abs_diff_eq!(fences[0].chanblocks[4]._freq, 14000.0); - assert_abs_diff_eq!(fences[1].chanblocks[0]._freq, 20000.0); - assert_abs_diff_eq!(fences[0]._first_freq, 10000.0); - assert_abs_diff_eq!(fences[1]._first_freq, 20000.0); - assert_abs_diff_eq!(fences[0]._freq_res.unwrap(), 1000.0); - assert_abs_diff_eq!(fences[1]._freq_res.unwrap(), 1000.0); + assert_eq!(spws.len(), 2); + assert_eq!(spws[0].chanblocks.len(), 5); + assert_eq!(spws[1].chanblocks.len(), 1); + assert!(spws[0].flagged_chanblock_indices.is_empty()); + assert!(spws[1].flagged_chanblock_indices.is_empty()); + assert_abs_diff_eq!(spws[0].chanblocks[0].freq, 10000.0); + assert_abs_diff_eq!(spws[0].chanblocks[1].freq, 11000.0); + assert_abs_diff_eq!(spws[0].chanblocks[2].freq, 12000.0); + assert_abs_diff_eq!(spws[0].chanblocks[3].freq, 13000.0); + assert_abs_diff_eq!(spws[0].chanblocks[4].freq, 14000.0); + assert_abs_diff_eq!(spws[1].chanblocks[0].freq, 20000.0); + assert_abs_diff_eq!(spws[0].freq_res, 1000.0); + assert_abs_diff_eq!(spws[1].freq_res, 1000.0); + assert_abs_diff_eq!(spws[0].first_freq, 10000.0); + assert_abs_diff_eq!(spws[1].first_freq, 20000.0); flagged_channels.insert(3); - let fences = channels_to_chanblocks( + let spws = channels_to_chanblocks( &all_channel_freqs, - None, + freq_res, freq_average_factor, &flagged_channels, ); - assert_eq!(fences.len(), 2); - assert_eq!(fences[0].chanblocks.len(), 4); - assert_eq!(fences[1].chanblocks.len(), 1); - assert_eq!(fences[0].flagged_chanblock_indices.len(), 1); - assert_eq!(fences[0].flagged_chanblock_indices[0], 3); - assert!(fences[1].flagged_chanblock_indices.is_empty()); - assert_abs_diff_eq!(fences[0].chanblocks[0]._freq, 10000.0); - assert_abs_diff_eq!(fences[0].chanblocks[1]._freq, 11000.0); - assert_abs_diff_eq!(fences[0].chanblocks[2]._freq, 12000.0); - assert_abs_diff_eq!(fences[0].chanblocks[3]._freq, 14000.0); - assert_abs_diff_eq!(fences[1].chanblocks[0]._freq, 20000.0); - assert_abs_diff_eq!(fences[0]._first_freq, 10000.0); - assert_abs_diff_eq!(fences[1]._first_freq, 20000.0); - assert_abs_diff_eq!(fences[0]._freq_res.unwrap(), 1000.0); - assert_abs_diff_eq!(fences[1]._freq_res.unwrap(), 1000.0); - - let freq_average_factor = 2; - let fences = channels_to_chanblocks( + assert_eq!(spws.len(), 2); + assert_eq!(spws[0].chanblocks.len(), 4); + assert_eq!(spws[1].chanblocks.len(), 1); + assert_eq!(spws[0].flagged_chanblock_indices.len(), 1); + let mut sorted = spws[0] + .flagged_chanblock_indices + .iter() + .copied() + .collect::>(); + sorted.sort_unstable(); + assert_eq!(sorted[0], 3); + assert!(spws[1].flagged_chanblock_indices.is_empty()); + assert_abs_diff_eq!(spws[0].chanblocks[0].freq, 10000.0); + assert_abs_diff_eq!(spws[0].chanblocks[1].freq, 11000.0); + assert_abs_diff_eq!(spws[0].chanblocks[2].freq, 12000.0); + assert_abs_diff_eq!(spws[0].chanblocks[3].freq, 14000.0); + assert_abs_diff_eq!(spws[1].chanblocks[0].freq, 20000.0); + assert_abs_diff_eq!(spws[0].freq_res, 1000.0); + assert_abs_diff_eq!(spws[1].freq_res, 1000.0); + assert_abs_diff_eq!(spws[0].first_freq, 10000.0); + assert_abs_diff_eq!(spws[1].first_freq, 20000.0); + + let freq_average_factor = NonZeroUsize::new(2).unwrap(); + let spws = channels_to_chanblocks( &all_channel_freqs, - None, + freq_res, freq_average_factor, &flagged_channels, ); - assert_eq!(fences.len(), 2); - assert_eq!(fences[0].chanblocks.len(), 3); - assert_eq!(fences[1].chanblocks.len(), 1); - assert!(fences[0].flagged_chanblock_indices.is_empty()); - assert!(fences[1].flagged_chanblock_indices.is_empty()); - assert_abs_diff_eq!(fences[0].chanblocks[0]._freq, 10500.0); - assert_abs_diff_eq!(fences[0].chanblocks[1]._freq, 12500.0); - assert_abs_diff_eq!(fences[0].chanblocks[2]._freq, 14500.0); - assert_abs_diff_eq!(fences[1].chanblocks[0]._freq, 20500.0); - assert_abs_diff_eq!(fences[0]._first_freq, 10000.0); - assert_abs_diff_eq!(fences[1]._first_freq, 20000.0); - assert_abs_diff_eq!(fences[0]._freq_res.unwrap(), 2000.0); - assert_abs_diff_eq!(fences[1]._freq_res.unwrap(), 2000.0); - - let freq_average_factor = 3; - let fences = channels_to_chanblocks( + assert_eq!(spws.len(), 2); + assert_eq!(spws[0].chanblocks.len(), 3); + assert_eq!(spws[1].chanblocks.len(), 1); + assert!(spws[0].flagged_chanblock_indices.is_empty()); + assert!(spws[1].flagged_chanblock_indices.is_empty()); + assert_abs_diff_eq!(spws[0].chanblocks[0].freq, 10500.0); + assert_abs_diff_eq!(spws[0].chanblocks[1].freq, 12500.0); + assert_abs_diff_eq!(spws[0].chanblocks[2].freq, 14500.0); + assert_abs_diff_eq!(spws[1].chanblocks[0].freq, 20500.0); + assert_abs_diff_eq!(spws[0].freq_res, 2000.0); + assert_abs_diff_eq!(spws[1].freq_res, 2000.0); + assert_abs_diff_eq!(spws[0].first_freq, 10500.0); + assert_abs_diff_eq!(spws[1].first_freq, 20500.0); + + let freq_average_factor = NonZeroUsize::new(3).unwrap(); + let spws = channels_to_chanblocks( &all_channel_freqs, - None, + freq_res, freq_average_factor, &flagged_channels, ); - assert_eq!(fences.len(), 2); - assert_eq!(fences[0].chanblocks.len(), 2); - assert_eq!(fences[1].chanblocks.len(), 1); - assert!(fences[0].flagged_chanblock_indices.is_empty()); - assert!(fences[1].flagged_chanblock_indices.is_empty()); - assert_abs_diff_eq!(fences[0].chanblocks[0]._freq, 11000.0); - assert_abs_diff_eq!(fences[0].chanblocks[1]._freq, 14000.0); - assert_abs_diff_eq!(fences[1].chanblocks[0]._freq, 21000.0); - assert_abs_diff_eq!(fences[0]._first_freq, 10000.0); - assert_abs_diff_eq!(fences[1]._first_freq, 20000.0); - assert_abs_diff_eq!(fences[0]._freq_res.unwrap(), 3000.0); - assert_abs_diff_eq!(fences[1]._freq_res.unwrap(), 3000.0); + assert_eq!(spws.len(), 2); + assert_eq!(spws[0].chanblocks.len(), 2); + assert_eq!(spws[1].chanblocks.len(), 1); + assert!(spws[0].flagged_chanblock_indices.is_empty()); + assert!(spws[1].flagged_chanblock_indices.is_empty()); + assert_abs_diff_eq!(spws[0].chanblocks[0].freq, 11000.0); + assert_abs_diff_eq!(spws[0].chanblocks[1].freq, 14000.0); + assert_abs_diff_eq!(spws[1].chanblocks[0].freq, 21000.0); + assert_abs_diff_eq!(spws[0].freq_res, 3000.0); + assert_abs_diff_eq!(spws[1].freq_res, 3000.0); + assert_abs_diff_eq!(spws[0].first_freq, 11000.0); + assert_abs_diff_eq!(spws[1].first_freq, 21000.0); } -// No frequencies, no fences. +// No frequencies, no spws. #[test] fn test_no_channels_to_chanblocks() { let all_channel_freqs = []; - let freq_average_factor = 2; + let freq_average_factor = NonZeroUsize::new(2).unwrap(); let flagged_channels = HashSet::new(); - let fences = channels_to_chanblocks( + let spws = channels_to_chanblocks( &all_channel_freqs, - None, + 10e3 as u64, freq_average_factor, &flagged_channels, ); - assert!(fences.is_empty()); + assert!(spws.is_empty()); } fn test_time( time_resolution: Option, user_input_time_factor: Option<&str>, - default: usize, + default: NonZeroUsize, expected: Option, ) { let result = parse_time_average_factor(time_resolution, user_input_time_factor, default); @@ -267,8 +292,7 @@ fn test_time( match expected { // If expected is Some, then we expect the result to match. Some(expected) => { - assert!(result.is_ok(), "res={time_resolution:?}, input={user_input_time_factor:?}, default={default}, expected={expected:?}, error={}", result.unwrap_err()); - assert_eq!(result.unwrap(), expected, "res={time_resolution:?}, input={user_input_time_factor:?}, default={default}, expected={expected:?}"); + assert_eq!(result.unwrap().get(), expected, "res={time_resolution:?}, input={user_input_time_factor:?}, default={default}, expected={expected:?}"); } // Otherwise, we expect failure. None => { @@ -280,7 +304,7 @@ fn test_time( #[test] fn test_parse_time_average_factor() { let time_resolution = Some(Duration::from_seconds(2.0)); - let default = 100; + let default = NonZeroUsize::new(100).unwrap(); let user_input_time_factor = Some("2"); let expected = Some(2); @@ -319,14 +343,14 @@ fn test_parse_time_average_factor() { test_time(time_resolution, user_input_time_factor, default, expected); let user_input_time_factor = None; - let expected = Some(default); + let expected = Some(default.get()); test_time(time_resolution, user_input_time_factor, default, expected); } fn test_freq( freq_resolution: Option, user_input_freq_factor: Option<&str>, - default: usize, + default: NonZeroUsize, expected: Option, ) { let result = parse_freq_average_factor(freq_resolution, user_input_freq_factor, default); @@ -334,8 +358,7 @@ fn test_freq( match expected { // If expected is Some, then we expect the result to match. Some(expected) => { - assert!(result.is_ok(), "res={freq_resolution:?}, input={user_input_freq_factor:?}, default={default}, expected={expected:?}, error={}", result.unwrap_err()); - assert_eq!(result.unwrap(), expected, "res={freq_resolution:?}, input={user_input_freq_factor:?}, default={default}, expected={expected:?}"); + assert_eq!(result.unwrap().get(), expected, "res={freq_resolution:?}, input={user_input_freq_factor:?}, default={default}, expected={expected:?}"); } // Otherwise, we expect failure. None => { @@ -347,7 +370,7 @@ fn test_freq( #[test] fn test_parse_freq_average_factor() { let freq_resolution = Some(40000.0); // Hz - let default = 1; + let default = NonZeroUsize::new(1).unwrap(); let user_input_freq_factor = Some("2"); let expected = Some(2); @@ -378,6 +401,328 @@ fn test_parse_freq_average_factor() { test_freq(freq_resolution, user_input_freq_factor, default, expected); let user_input_freq_factor = None; - let expected = Some(default); + let expected = Some(default.get()); test_freq(freq_resolution, user_input_freq_factor, default, expected); } + +#[test] +fn test_vis_average_1d_time() { + // 1 timestep, 4 channels, 1 baseline. + let jones_from_tfb = array![[ + [Jones::identity()], + [Jones::identity() * 2.0], + [Jones::identity() * 3.0], + [Jones::identity() * 4.0] + ]]; + let weight_from_tfb = Array3::ones(jones_from_tfb.dim()); + let mut jones_to_fb = Array2::default(( + jones_from_tfb.len_of(Axis(1)), + jones_from_tfb.len_of(Axis(2)), + )); + let mut weight_to_fb = Array2::default(jones_to_fb.dim()); + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + + for (jones_from, jones_to) in jones_from_tfb.iter().zip_eq(jones_to_fb.iter()) { + assert_abs_diff_eq!(jones_from, jones_to); + } + for (weight_from, weight_to) in weight_from_tfb.iter().zip_eq(weight_to_fb.iter()) { + assert_abs_diff_eq!(weight_from, weight_to); + } +} + +#[test] +fn test_vis_average() { + // 2 timesteps, 4 channels, 1 baseline. + let jones_from_tfb = array![ + [ + [Jones::identity()], + [Jones::identity() * 2.0], + [Jones::identity() * 3.0], + [Jones::identity() * 4.0] + ], + [ + [Jones::identity() * 5.0], + [Jones::identity() * 6.0], + [Jones::identity() * 7.0], + [Jones::identity() * 8.0] + ] + ]; + let mut weight_from_tfb = Array3::ones(jones_from_tfb.dim()); + let mut jones_to_fb = Array2::default(( + jones_from_tfb.len_of(Axis(1)), + jones_from_tfb.len_of(Axis(2)), + )); + let mut weight_to_fb = Array2::default(jones_to_fb.dim()); + let num_chans = jones_from_tfb.len_of(Axis(1)); + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + for (i, (jones, weight)) in jones_to_fb + .iter() + .copied() + .zip(weight_to_fb.iter().copied()) + .enumerate() + { + // i -> ((i + 1) + (num_chans + i + 1)) / 2 -> (2*i + num_chans + 2) / 2 + let ii = (2 * i + num_chans + 2) as f32 / 2.0; + assert_abs_diff_eq!(jones, Jones::identity() * ii); + assert_abs_diff_eq!(weight, 2.0); + } + + // Make the first channel's weights negative. + weight_from_tfb.slice_mut(s![.., 0, ..]).fill(-1.0); + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + + for (i, (jones, weight)) in jones_to_fb + .iter() + .copied() + .zip(weight_to_fb.iter().copied()) + .enumerate() + { + let ii = (2 * i + num_chans + 2) as f32 / 2.0; + assert_abs_diff_eq!(jones, Jones::identity() * ii); + // The first channel's weight accumulates only negatives. + if i == 0 { + assert_abs_diff_eq!(weight, -2.0); + } else { + assert_abs_diff_eq!(weight, 2.0); + } + } + + // Make all weights positive, except for the very first one. + weight_from_tfb.fill(1.0); + weight_from_tfb[(0, 0, 0)] = -1.0; + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + + for (i, (jones, weight)) in jones_to_fb + .iter() + .copied() + .zip(weight_to_fb.iter().copied()) + .enumerate() + { + let ii = (2 * i + num_chans + 2) as f32 / 2.0; + if i == 0 { + // The first channel uses only data corresponding to the positive + // weight. + assert_abs_diff_eq!(jones, Jones::identity() * 5.0); + assert_abs_diff_eq!(weight, 1.0); + } else { + assert_abs_diff_eq!(jones, Jones::identity() * ii); + assert_abs_diff_eq!(weight, 2.0); + } + } + + // Now let's average in time and frequency. + weight_from_tfb.fill(1.0); + let mut jones_to_fb = Array2::default((num_chans / 2, jones_from_tfb.len_of(Axis(2)))); + let mut weight_to_fb = Array2::default(jones_to_fb.dim()); + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + + assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 14.0 / 4.0); + assert_abs_diff_eq!(weight_to_fb[(0, 0)], 4.0); + + assert_abs_diff_eq!(jones_to_fb[(1, 0)], Jones::identity() * 22.0 / 4.0); + assert_abs_diff_eq!(weight_to_fb[(1, 0)], 4.0); +} + +#[test] +fn test_vis_average_non_uniform_weights() { + // 2 timesteps, 4 channels, 1 baseline. + let jones_from_tfb = array![ + [ + [Jones::identity()], + [Jones::identity() * 2.0], + [Jones::identity() * 3.0], + [Jones::identity() * 4.0] + ], + [ + [Jones::identity() * 5.0], + [Jones::identity() * 6.0], + [Jones::identity() * 7.0], + [Jones::identity() * 8.0] + ] + ]; + let mut weight_from_tfb = array![ + [[2.0], [3.0], [5.0], [7.0]], + [[11.0], [13.0], [17.0], [19.0]] + ]; + let mut jones_to_fb = Array2::default(( + jones_from_tfb.len_of(Axis(1)), + jones_from_tfb.len_of(Axis(2)), + )); + let mut weight_to_fb = Array2::default(jones_to_fb.dim()); + let num_chans = jones_from_tfb.len_of(Axis(1)); + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + + assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 57.0 / 13.0); + assert_abs_diff_eq!(weight_to_fb[(0, 0)], 13.0); + + assert_abs_diff_eq!(jones_to_fb[(1, 0)], Jones::identity() * 84.0 / 16.0); + assert_abs_diff_eq!(weight_to_fb[(1, 0)], 16.0); + + assert_abs_diff_eq!(jones_to_fb[(2, 0)], Jones::identity() * 134.0 / 22.0); + assert_abs_diff_eq!(weight_to_fb[(2, 0)], 22.0); + + assert_abs_diff_eq!(jones_to_fb[(3, 0)], Jones::identity() * 180.0 / 26.0); + assert_abs_diff_eq!(weight_to_fb[(3, 0)], 26.0); + + // Make the first channel's weights negative. + weight_from_tfb[(0, 0, 0)] *= -1.0; + weight_from_tfb[(1, 0, 0)] *= -1.0; + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + + // The first channel's weight accumulates only negatives. + assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 57.0 / 13.0); + assert_abs_diff_eq!(weight_to_fb[(0, 0)], -13.0); + + assert_abs_diff_eq!(jones_to_fb[(1, 0)], Jones::identity() * 84.0 / 16.0); + assert_abs_diff_eq!(weight_to_fb[(1, 0)], 16.0); + + assert_abs_diff_eq!(jones_to_fb[(2, 0)], Jones::identity() * 134.0 / 22.0); + assert_abs_diff_eq!(weight_to_fb[(2, 0)], 22.0); + + assert_abs_diff_eq!(jones_to_fb[(3, 0)], Jones::identity() * 180.0 / 26.0); + assert_abs_diff_eq!(weight_to_fb[(3, 0)], 26.0); + + // Make all weights positive, except for the very first one. + weight_from_tfb[(1, 0, 0)] *= -1.0; + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + + // The first channel uses only data corresponding to the positive weight. + assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 5.0); + assert_abs_diff_eq!(weight_to_fb[(0, 0)], 11.0); + + assert_abs_diff_eq!(jones_to_fb[(1, 0)], Jones::identity() * 84.0 / 16.0); + assert_abs_diff_eq!(weight_to_fb[(1, 0)], 16.0); + + assert_abs_diff_eq!(jones_to_fb[(2, 0)], Jones::identity() * 134.0 / 22.0); + assert_abs_diff_eq!(weight_to_fb[(2, 0)], 22.0); + + assert_abs_diff_eq!(jones_to_fb[(3, 0)], Jones::identity() * 180.0 / 26.0); + assert_abs_diff_eq!(weight_to_fb[(3, 0)], 26.0); + + // Now let's average in time and frequency. + weight_from_tfb[(0, 0, 0)] *= -1.0; + let mut jones_to_fb = Array2::default((num_chans / 2, jones_from_tfb.len_of(Axis(2)))); + let mut weight_to_fb = Array2::default(jones_to_fb.dim()); + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + + assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 141.0 / 29.0); + assert_abs_diff_eq!(weight_to_fb[(0, 0)], 29.0); + + assert_abs_diff_eq!(jones_to_fb[(1, 0)], Jones::identity() * 314.0 / 48.0); + assert_abs_diff_eq!(weight_to_fb[(1, 0)], 48.0); +} + +#[test] +fn test_vis_average_non_uniform_weights_non_integral_array_shapes() { + // 2 timesteps, 3 channels, 1 baseline. + let jones_from_tfb = array![ + [ + [Jones::identity()], + [Jones::identity() * 2.0], + [Jones::identity() * 3.0] + ], + [ + [Jones::identity() * 4.0], + [Jones::identity() * 5.0], + [Jones::identity() * 6.0] + ] + ]; + let mut weight_from_tfb = array![[[2.0], [3.0], [5.0]], [[7.0], [11.0], [13.0]]]; + let mut jones_to_fb = Array2::default((2, jones_from_tfb.len_of(Axis(2)))); + let mut weight_to_fb = Array2::default(jones_to_fb.dim()); + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + + assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 91.0 / 23.0); + assert_abs_diff_eq!(weight_to_fb[(0, 0)], 23.0); + + assert_abs_diff_eq!(jones_to_fb[(1, 0)], Jones::identity() * 93.0 / 18.0); + assert_abs_diff_eq!(weight_to_fb[(1, 0)], 18.0); + + // Make the first channel's weights negative. + weight_from_tfb[(0, 0, 0)] *= -1.0; + weight_from_tfb[(1, 0, 0)] *= -1.0; + + vis_average( + jones_from_tfb.view(), + jones_to_fb.view_mut(), + weight_from_tfb.view(), + weight_to_fb.view_mut(), + &HashSet::new(), + ); + + assert_abs_diff_eq!(jones_to_fb[(0, 0)], Jones::identity() * 61.0 / 14.0); + assert_abs_diff_eq!(weight_to_fb[(0, 0)], 14.0); + + assert_abs_diff_eq!(jones_to_fb[(1, 0)], Jones::identity() * 93.0 / 18.0); + assert_abs_diff_eq!(weight_to_fb[(1, 0)], 18.0); +} diff --git a/src/beam/error.rs b/src/beam/error.rs index 833f1c7f..1a618af7 100644 --- a/src/beam/error.rs +++ b/src/beam/error.rs @@ -8,6 +8,14 @@ use thiserror::Error; #[derive(Error, Debug)] pub enum BeamError { + #[error("Tried to create a beam object, but MWA dipole delay information isn't available!")] + NoDelays, + + #[error( + "The specified MWA dipole delays aren't valid; there should be 16 values between 0 and 32" + )] + BadDelays, + #[error("The number of delays per tile ({delays}) didn't match the number of gains per tile ({gains})")] DelayGainsDimensionMismatch { delays: usize, gains: usize }, diff --git a/src/beam/mod.rs b/src/beam/mod.rs index 103ae21c..1f61ee91 100644 --- a/src/beam/mod.rs +++ b/src/beam/mod.rs @@ -215,6 +215,14 @@ impl Delays { Delays::Partial { .. } => (), } } + + // Parse user-provided dipole delays. + pub(crate) fn parse(delays: Vec) -> Result { + if delays.len() != 16 || delays.iter().any(|&v| v > 32) { + return Err(BeamError::BadDelays); + } + Ok(Delays::Partial(delays)) + } } /// A beam implementation that returns only identity Jones matrices for all beam diff --git a/src/bin/hyperdrive.rs b/src/bin/hyperdrive.rs index 9da09336..496360e2 100644 --- a/src/bin/hyperdrive.rs +++ b/src/bin/hyperdrive.rs @@ -4,340 +4,16 @@ //! The main hyperdrive binary. -use clap::{AppSettings, Parser}; -use log::info; - -use mwa_hyperdrive::HyperdriveError; - -// Add build-time information from the "built" crate. -include!(concat!(env!("OUT_DIR"), "/built.rs")); +use clap::Parser; fn main() { + // Run hyperdrive, only performing extra steps if it returns an error. + // // Stolen from BurntSushi. We don't return Result from main because it // prints the debug representation of the error. The code below prints the // "display" or human readable representation of the error. - if let Err(e) = cli() { + if let Err(e) = mwa_hyperdrive::Hyperdrive::parse().run() { eprintln!("Error: {e}"); std::process::exit(1); } } - -#[derive(Parser)] -#[clap( - version, - author, - about = r#"Calibration software for the Murchison Widefield Array (MWA) radio telescope -Documentation: https://mwatelescope.github.io/mwa_hyperdrive -Source: https://github.com/MWATelescope/mwa_hyperdrive"# -)] -#[clap(global_setting(AppSettings::DeriveDisplayOrder))] -#[clap(disable_help_subcommand = true)] -#[clap(infer_subcommands = true)] -#[clap(propagate_version = true)] -enum Args { - #[clap(alias = "calibrate")] - #[clap( - about = r#"Perform direction-independent calibration on the input MWA data. -https://mwatelescope.github.io/mwa_hyperdrive/user/di_cal/intro.html"# - )] - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - DiCalibrate { - // Share the arguments that could be passed in via a parameter file. - #[clap(flatten)] - cli_args: Box, - - /// The verbosity of the program. Increase by specifying multiple times - /// (e.g. -vv). The default is to print only high-level information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - - /// Don't actually do calibration; just verify that arguments were - /// correctly ingested and print out high-level information. - #[clap(long)] - dry_run: bool, - }, - - #[clap(alias = "simulate-vis")] - #[clap(about = r#"Simulate visibilities of a sky-model source list. -https://mwatelescope.github.io/mwa_hyperdrive/user/vis_simulate/intro.html"#)] - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - VisSimulate { - #[clap(flatten)] - args: mwa_hyperdrive::VisSimulateArgs, - - /// The verbosity of the program. The default is to print high-level - /// information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - - /// Don't actually do any work; just verify that the input arguments - /// were correctly ingested and print out high-level information. - #[clap(long)] - dry_run: bool, - }, - - #[clap(alias = "subtract-vis")] - #[clap(about = "Subtract sky-model sources from supplied visibilities. -https://mwatelescope.github.io/mwa_hyperdrive/user/vis_subtract/intro.html")] - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - VisSubtract { - #[clap(flatten)] - args: mwa_hyperdrive::VisSubtractArgs, - - /// The verbosity of the program. The default is to print high-level - /// information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - - /// Don't actually do any work; just verify that the input arguments - /// were correctly ingested and print out high-level information. - #[clap(long)] - dry_run: bool, - }, - - #[clap(alias = "apply-solutions")] - #[clap(about = r#"Apply calibration solutions to input data. -https://mwatelescope.github.io/mwa_hyperdrive/user/solutions_apply/intro.html"#)] - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - SolutionsApply { - #[clap(flatten)] - args: mwa_hyperdrive::SolutionsApplyArgs, - - /// The verbosity of the program. The default is to print high-level - /// information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - - /// Don't actually do any work; just verify that the input arguments - /// were correctly ingested and print out high-level information. - #[clap(long)] - dry_run: bool, - }, - - #[clap(alias = "convert-solutions")] - #[clap(about = "Convert between calibration solution file formats.")] - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - SolutionsConvert { - #[clap(flatten)] - args: mwa_hyperdrive::SolutionsConvertArgs, - - /// The verbosity of the program. Increase by specifying multiple times - /// (e.g. -vv). The default is to print only high-level information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - }, - - #[clap(alias = "plot-solutions")] - #[clap( - about = "Plot calibration solutions. Only available if compiled with the \"plotting\" feature." - )] - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - SolutionsPlot { - #[clap(flatten)] - args: mwa_hyperdrive::SolutionsPlotArgs, - - /// The verbosity of the program. Increase by specifying multiple times - /// (e.g. -vv). The default is to print only high-level information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - }, - - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - SrclistByBeam { - #[clap(flatten)] - args: mwa_hyperdrive::SrclistByBeamArgs, - - /// The verbosity of the program. The default is to print high-level - /// information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - }, - - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - SrclistConvert { - #[clap(flatten)] - args: mwa_hyperdrive::SrclistConvertArgs, - - /// The verbosity of the program. The default is to print high-level - /// information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - }, - - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - SrclistShift { - #[clap(flatten)] - args: mwa_hyperdrive::SrclistShiftArgs, - - /// The verbosity of the program. The default is to print high-level - /// information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - }, - - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - SrclistVerify { - #[clap(flatten)] - args: mwa_hyperdrive::SrclistVerifyArgs, - - /// The verbosity of the program. The default is to print high-level - /// information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - }, - - #[clap(arg_required_else_help = true)] - #[clap(infer_long_args = true)] - DipoleGains { - #[clap(flatten)] - args: mwa_hyperdrive::DipoleGainsArgs, - - /// The verbosity of the program. The default is to print high-level - /// information. - #[clap(short, long, parse(from_occurrences))] - verbosity: u8, - }, -} - -/// Run `hyperdrive`. -fn cli() -> Result<(), HyperdriveError> { - // Get the command-line arguments. - let args = Args::parse(); - - // Set up logging. - let (verbosity, sub_command) = match &args { - Args::DiCalibrate { verbosity, .. } => (verbosity, "di-calibrate"), - Args::VisSimulate { verbosity, .. } => (verbosity, "vis-simulate"), - Args::VisSubtract { verbosity, .. } => (verbosity, "vis-subtract"), - Args::SolutionsApply { verbosity, .. } => (verbosity, "solutions-apply"), - Args::SolutionsConvert { verbosity, .. } => (verbosity, "solutions-convert"), - Args::SolutionsPlot { verbosity, .. } => (verbosity, "solutions-plot"), - Args::SrclistByBeam { verbosity, .. } => (verbosity, "srclist-by-beam"), - Args::SrclistConvert { verbosity, .. } => (verbosity, "srclist-convert"), - Args::SrclistShift { verbosity, .. } => (verbosity, "srclist-shift"), - Args::SrclistVerify { verbosity, .. } => (verbosity, "srclist-verify"), - Args::DipoleGains { verbosity, .. } => (verbosity, "dipole-gains"), - }; - setup_logging(*verbosity).expect("Failed to initialise logging."); - - // Print the version of hyperdrive and its build-time information. - info!("hyperdrive {} {}", sub_command, env!("CARGO_PKG_VERSION")); - display_build_info(); - - match Args::parse() { - Args::DiCalibrate { - cli_args, - verbosity: _, - dry_run, - } => { - cli_args.run(dry_run)?; - } - - Args::VisSimulate { - args, - verbosity: _, - dry_run, - } => { - args.run(dry_run)?; - } - - Args::VisSubtract { - args, - verbosity: _, - dry_run, - } => { - args.run(dry_run)?; - } - - Args::SolutionsApply { - args, - verbosity: _, - dry_run, - } => { - args.run(dry_run)?; - } - - Args::SolutionsConvert { args, verbosity: _ } => { - args.run()?; - } - - Args::SolutionsPlot { args, verbosity: _ } => { - args.run()?; - } - - // Source list utilities. - Args::SrclistByBeam { args, .. } => args.run()?, - Args::SrclistConvert { args, .. } => args.run()?, - Args::SrclistShift { args, .. } => args.run()?, - Args::SrclistVerify { args, .. } => args.run()?, - - // Misc. utilities. - Args::DipoleGains { args, .. } => args.run()?, - } - - info!("hyperdrive {} complete.", sub_command); - Ok(()) -} - -/// Activate a logger. All log messages are put onto `stdout`. `env_logger` -/// automatically only uses colours and fancy symbols if we're on a tty (e.g. a -/// terminal); piped output will be formatted sensibly. Source code lines are -/// displayed in log messages when verbosity >= 3. -fn setup_logging(verbosity: u8) -> Result<(), log::SetLoggerError> { - let mut builder = env_logger::Builder::from_default_env(); - builder.target(env_logger::Target::Stdout); - builder.format_target(false); - match verbosity { - 0 => builder.filter_level(log::LevelFilter::Info), - 1 => builder.filter_level(log::LevelFilter::Debug), - 2 => builder.filter_level(log::LevelFilter::Trace), - _ => { - builder.filter_level(log::LevelFilter::Trace); - builder.format(|buf, record| { - use std::io::Write; - - // TODO: Add colours. - let timestamp = buf.timestamp(); - let level = record.level(); - let target = record.target(); - let line = record.line().unwrap_or(0); - let message = record.args(); - - writeln!(buf, "[{timestamp} {level} {target}:{line}] {message}") - }) - } - }; - builder.init(); - - Ok(()) -} - -/// Write many info-level log lines of how this executable was compiled. -fn display_build_info() { - let dirty = match GIT_DIRTY { - Some(true) => " (dirty)", - _ => "", - }; - match GIT_COMMIT_HASH_SHORT { - Some(hash) => { - info!("Compiled on git commit hash: {hash}{dirty}"); - } - None => info!("Compiled on git commit hash: "), - } - if let Some(hr) = GIT_HEAD_REF { - info!(" git head ref: {}", hr); - } - info!(" {}", BUILT_TIME_UTC); - info!(" with compiler {}", RUSTC_VERSION); - info!(""); -} diff --git a/src/cli/common/beam/mod.rs b/src/cli/common/beam/mod.rs new file mode 100644 index 00000000..eabcad51 --- /dev/null +++ b/src/cli/common/beam/mod.rs @@ -0,0 +1,190 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#[cfg(test)] +mod tests; + +use std::path::PathBuf; + +use clap::Parser; +use log::debug; +use ndarray::Array2; +use serde::{Deserialize, Serialize}; + +use super::{InfoPrinter, Warn}; +use crate::{ + beam::{create_fee_beam_object, create_no_beam_object, Beam, BeamError, Delays}, + io::read::VisInputType, +}; + +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(crate) struct BeamArgs { + /// The path to the HDF5 MWA FEE beam file. If not specified, this must be + /// provided by the MWA_BEAM_FILE environment variable. + #[clap(long, help_heading = "BEAM")] + pub(crate) beam_file: Option, + + /// Pretend that all MWA dipoles are alive and well, ignoring whatever is in + /// the metafits file. + #[clap(long, help_heading = "BEAM")] + #[serde(default)] + pub(crate) unity_dipole_gains: bool, + + /// If specified, use these dipole delays for the MWA pointing. e.g. 0 1 2 3 + /// 0 1 2 3 0 1 2 3 0 1 2 3 + #[clap(long, multiple_values(true), help_heading = "BEAM")] + pub(crate) delays: Option>, + + /// Don't apply a beam response when generating a sky model. The default is + /// to use the FEE beam. + #[clap(long, help_heading = "BEAM")] + #[serde(default)] + pub(crate) no_beam: bool, +} + +impl BeamArgs { + pub(crate) fn merge(self, other: Self) -> Self { + Self { + beam_file: self.beam_file.or(other.beam_file), + unity_dipole_gains: self.unity_dipole_gains || other.unity_dipole_gains, + delays: self.delays.or(other.delays), + no_beam: self.no_beam || other.no_beam, + } + } + + pub(crate) fn parse( + self, + total_num_tiles: usize, + data_dipole_delays: Option, + dipole_gains: Option>, + input_data_type: Option, + ) -> Result, BeamError> { + let Self { + beam_file, + unity_dipole_gains, + delays: user_dipole_delays, + no_beam, + } = self; + + let mut printer = InfoPrinter::new("Beam info".into()); + debug!("Beam file: {beam_file:?}"); + + let user_dipole_delays = match user_dipole_delays { + // We have user-provided delays; check that they're are sensible, + // regardless of whether we actually need them. + Some(d) => Some(Delays::parse(d)?), + None => None, + }; + + let mut num_tiles_with_dead_dipoles = None; + let beam: Box = if no_beam { + printer.push_line("Not using any beam responses".into()); + create_no_beam_object(total_num_tiles) + } else { + printer.push_line("Type: FEE".into()); + let mut dipole_delays = user_dipole_delays + .or(data_dipole_delays) + .ok_or(BeamError::NoDelays)?; + let dipole_gains = if unity_dipole_gains { + None + } else { + // If we don't have dipole gains from the input data, then + // we issue a warning that we must assume no dead dipoles. + if dipole_gains.is_none() { + match input_data_type { + Some(VisInputType::MeasurementSet) => { + [ + "Measurement sets cannot supply dead dipole information.".into(), + "Without a metafits file, we must assume all dipoles are alive.".into(), + "This will make beam Jones matrices inaccurate in sky-model generation.".into() + ].warn() + } + Some(VisInputType::Uvfits) => { + [ + "uvfits files cannot supply dead dipole information.".into(), + "Without a metafits file, we must assume all dipoles are alive.".into(), + "This will make beam Jones matrices inaccurate in sky-model generation.".into() + ].warn() + } + Some(VisInputType::Raw) => { + unreachable!("Raw data inputs always specify dipole gains") + } + None => (), + } + } + dipole_gains + }; + let ideal_delays = dipole_delays.get_ideal_delays(); + if dipole_gains.is_none() { + // If we don't have dipole gains, we must assume all dipoles are + // "alive". But, if any dipole delays are 32, then the beam code + // will still ignore those dipoles. So use ideal dipole delays + // for all tiles. + + // Warn the user if they wanted unity dipole gains but the ideal + // dipole delays contain 32. + if unity_dipole_gains && ideal_delays.iter().any(|&v| v == 32) { + "Some ideal dipole delays are 32; these dipoles will not have unity gains" + .warn() + } + dipole_delays.set_to_ideal_delays(); + } + + { + let d = ideal_delays; + printer.push_block(vec![ + format!( + "Ideal dipole delays: [{:>2} {:>2} {:>2} {:>2}", + d[0], d[1], d[2], d[3] + ) + .into(), + format!( + " {:>2} {:>2} {:>2} {:>2}", + d[4], d[5], d[6], d[7] + ) + .into(), + format!( + " {:>2} {:>2} {:>2} {:>2}", + d[8], d[9], d[10], d[11] + ) + .into(), + format!( + " {:>2} {:>2} {:>2} {:>2}]", + d[12], d[13], d[14], d[15] + ) + .into(), + ]); + } + + if let Some(dipole_gains) = dipole_gains.as_ref() { + num_tiles_with_dead_dipoles = Some( + dipole_gains + .outer_iter() + .filter(|tile_dipole_gains| { + tile_dipole_gains.iter().any(|g| g.abs() < f64::EPSILON) + }) + .count(), + ); + } + + create_fee_beam_object(beam_file, total_num_tiles, dipole_delays, dipole_gains)? + }; + if let Some(f) = beam.get_beam_file() { + printer.push_line(format!("File: {}", f.display()).into()); + } + if let Some(num_tiles_with_dead_dipoles) = num_tiles_with_dead_dipoles { + printer.push_line( + format!( + "Using dead dipole information ({num_tiles_with_dead_dipoles} tiles affected)" + ) + .into(), + ); + } else { + printer.push_line("Assuming all dipoles are \"alive\"".into()); + } + + printer.display(); + Ok(beam) + } +} diff --git a/src/cli/common/beam/tests.rs b/src/cli/common/beam/tests.rs new file mode 100644 index 00000000..b0cb9827 --- /dev/null +++ b/src/cli/common/beam/tests.rs @@ -0,0 +1,93 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +use ndarray::array; + +use super::BeamArgs; +use crate::beam::{BeamError::BadDelays, BeamType}; + +#[test] +fn test_handle_delays() { + let args = BeamArgs { + // only 3 delays instead of 16 expected + delays: Some((0..3).collect::>()), + no_beam: false, + ..Default::default() + }; + + let result = args.parse(1, None, None, None); + assert!(result.is_err()); + assert!(matches!(result, Err(BadDelays))); + + let args = BeamArgs { + // delays > 32 + delays: Some((20..36).collect::>()), + no_beam: false, + ..Default::default() + }; + let result = args.parse(1, None, None, None); + + assert!(result.is_err()); + assert!(matches!(result, Err(BadDelays))); + + let delays = (0..16).collect::>(); + let args = BeamArgs { + // delays > 32 + delays: Some(delays.clone()), + no_beam: false, + ..Default::default() + }; + let result = args.parse(1, None, None, None); + + assert!(result.is_ok(), "result={:?} not Ok", result.err().unwrap()); + + let fee_beam = result.unwrap(); + assert_eq!(fee_beam.get_beam_type(), BeamType::FEE); + let beam_delays = fee_beam + .get_dipole_delays() + .expect("expected some delays to be provided from the FEE beam!"); + // Each row of the delays should be the same as the 16 input values. + for row in beam_delays.outer_iter() { + assert_eq!(row.as_slice().unwrap(), delays); + } +} + +#[test] +fn test_unity_dipole_gains() { + let args = BeamArgs { + delays: Some(vec![0; 16]), + no_beam: false, + ..Default::default() + }; + + // Let one of the dipoles be dead. + let dipole_gains = array![ + [1.0; 16], + [1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + ]; + let beam = args.parse(2, None, Some(dipole_gains), None).unwrap(); + assert_eq!(beam.get_beam_type(), BeamType::FEE); + let beam_gains = beam.get_dipole_gains(); + + // We should find that not all dipole gains are 1. + assert!(!beam_gains.iter().all(|g| (*g - 1.0).abs() < f64::EPSILON)); + + // Now ignore dead dipoles. + let args = BeamArgs { + delays: Some(vec![0; 16]), + no_beam: false, + unity_dipole_gains: true, + ..Default::default() + }; + + let dipole_gains = array![[1.0; 16], [1.0; 16]]; + let beam = args.parse(2, None, Some(dipole_gains), None).unwrap(); + assert_eq!(beam.get_beam_type(), BeamType::FEE); + let beam_gains = beam.get_dipole_gains(); + + // We expect all gains to be 1s, as we're ignoring dead dipoles. + assert!(beam_gains.iter().all(|g| (*g - 1.0).abs() < f64::EPSILON)); + // Verify that there are no dead dipoles in the delays. + assert!(beam.get_dipole_delays().unwrap().iter().all(|d| *d != 32)); +} diff --git a/src/filenames.rs b/src/cli/common/input_vis/filenames.rs similarity index 80% rename from src/filenames.rs rename to src/cli/common/input_vis/filenames.rs index 8304ff95..edb01749 100644 --- a/src/filenames.rs +++ b/src/cli/common/input_vis/filenames.rs @@ -4,46 +4,46 @@ //! Code to parse filenames. //! -//! [InputDataTypes] is the struct to be used here. It is constructed from a +//! [`InputDataTypes`] is the struct to be used here. It is constructed from a //! slice of string filenames, and it enforces things like allowing only one //! metafits file to be present. -use std::fs::OpenOptions; -use std::path::{Path, PathBuf}; +use std::{ + fs::OpenOptions, + path::{Path, PathBuf}, +}; use regex::{Regex, RegexBuilder}; -use thiserror::Error; use vec1::Vec1; -use crate::io::{get_all_matches_from_glob, read::VisReadError, GlobError}; +use super::InputVisArgsError; +use crate::io::get_all_matches_from_glob; + +pub(super) const GPUBOX_REGEX: &str = r".*gpubox.*\.fits$"; +pub(super) const MWAX_REGEX: &str = r"\d{10}_\d{8}(.)?\d{6}_ch\d{3}_\d{3}\.fits$"; lazy_static::lazy_static! { // gpubox files should not be renamed in any way! This includes the case of // the letters in the filename. mwalib should complain if this is not the // case. static ref RE_GPUBOX: Regex = - RegexBuilder::new(r".*gpubox.*\.fits$") + RegexBuilder::new(GPUBOX_REGEX) .case_insensitive(false).build().unwrap(); static ref RE_MWAX: Regex = - RegexBuilder::new(r"\d{10}_\d{8}(.)?\d{6}_ch\d{3}_\d{3}\.fits$") + RegexBuilder::new(MWAX_REGEX) .case_insensitive(false).build().unwrap(); } -pub(super) const SUPPORTED_INPUT_FILE_COMBINATIONS: &str = - "gpubox + metafits (+ mwaf)\nms (+ metafits)\nuvfits (+ metafits)"; - -pub(super) const SUPPORTED_CALIBRATED_INPUT_FILE_COMBINATIONS: &str = - "ms (+ metafits)\nuvfits (+ metafits)"; - #[derive(Debug)] /// Supported input data types for calibration. -pub(crate) struct InputDataTypes { - pub(crate) metafits: Option>, - pub(crate) gpuboxes: Option>, - pub(crate) mwafs: Option>, - pub(crate) ms: Option>, - pub(crate) uvfits: Option>, +pub(super) struct InputDataTypes { + pub(super) metafits: Option>, + pub(super) gpuboxes: Option>, + pub(super) mwafs: Option>, + pub(super) ms: Option>, + pub(super) uvfits: Option>, + pub(super) solutions: Option>, } // The same as `InputDataTypes`, but all types are allowed to be multiples. This @@ -55,12 +55,13 @@ struct InputDataTypesTemp { mwafs: Vec, ms: Vec, uvfits: Vec, + solutions: Vec, } impl InputDataTypes { /// From an input collection of filename or glob strings, disentangle the - /// file types and populate [InputDataTypes]. - pub(super) fn new(files: &[String]) -> Result { + /// file types and populate [`InputDataTypes`]. + pub(super) fn parse(files: &[String]) -> Result { let mut temp = InputDataTypesTemp::default(); for file in files.iter().map(|f| f.as_str()) { @@ -93,13 +94,18 @@ impl InputDataTypes { } else { Some(Vec1::try_from_vec(temp.uvfits).unwrap()) }, + solutions: if temp.solutions.is_empty() { + None + } else { + Some(Vec1::try_from_vec(temp.solutions).unwrap()) + }, }) } } -fn exists_and_is_readable(file: &Path) -> Result<(), InputFileError> { +fn exists_and_is_readable(file: &Path) -> Result<(), InputVisArgsError> { if !file.exists() { - return Err(InputFileError::DoesNotExist(file.display().to_string())); + return Err(InputVisArgsError::DoesNotExist(file.display().to_string())); } match OpenOptions::new() .read(true) @@ -108,9 +114,9 @@ fn exists_and_is_readable(file: &Path) -> Result<(), InputFileError> { { Ok(_) => (), Err(std::io::ErrorKind::PermissionDenied) => { - return Err(InputFileError::CouldNotRead(file.display().to_string())) + return Err(InputVisArgsError::CouldNotRead(file.display().to_string())) } - Err(e) => return Err(InputFileError::IO(file.display().to_string(), e.into())), + Err(e) => return Err(InputVisArgsError::IO(file.display().to_string(), e.into())), } Ok(()) @@ -120,20 +126,20 @@ fn exists_and_is_readable(file: &Path) -> Result<(), InputFileError> { // what type it is, and add it to the provided file types struct. If the file // string doesn't exist, then check if it's a glob string, and act recursively // on the glob results. -fn file_checker(file_types: &mut InputDataTypesTemp, file: &str) -> Result<(), InputFileError> { +fn file_checker(file_types: &mut InputDataTypesTemp, file: &str) -> Result<(), InputVisArgsError> { let file_pb = PathBuf::from(file); // Is this a file, and is it readable? match exists_and_is_readable(&file_pb) { Ok(_) => (), // If this string isn't a file, maybe it's a glob. - Err(InputFileError::DoesNotExist(f)) => { + Err(InputVisArgsError::DoesNotExist(f)) => { match get_all_matches_from_glob(file) { Ok(glob_results) => { // If there were no glob matches, then just return the // original error (the file does not exist). if glob_results.is_empty() { - return Err(InputFileError::DoesNotExist(f)); + return Err(InputVisArgsError::DoesNotExist(f)); } // Iterate over all glob results, adding them to the file @@ -145,7 +151,7 @@ fn file_checker(file_types: &mut InputDataTypesTemp, file: &str) -> Result<(), I } // Propagate all other errors. - Err(e) => return Err(InputFileError::from(e)), + Err(e) => return Err(InputVisArgsError::from(e)), } } @@ -153,7 +159,7 @@ fn file_checker(file_types: &mut InputDataTypesTemp, file: &str) -> Result<(), I Err(e) => return Err(e), }; if file.contains("_metafits_ppds.fits") { - return Err(InputFileError::PpdMetafitsUnsupported(file.to_string())); + return Err(InputVisArgsError::PpdMetafitsUnsupported(file.to_string())); } match ( file.ends_with(".metafits") || file.ends_with("_metafits.fits"), @@ -169,33 +175,21 @@ fn file_checker(file_types: &mut InputDataTypesTemp, file: &str) -> Result<(), I (false, false, false, true, false, false) => file_types.mwafs.push(file_pb), (false, false, false, false, true, false) => file_types.ms.push(file_pb), (false, false, false, false, false, true) => file_types.uvfits.push(file_pb), - _ => return Err(InputFileError::NotRecognised(file.to_string())), + _ => { + // We don't recognise this file as a "vis input" type. Try to match + // a calibration solutions type. + if file.ends_with(".fits") || file.ends_with(".bin") { + file_types.solutions.push(file_pb); + } else { + // If that doesn't work, bail out. + return Err(InputVisArgsError::NotRecognised(file.to_string())); + } + } } Ok(()) } -#[derive(Debug, Error)] -pub enum InputFileError { - #[error("Specified file does not exist: {0}")] - DoesNotExist(String), - - #[error("Could not read specified file: {0}")] - CouldNotRead(String), - - #[error("The specified file '{0}' is a \"PPDs metafits\" and is not supported. Please use a newer metafits file.")] - PpdMetafitsUnsupported(String), - - #[error("The specified file '{0}' was not a recognised file type.")] - NotRecognised(String), - - #[error(transparent)] - Glob(#[from] GlobError), - - #[error("IO error when attempting to read file '{0}': {1}")] - IO(String, std::io::Error), -} - #[cfg(test)] mod tests { use super::*; @@ -255,7 +249,7 @@ mod tests { let result = exists_and_is_readable(&PathBuf::from("/does/not/exist.metafits")); assert!(result.is_err()); match result { - Err(InputFileError::DoesNotExist(_)) => (), + Err(InputVisArgsError::DoesNotExist(_)) => (), Err(e) => panic!("Unexpected error kind! {e:?}"), Ok(_) => unreachable!(), } @@ -284,7 +278,7 @@ mod tests { let result = exists_and_is_readable(tmp_file.path()); assert!(result.is_err()); match result { - Err(InputFileError::CouldNotRead(_)) => (), + Err(InputVisArgsError::CouldNotRead(_)) => (), Err(e) => panic!("Unexpected error kind! {e:?}"), Ok(_) => unreachable!(), } @@ -309,7 +303,7 @@ mod tests { let result = file_checker(&mut input, "/tmp"); assert!(result.is_err()); match result { - Err(InputFileError::NotRecognised(_)) => (), + Err(InputVisArgsError::NotRecognised(_)) => (), Err(e) => panic!("Unexpected error kind! {e:?}"), Ok(_) => unreachable!(), } @@ -321,7 +315,7 @@ mod tests { let result = file_checker(&mut input, "/does/not/exist.metafits"); assert!(result.is_err()); match result { - Err(InputFileError::DoesNotExist(_)) => (), + Err(InputVisArgsError::DoesNotExist(_)) => (), Err(e) => panic!("Unexpected error kind! {e:?}"), Ok(_) => unreachable!(), } @@ -370,7 +364,7 @@ mod tests { let dir = make_new_dir(); let gpubox = make_legacy_gpubox(dir.path()); let result = file_checker(&mut input, gpubox.to_str().unwrap()); - assert!(result.is_ok()); + result.unwrap(); assert_eq!(input.gpuboxes.len(), 1); } diff --git a/src/cli/common/input_vis/mod.rs b/src/cli/common/input_vis/mod.rs new file mode 100644 index 00000000..6a201239 --- /dev/null +++ b/src/cli/common/input_vis/mod.rs @@ -0,0 +1,1268 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +mod filenames; +#[cfg(test)] +mod tests; + +use filenames::{InputDataTypes, GPUBOX_REGEX, MWAX_REGEX}; + +use std::{collections::HashSet, num::NonZeroUsize, path::PathBuf}; + +use clap::Parser; +use console::style; +use hifitime::Duration; +use itertools::Itertools; +use log::{ + debug, info, log_enabled, + Level::{Debug, Info}, +}; +use marlu::{ + constants::{FREQ_WEIGHT_FACTOR, TIME_WEIGHT_FACTOR}, + precession::precess_time, + LatLngHeight, +}; +use ndarray::Axis; +use serde::{Deserialize, Serialize}; +use vec1::Vec1; + +use super::{InfoPrinter, ARRAY_POSITION_HELP}; +use crate::{ + averaging::{ + channels_to_chanblocks, parse_freq_average_factor, parse_time_average_factor, + timesteps_to_timeblocks, AverageFactorError, + }, + cli::Warn, + constants::DEFAULT_MS_DATA_COL_NAME, + io::read::{ + pfb_gains::{PfbFlavour, DEFAULT_PFB_FLAVOUR, PFB_FLAVOURS}, + MsReader, RawDataCorrections, RawDataReader, UvfitsReader, VisInputType, VisRead, + }, + math::TileBaselineFlags, + params::InputVisParams, + CalibrationSolutions, +}; + +lazy_static::lazy_static! { + pub(super) static ref PFB_FLAVOUR_HELP: String = + format!("The 'flavour' of poly-phase filter bank corrections applied to raw MWA data. The default is '{}'. Valid flavours are: {}", DEFAULT_PFB_FLAVOUR, *PFB_FLAVOURS); + + pub(super) static ref MS_DATA_COL_NAME_HELP: String = + format!("If reading from a measurement set, this specifies the column to use in the main table containing visibilities. Default: {DEFAULT_MS_DATA_COL_NAME}"); + + static ref SUPPORTED_INPUT_FILE_TYPES: String = format!(r#" + metafits: .metafits, _metafits.fits + measurement sets: .ms + uvfits files: .uvfits + gpubox files (regex): {GPUBOX_REGEX} + MWAX files (regex): {MWAX_REGEX}"#); +} + +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(crate) struct InputVisArgs { + /// Paths to input data files. These can include a metafits file, a + /// calibration solutions file, gpubox files, mwaf files, a measurement set, + /// and/or a uvfits file. + #[clap( + short = 'd', + long = "data", + multiple_values(true), + help_heading = "INPUT DATA" + )] + pub(crate) files: Option>, + + /// The timesteps to use from the input data. Any input will be ascendingly + /// sorted. No duplicates are allowed. The default is to use all unflagged + /// timesteps. e.g. The following skips the first two timesteps and use the + /// following three: --timesteps 2 3 4, --timesteps {2..4} (bash shell + /// syntax) + #[clap(long, multiple_values(true), help_heading = "INPUT DATA")] + pub(crate) timesteps: Option>, + + /// Use all timesteps in the data, including flagged ones. The default is to + /// use all unflagged timesteps. + #[clap(long, conflicts_with("timesteps"), help_heading = "INPUT DATA")] + #[serde(default)] + pub(crate) use_all_timesteps: bool, + + #[clap( + long, help = ARRAY_POSITION_HELP.as_str(), help_heading = "INPUT DATA", + number_of_values = 3, + allow_hyphen_values = true, + value_names = &["LONG_DEG", "LAT_DEG", "HEIGHT_M"] + )] + pub(crate) array_position: Option>, + + /// Don't read autocorrelations from the input data. + #[clap(long, help_heading = "INPUT DATA")] + #[serde(default)] + pub(crate) no_autos: bool, + + /// Use this value as the DUT1 [seconds]. + #[clap(long, help_heading = "INPUT DATA")] + #[serde(default)] + pub(crate) dut1: Option, + + /// Ignore the weights accompanying the visibilities. Internally, this will + /// set all weights to 1, meaning all visibilities are equal, including + /// those that would be otherwise flagged. + #[clap(long, help_heading = "INPUT DATA")] + #[serde(default)] + pub(crate) ignore_weights: bool, + + /// Use a DUT1 value of 0 seconds rather than what is in the input data. + #[clap(long, conflicts_with("dut1"), help_heading = "INPUT DATA")] + #[serde(default)] + pub(crate) ignore_dut1: bool, + + #[clap(long, help = MS_DATA_COL_NAME_HELP.as_str(), help_heading = "INPUT DATA (MS)")] + pub(crate) ms_data_column_name: Option, + + #[clap(long, help = PFB_FLAVOUR_HELP.as_str(), help_heading = "INPUT DATA (RAW)")] + pub(crate) pfb_flavour: Option, + + /// When reading in raw MWA data, don't apply digital gains. + #[clap(long, help_heading = "INPUT DATA (RAW)")] + #[serde(default)] + pub(crate) no_digital_gains: bool, + + /// When reading in raw MWA data, don't apply cable length corrections. Note + /// that some data may have already had the correction applied before it was + /// written. + #[clap(long, help_heading = "INPUT DATA (RAW)")] + #[serde(default)] + pub(crate) no_cable_length_correction: bool, + + /// When reading in raw MWA data, don't apply geometric corrections. Note + /// that some data may have already had the correction applied before it was + /// written. + #[clap(long, help_heading = "INPUT DATA (RAW)")] + #[serde(default)] + pub(crate) no_geometric_correction: bool, + + /// Additional tiles to be flagged. These values correspond to either the + /// values in the "Antenna" column of HDU 2 in the metafits file (e.g. 0 3 + /// 127), or the "TileName" (e.g. Tile011). + #[clap(long, multiple_values(true), help_heading = "INPUT DATA (FLAGGING)")] + pub(crate) tile_flags: Option>, + + /// If specified, pretend that all tiles are unflagged in the input data. + #[clap(long, help_heading = "INPUT DATA (FLAGGING)")] + #[serde(default)] + pub(crate) ignore_input_data_tile_flags: bool, + + /// If specified, pretend all fine channels in the input data are unflagged. + /// Note that this does not unset any negative weights; visibilities + /// associated with negative weights are still considered flagged even if + /// we're ignoring input data fine channel flags. + #[clap(long, help_heading = "INPUT DATA (FLAGGING)")] + #[serde(default)] + pub(crate) ignore_input_data_fine_channel_flags: bool, + + /// The fine channels to be flagged in each coarse channel. e.g. 0 1 16 30 + /// 31 are typical for 40 kHz data. If this is not specified, it defaults + /// to flagging 80 kHz for raw data (or as close to this as possible) at the + /// edges, as well as the centre channel for non-MWAX data. Other visibility + /// file formats do not use this by default. + #[clap(long, multiple_values(true), help_heading = "INPUT DATA (FLAGGING)")] + pub(crate) fine_chan_flags_per_coarse_chan: Option>, + + /// The fine channels to be flagged across the whole observation band. e.g. + /// 0 767 are the first and last fine channels for 40 kHz data. These flags + /// are applied *before* any averaging is performed. + #[clap(long, multiple_values(true), help_heading = "INPUT DATA (FLAGGING)")] + pub(crate) fine_chan_flags: Option>, + + /// The number of timesteps to average together while reading in data. The + /// value must be a multiple of the input data's time resolution, except if + /// this is 0, in which case all timesteps are averaged together. A target + /// resolution (e.g. 8s) may be used instead, in which case the specified + /// resolution must be a multiple of the input data's resolution. The + /// default is no averaging, i.e. a value of 1. Examples: If the input data + /// is in 0.5s resolution and this variable is 4, then we average 2s worth + /// of data together before performing work on it. If the variable is + /// instead 4s, then 8 timesteps are averaged together. + #[clap(long, help_heading = "INPUT DATA (AVERAGING)")] + pub(crate) time_average: Option, + + /// The number of fine-frequency channels to average together while reading + /// in data. The value must be a multiple of the input data's freq. + /// resolution, except if this is 0, in which case all channels are averaged + /// together. A target resolution (e.g. 80kHz) may be used instead, in which + /// case the specified resolution must be a multiple of the input data's + /// resolution. The default is no averaging, i.e. a value of 1. Examples: If + /// the input data is in 20kHz resolution and this variable was 2, then we + /// average 40kHz worth of data together before performing work with it. If + /// the variable is instead 80kHz, then 4 channels are averaged together. + #[clap(short, long, help_heading = "INPUT DATA (AVERAGING)")] + pub(crate) freq_average: Option, +} + +impl InputVisArgs { + pub(crate) fn merge(self, other: Self) -> Self { + InputVisArgs { + files: self.files.or(other.files), + timesteps: self.timesteps.or(other.timesteps), + use_all_timesteps: self.use_all_timesteps || other.use_all_timesteps, + array_position: self.array_position.or(other.array_position), + no_autos: self.no_autos || other.no_autos, + dut1: self.dut1.or(other.dut1), + ignore_weights: self.ignore_weights || other.ignore_weights, + ignore_dut1: self.ignore_dut1 || other.ignore_dut1, + ms_data_column_name: self.ms_data_column_name.or(other.ms_data_column_name), + pfb_flavour: self.pfb_flavour.or(other.pfb_flavour), + no_digital_gains: self.no_digital_gains || other.no_digital_gains, + no_cable_length_correction: self.no_cable_length_correction + || other.no_cable_length_correction, + no_geometric_correction: self.no_geometric_correction || other.no_geometric_correction, + tile_flags: self.tile_flags.or(other.tile_flags), + ignore_input_data_tile_flags: self.ignore_input_data_tile_flags + || other.ignore_input_data_tile_flags, + ignore_input_data_fine_channel_flags: self.ignore_input_data_fine_channel_flags + || other.ignore_input_data_fine_channel_flags, + fine_chan_flags_per_coarse_chan: self + .fine_chan_flags_per_coarse_chan + .or(other.fine_chan_flags_per_coarse_chan), + fine_chan_flags: self.fine_chan_flags.or(other.fine_chan_flags), + time_average: self.time_average.or(other.time_average), + freq_average: self.freq_average.or(other.freq_average), + } + } + + pub(crate) fn parse(self, operation_verb: &str) -> Result { + let InputVisArgs { + files, + timesteps, + use_all_timesteps, + array_position, + no_autos, + dut1, + ignore_weights, + ignore_dut1, + ms_data_column_name, + pfb_flavour, + no_digital_gains, + no_cable_length_correction, + no_geometric_correction, + tile_flags, + ignore_input_data_tile_flags, + ignore_input_data_fine_channel_flags, + fine_chan_flags_per_coarse_chan, + fine_chan_flags, + time_average, + freq_average, + } = self; + + // If the user supplied the array position, unpack it here. + let array_position = match array_position { + Some(v) => { + if v.len() != 3 { + return Err(InputVisArgsError::BadArrayPosition { pos: v }); + } + Some(LatLngHeight { + longitude_rad: v[0].to_radians(), + latitude_rad: v[1].to_radians(), + height_metres: v[2], + }) + } + None => None, + }; + + // Handle input data. We expect one of three possibilities: + // - gpubox files, a metafits file (and maybe mwaf files), + // - a measurement set (and maybe a metafits file), or + // - uvfits files. + // If none or multiple of these possibilities are met, then we must fail. + let InputDataTypes { + metafits, + gpuboxes, + mwafs, + ms, + uvfits, + solutions, + } = match files { + Some(strings) => InputDataTypes::parse(&strings)?, + None => return Err(InputVisArgsError::NoInputData), + }; + let mut data_printer = InfoPrinter::new(format!("{operation_verb} data").into()); + let mut input_files_block = vec![]; + let mut raw_data_corrections_block = vec![]; + let mut vis_reader: Box = match (metafits, gpuboxes, mwafs, ms, uvfits) { + // Valid input for reading raw data. + (Some(meta), Some(gpuboxes), mwafs, None, None) => { + // Ensure that there's only one metafits. + let meta = if meta.len() > 1 { + return Err(InputVisArgsError::MultipleMetafits(meta)); + } else { + meta.to_vec().swap_remove(0) + }; + + debug!("gpubox files: {:?}", &gpuboxes); + debug!("mwaf files: {:?}", &mwafs); + + let corrections = RawDataCorrections::new( + pfb_flavour.as_deref(), + !no_digital_gains, + !no_cable_length_correction, + !no_geometric_correction, + )?; + let raw_reader = RawDataReader::new( + &meta, + &gpuboxes, + mwafs.as_deref(), + corrections, + array_position, + )?; + let obs_context = raw_reader.get_obs_context(); + let obsid = obs_context + .obsid + .expect("Raw data inputs always have the obsid specified"); + + data_printer.overwrite_title(format!("{operation_verb} obsid {obsid}").into()); + input_files_block.push(format!("from {} gpubox files", gpuboxes.len()).into()); + input_files_block.push(format!("with metafits {}", meta.display()).into()); + + match raw_reader.get_flags() { + Some(flags) => { + let software_string = match flags.software_version.as_ref() { + Some(v) => format!("{} {}", flags.software, v), + None => flags.software.to_string(), + }; + input_files_block.push( + format!( + "with {} mwaf files ({})", + flags.gpubox_nums.len(), + software_string, + ) + .into(), + ); + if let Some(s) = flags.aoflagger_version.as_deref() { + info!(" AOFlagger version: {s}"); + } + if let Some(s) = flags.aoflagger_strategy.as_deref() { + info!(" AOFlagger strategy: {s}"); + } + } + None => "No mwaf files supplied".warn(), + } + + let raw_data_corrections = raw_reader + .get_raw_data_corrections() + .expect("raw reader always has data corrections"); + match raw_data_corrections.pfb_flavour { + PfbFlavour::None => { + raw_data_corrections_block.push("Not doing any PFB correction".into()) + } + PfbFlavour::Jake => raw_data_corrections_block + .push("Correcting PFB gains with 'Jake Jones' gains".into()), + PfbFlavour::Cotter2014 => raw_data_corrections_block + .push("Correcting PFB gains with 'Cotter 2014' gains".into()), + PfbFlavour::Empirical => raw_data_corrections_block + .push("Correcting PFB gains with 'RTS empirical' gains".into()), + PfbFlavour::Levine => raw_data_corrections_block + .push("Correcting PFB gains with 'Alan Levine' gains".into()), + } + if raw_data_corrections.digital_gains { + raw_data_corrections_block.push("Correcting digital gains".into()); + } else { + raw_data_corrections_block.push("Not correcting digital gains".into()); + } + if raw_data_corrections.cable_length { + raw_data_corrections_block.push("Correcting cable lengths".into()); + } else { + raw_data_corrections_block.push("Not correcting cable lengths".into()); + } + if raw_data_corrections.geometric { + raw_data_corrections_block + .push("Correcting geometric delays (if necessary)".into()); + } else { + raw_data_corrections_block.push("Not correcting geometric delays".into()); + } + + Box::new(raw_reader) + } + + // Valid input for reading a measurement set. + (meta, None, None, Some(ms), None) => { + // Only one MS is supported at the moment. + let ms: PathBuf = if ms.len() > 1 { + return Err(InputVisArgsError::MultipleMeasurementSets(ms)); + } else { + ms.into_vec().swap_remove(0) + }; + + // Ensure that there's only one metafits. + let meta: Option = match meta { + None => None, + Some(meta) => { + if meta.len() > 1 { + return Err(InputVisArgsError::MultipleMetafits(meta)); + } else { + Some(meta.into_vec().swap_remove(0)) + } + } + }; + + let ms_string = ms.display().to_string(); + let ms_reader = + MsReader::new(ms, ms_data_column_name, meta.as_deref(), array_position)?; + let obs_context = ms_reader.get_obs_context(); + + if let Some(o) = obs_context.obsid { + data_printer.overwrite_title(format!("{operation_verb} obsid {o}").into()); + input_files_block.push(format!("from measurement set {}", ms_string).into()); + } else { + data_printer.overwrite_title( + format!("{operation_verb} measurement set {}", ms_string).into(), + ); + }; + if let Some(meta) = meta.as_ref() { + input_files_block.push(format!("with metafits {}", meta.display()).into()); + } + + Box::new(ms_reader) + } + + // Valid input for reading uvfits files. + (meta, None, None, None, Some(uvfits)) => { + // Only one uvfits is supported at the moment. + let uvfits: PathBuf = if uvfits.len() > 1 { + return Err(InputVisArgsError::MultipleUvfits(uvfits)); + } else { + uvfits.into_vec().swap_remove(0) + }; + + // Ensure that there's only one metafits. + let meta: Option = match meta { + None => None, + Some(meta) => { + if meta.len() > 1 { + return Err(InputVisArgsError::MultipleMetafits(meta)); + } else { + Some(meta.into_vec().swap_remove(0)) + } + } + }; + + let uvfits_string = uvfits.display().to_string(); + let uvfits_reader = UvfitsReader::new(uvfits, meta.as_deref(), array_position)?; + let obs_context = uvfits_reader.get_obs_context(); + + if let Some(o) = obs_context.obsid { + data_printer.overwrite_title(format!("{operation_verb} obsid {o}").into()); + input_files_block.push(format!("from uvfits {}", uvfits_string).into()); + } else { + data_printer.overwrite_title( + format!("{operation_verb} uvfits {}", uvfits_string).into(), + ); + }; + if let Some(meta) = meta { + input_files_block.push(format!("with metafits {}", meta.display()).into()); + } + + Box::new(uvfits_reader) + } + + // The following matches are for invalid combinations of input + // files. Make an error message for the user. + (Some(_), _, None, None, None) => { + let msg = "Received only a metafits file; a uvfits file, a measurement set or gpubox files are required."; + return Err(InputVisArgsError::InvalidDataInput(msg)); + } + (Some(_), _, Some(_), None, None) => { + let msg = + "Received only a metafits file and mwaf files; gpubox files are required."; + return Err(InputVisArgsError::InvalidDataInput(msg)); + } + (None, Some(_), _, None, None) => { + let msg = "Received gpuboxes without a metafits file; this is not supported."; + return Err(InputVisArgsError::InvalidDataInput(msg)); + } + (None, None, Some(_), None, None) => { + let msg = "Received mwaf files without gpuboxes and a metafits file; this is not supported."; + return Err(InputVisArgsError::InvalidDataInput(msg)); + } + (_, Some(_), _, Some(_), None) => { + let msg = "Received gpuboxes and measurement set files; this is not supported."; + return Err(InputVisArgsError::InvalidDataInput(msg)); + } + (_, Some(_), _, None, Some(_)) => { + let msg = "Received gpuboxes and uvfits files; this is not supported."; + return Err(InputVisArgsError::InvalidDataInput(msg)); + } + (_, _, _, Some(_), Some(_)) => { + let msg = "Received uvfits and measurement set files; this is not supported."; + return Err(InputVisArgsError::InvalidDataInput(msg)); + } + (_, _, Some(_), Some(_), _) => { + let msg = "Received mwafs and measurement set files; this is not supported."; + return Err(InputVisArgsError::InvalidDataInput(msg)); + } + (_, _, Some(_), _, Some(_)) => { + let msg = "Received mwafs and uvfits files; this is not supported."; + return Err(InputVisArgsError::InvalidDataInput(msg)); + } + (None, None, None, None, None) => return Err(InputVisArgsError::NoInputData), + }; + + let total_num_tiles = vis_reader.get_obs_context().get_total_num_tiles(); + + // Read the calibration solutions, if they were supplied. + let mut solutions_block = vec![]; + let solutions = match solutions { + Some(s) => { + let s = if s.len() > 1 { + return Err(InputVisArgsError::MultipleSolutions(s)); + } else { + s.into_vec().remove(0) + }; + // The optional metafits file is only used for reading RTS + // solutions, which we won't support here. + let sols = CalibrationSolutions::read_solutions_from_ext_inner(&s, None)?; + solutions_block + .push(format!("On-the-fly-calibrating with solutions {}", s.display()).into()); + + debug!( + "Raw data corrections in the solutions: {:?}", + sols.raw_data_corrections + ); + + // We can't do anything if the number of tiles in the data is + // different to that of the solutions. + + // TODO: Check that all unflagged input tiles are in the + // solutions; it's OK if the tile counts mismatch. + if total_num_tiles != sols.di_jones.len_of(Axis(1)) { + return Err(InputVisArgsError::TileCountMismatch { + data: total_num_tiles, + solutions: sols.di_jones.len_of(Axis(1)), + }); + } + + // Replace raw data corrections in the data args with what's in + // the solutions. + match sols.raw_data_corrections { + Some(c) => { + vis_reader.set_raw_data_corrections(c); + } + + None => { + // Warn the user if we're applying solutions to raw data + // without knowing what was applied during calibration. + if matches!(vis_reader.get_input_data_type(), VisInputType::Raw) { + [ + "The calibration solutions do not list raw data corrections." + .into(), + "Defaults and any user inputs are being used.".into(), + ] + .warn(); + } + } + }; + + Some(sols) + } + None => None, + }; + data_printer.push_block(input_files_block); + data_printer.push_block(solutions_block); + data_printer.push_block(raw_data_corrections_block); + data_printer.display(); + + let obs_context = vis_reader.get_obs_context(); + + let mut coord_printer = InfoPrinter::new("Coordinates".into()); + let mut block = vec![ + style(" RA Dec") + .bold() + .to_string() + .into(), + format!( + "Phase centre: {:>8.4}° {:>8.4}° (J2000)", + obs_context.phase_centre.ra.to_degrees(), + obs_context.phase_centre.dec.to_degrees() + ) + .into(), + ]; + if let Some(pointing_centre) = obs_context.pointing_centre { + block.push( + format!( + "Pointing centre: {:>8.4}° {:>8.4}°", + pointing_centre.ra.to_degrees(), + pointing_centre.dec.to_degrees() + ) + .into(), + ); + } + coord_printer.push_block(block); + let mut block = vec![format!( + "Array position: {:>8.4}° {:>8.4}° {:.4}m", + obs_context.array_position.longitude_rad.to_degrees(), + obs_context.array_position.latitude_rad.to_degrees(), + obs_context.array_position.height_metres + ) + .into()]; + let supplied = obs_context.supplied_array_position; + let used = obs_context.array_position; + if (used.longitude_rad - supplied.longitude_rad).abs() > f64::EPSILON + || (used.latitude_rad - supplied.latitude_rad).abs() > f64::EPSILON + || (used.height_metres - supplied.height_metres).abs() > f64::EPSILON + { + block.push( + format!( + "Supplied position: {:>8.4}° {:>8.4}° {:.4}m", + supplied.longitude_rad.to_degrees(), + supplied.latitude_rad.to_degrees(), + supplied.height_metres + ) + .into(), + ); + } + block.push( + style(" Longitude Latitude Height") + .bold() + .to_string() + .into(), + ); + coord_printer.push_block(block); + coord_printer.display(); + + // Assign the tile flags. The flags depend on what's available in the + // data, whether the user wants to use input data tile flags, and any + // additional flags the user wants. + let flagged_tiles = { + let mut flagged_tiles = HashSet::new(); + + if !ignore_input_data_tile_flags { + // Add tiles that have already been flagged by the input data. + flagged_tiles.extend(obs_context.flagged_tiles.iter()); + } + // Unavailable tiles must be regarded as flagged. + flagged_tiles.extend(obs_context.unavailable_tiles.iter()); + + if let Some(flag_strings) = tile_flags { + // We need to convert the strings into antenna indices. The strings + // are either indices themselves or antenna names. + for flag_string in flag_strings { + // Try to parse a naked number. + let result = + match flag_string.trim().parse().ok() { + Some(i) => { + if i >= total_num_tiles { + Err(InputVisArgsError::BadTileIndexForFlagging { + got: i, + max: total_num_tiles - 1, + }) + } else { + flagged_tiles.insert(i); + Ok(()) + } + } + None => { + // Check if this is an antenna name. + match obs_context.tile_names.iter().enumerate().find(|(_, name)| { + name.to_lowercase() == flag_string.to_lowercase() + }) { + // If there are no matches, complain that the user input + // is no good. + None => Err(InputVisArgsError::BadTileNameForFlagging( + flag_string.to_string(), + )), + Some((i, _)) => { + flagged_tiles.insert(i); + Ok(()) + } + } + } + }; + if result.is_err() { + // If there's a problem, show all the tile names and their + // indices to help out the user. + obs_context.print_tile_statuses(Info); + // Propagate the error. + result?; + } + } + } + + flagged_tiles + }; + let num_unflagged_tiles = total_num_tiles - flagged_tiles.len(); + if num_unflagged_tiles == 0 { + obs_context.print_tile_statuses(Debug); + return Err(InputVisArgsError::NoTiles); + } + let flagged_tile_names_and_indices = flagged_tiles + .iter() + .cloned() + .sorted() + .map(|i| (obs_context.tile_names[i].as_str(), i)) + .collect::>(); + let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, flagged_tiles); + + let mut tiles_printer = InfoPrinter::new("Tile info".into()); + tiles_printer.push_block(vec![ + format!("{total_num_tiles} total").into(), + format!("{num_unflagged_tiles} unflagged").into(), + ]); + if !flagged_tile_names_and_indices.is_empty() { + let mut block = vec!["Flagged tiles:".into()]; + for f in flagged_tile_names_and_indices.chunks(5) { + block.push(format!("{f:?}").into()); + } + tiles_printer.push_block(block); + } + tiles_printer.display(); + + if log_enabled!(Debug) { + obs_context.print_tile_statuses(Debug); + } + + let timesteps_to_use = { + match (use_all_timesteps, timesteps) { + (true, _) => obs_context.all_timesteps.clone(), + (false, None) => Vec1::try_from_vec(obs_context.unflagged_timesteps.clone()) + .map_err(|_| InputVisArgsError::NoTimesteps)?, + (false, Some(mut ts)) => { + // Make sure there are no duplicates. + let timesteps_hashset: HashSet<&usize> = ts.iter().collect(); + if timesteps_hashset.len() != ts.len() { + return Err(InputVisArgsError::DuplicateTimesteps); + } + + // Ensure that all specified timesteps are actually available. + for t in &ts { + if !(0..obs_context.timestamps.len()).contains(t) { + return Err(InputVisArgsError::UnavailableTimestep { + got: *t, + last: obs_context.timestamps.len() - 1, + }); + } + } + + ts.sort_unstable(); + Vec1::try_from_vec(ts).map_err(|_| InputVisArgsError::NoTimesteps)? + } + } + }; + + let timestep_span = NonZeroUsize::new( + timesteps_to_use + .last() + .checked_sub(*timesteps_to_use.first()) + .expect("last timestep index is bigger than first") + + 1, + ) + .expect("is not 0"); + let time_average_factor = match parse_time_average_factor( + obs_context.time_res, + time_average.as_deref(), + NonZeroUsize::new(1).unwrap(), + ) { + Ok(f) => { + // Check that the factor is not too big. + if f > timestep_span { + format!( + "Cannot average {} timesteps; only {} are being used. Capping.", + f, timestep_span + ) + .warn(); + timestep_span + } else { + f + } + } + // The factor was 0, average everything together. + Err(AverageFactorError::Zero) => timestep_span, + Err(AverageFactorError::NotInteger) => { + return Err(InputVisArgsError::TimeFactorNotInteger) + } + Err(AverageFactorError::NotIntegerMultiple { out, inp }) => { + return Err(InputVisArgsError::TimeResNotMultiple { out, inp }) + } + Err(AverageFactorError::Parse(e)) => { + return Err(InputVisArgsError::ParseTimeAverageFactor(e)) + } + }; + + let dut1 = match (ignore_dut1, dut1) { + (true, _) => { + debug!("Ignoring input data and user DUT1"); + Duration::default() + } + (false, Some(dut1)) => { + debug!("Using user DUT1"); + Duration::from_seconds(dut1) + } + (false, None) => { + if let Some(dut1) = obs_context.dut1 { + debug!("Using input data DUT1"); + dut1 + } else { + debug!("Input data has no DUT1"); + Duration::default() + } + } + }; + + let mut time_printer = InfoPrinter::new("Time info".into()); + let time_res = match (obs_context.time_res, time_average_factor.get()) { + (_, 0) => unreachable!("cannot be 0"), + (None, _) => { + time_printer.push_line( + format!("Resolution is unknown, assuming {TIME_WEIGHT_FACTOR}").into(), + ); + obs_context + .time_res + .unwrap_or(Duration::from_seconds(TIME_WEIGHT_FACTOR)) + } + (Some(r), 1) => { + time_printer.push_line(format!("Resolution: {r}").into()); + r + } + (Some(r), f) => { + time_printer.push_block(vec![ + format!("Resolution: {r}").into(), + format!("Averaging {f}x ({})", r * f as i64).into(), + ]); + r + } + }; + time_printer + .push_line(format!("First obs timestamp: {}", obs_context.timestamps.first()).into()); + time_printer.push_block(vec![ + format!( + "Available timesteps: {}", + range_or_comma_separated(&obs_context.all_timesteps) + ) + .into(), + format!( + "Unflagged timesteps: {}", + range_or_comma_separated(&obs_context.unflagged_timesteps) + ) + .into(), + ]); + let mut block = vec![format!( + "Using timesteps: {}", + range_or_comma_separated(×teps_to_use) + ) + .into()]; + match timesteps_to_use.as_slice() { + [t] => block.push( + format!( + "Only timestamp (GPS): {:.2}", + obs_context.timestamps[*t].to_gpst_seconds() + ) + .into(), + ), + + [f, .., l] => { + block.push( + format!( + "First timestamp (GPS): {:.2}", + obs_context.timestamps[*f].to_gpst_seconds() + ) + .into(), + ); + block.push( + format!( + "Last timestamp (GPS): {:.2}", + obs_context.timestamps[*l].to_gpst_seconds() + ) + .into(), + ); + } + + [] => unreachable!("cannot be empty"), + } + { + let p = precess_time( + obs_context.array_position.longitude_rad, + obs_context.array_position.latitude_rad, + obs_context.phase_centre, + obs_context.timestamps[*timesteps_to_use.first()], + dut1, + ); + block.push(format!("First LMST: {:.6}° (J2000)", p.lmst_j2000.to_degrees()).into()); + } + time_printer.push_block(block); + time_printer.push_line(format!("DUT1: {:.10} s", dut1.to_seconds()).into()); + time_printer.display(); + + let timeblocks = timesteps_to_timeblocks( + &obs_context.timestamps, + time_res, + time_average_factor, + Some(×teps_to_use), + ); + + // Set up frequency information. Determine all of the fine-channel flags. + let mut flagged_fine_chans: HashSet = match fine_chan_flags { + Some(flags) => { + // Check that all channel flags are within the allowed range. + for &f in &flags { + if usize::from(f) > obs_context.fine_chan_freqs.len() { + return Err(InputVisArgsError::FineChanFlagTooBig { + got: f, + max: obs_context.fine_chan_freqs.len() - 1, + }); + } + } + flags.into_iter().collect() + } + None => HashSet::new(), + }; + if !ignore_input_data_fine_channel_flags { + flagged_fine_chans.extend(obs_context.flagged_fine_chans.iter()); + } + // Assign the per-coarse-channel fine-channel flags. + let fine_chan_flags_per_coarse_chan = { + let mut out_flags = HashSet::new(); + // Handle user flags. + if let Some(fine_chan_flags_per_coarse_chan) = fine_chan_flags_per_coarse_chan { + out_flags.extend(fine_chan_flags_per_coarse_chan); + } + // Handle input data flags. + if let (false, Some(flags)) = ( + ignore_input_data_fine_channel_flags, + obs_context.flagged_fine_chans_per_coarse_chan.as_ref(), + ) { + out_flags.extend(flags.iter()); + } + out_flags + }; + // Take the per-coarse-channel flags and put them in the fine channel + // flags. + match ( + obs_context.mwa_coarse_chan_nums.as_ref(), + obs_context.num_fine_chans_per_coarse_chan.map(|n| n.get()), + ) { + (Some(mwa_coarse_chan_nums), Some(num_fine_chans_per_coarse_chan)) => { + for (i_cc, _) in (0..).zip(mwa_coarse_chan_nums.iter()) { + for &f in &fine_chan_flags_per_coarse_chan { + if f > num_fine_chans_per_coarse_chan { + return Err(InputVisArgsError::FineChanFlagPerCoarseChanTooBig { + got: f, + max: num_fine_chans_per_coarse_chan - 1, + }); + } + + flagged_fine_chans.insert(f + num_fine_chans_per_coarse_chan * i_cc); + } + } + } + + // We can't do anything without the number of fine channels per + // coarse channel. + (_, None) => { + "Flags per coarse channel were specified, but no information on how many fine channels per coarse channel is available; flags are being ignored.".warn(); + } + + // If we don't have MWA coarse channel numbers but we do have + // per-coarse-channel flags, warn the user. + (None, _) => { + if !fine_chan_flags_per_coarse_chan.is_empty() { + "Flags per coarse channel were specified, but no MWA coarse channel information is available; flags are being ignored.".warn(); + } + } + } + let mut unflagged_fine_chan_freqs = vec![]; + for (i_chan, &freq) in (0..).zip(obs_context.fine_chan_freqs.iter()) { + if !flagged_fine_chans.contains(&i_chan) { + unflagged_fine_chan_freqs.push(freq as f64); + } + } + + let num_unflagged_fine_chan_freqs = if unflagged_fine_chan_freqs.is_empty() { + return Err(InputVisArgsError::NoChannels); + } else { + NonZeroUsize::new(unflagged_fine_chan_freqs.len()).expect("cannot be empty here") + }; + let freq_average_factor = match parse_freq_average_factor( + obs_context.freq_res, + freq_average.as_deref(), + NonZeroUsize::new(1).unwrap(), + ) { + Ok(f) => { + // Check that the factor is not too big. + if f > num_unflagged_fine_chan_freqs { + format!( + "Cannot average {} channels; only {} are being used. Capping.", + f, + unflagged_fine_chan_freqs.len() + ) + .warn(); + num_unflagged_fine_chan_freqs + } else { + f + } + } + // The factor was 0, average everything together. + Err(AverageFactorError::Zero) => num_unflagged_fine_chan_freqs, + Err(AverageFactorError::NotInteger) => { + return Err(InputVisArgsError::FreqFactorNotInteger) + } + Err(AverageFactorError::NotIntegerMultiple { out, inp }) => { + return Err(InputVisArgsError::FreqResNotMultiple { out, inp }) + } + Err(AverageFactorError::Parse(e)) => { + return Err(InputVisArgsError::ParseFreqAverageFactor(e)) + } + }; + + let mut chan_printer = InfoPrinter::new("Channel info".into()); + let freq_res = match (obs_context.freq_res, freq_average_factor.get()) { + (_, 0) => unreachable!("cannot be 0"), + (None, _) => { + chan_printer.push_line( + format!("Resolution is unknown, assuming {FREQ_WEIGHT_FACTOR}").into(), + ); + FREQ_WEIGHT_FACTOR + } + (Some(r), 1) => { + chan_printer.push_line(format!("Resolution: {:.2} kHz", r / 1e3).into()); + r + } + (Some(r), f) => { + chan_printer.push_block(vec![ + format!("Resolution: {:.2} kHz", r / 1e3).into(), + format!("Averaging {f}x ({:.2} kHz)", r / 1e3 * f as f64).into(), + ]); + r + } + }; + + // Set up the chanblocks. + let mut spws = channels_to_chanblocks( + &obs_context.fine_chan_freqs, + freq_res.round() as u64, + freq_average_factor, + &flagged_fine_chans, + ); + // There must be at least one chanblock to do anything. + let spw = match spws.as_slice() { + // No spectral windows is the same as no chanblocks. + [] => return Err(InputVisArgsError::NoChannels), + [spw] => { + // Check that the chanblocks aren't all flagged. + if spw.chanblocks.is_empty() { + return Err(InputVisArgsError::NoChannels); + } + spws.swap_remove(0) + } + [..] => { + // TODO: Allow picket fence. + eprintln!("\"Picket fence\" data detected. hyperdrive does not support this right now -- exiting."); + eprintln!("See for more info: https://MWATelescope.github.io/mwa_hyperdrive/defs/mwa/picket_fence.html"); + std::process::exit(1); + } + }; + + chan_printer.push_block(vec![ + format!( + "Total number of fine channels: {}", + obs_context.fine_chan_freqs.len() + ) + .into(), + format!( + "Number of unflagged fine channels: {}", + unflagged_fine_chan_freqs.len() + ) + .into(), + ]); + let mut block = vec![]; + if let Some(n) = obs_context.num_fine_chans_per_coarse_chan { + block.push(format!("Number of fine chans per coarse channel: {}", n.get()).into()); + } + if !fine_chan_flags_per_coarse_chan.is_empty() { + let mut sorted = fine_chan_flags_per_coarse_chan + .into_iter() + .collect::>(); + sorted.sort_unstable(); + block.push(format!("Flags per coarse channel: {sorted:?}").into()); + } + chan_printer.push_block(block); + match obs_context.fine_chan_freqs.as_slice() { + [f] => chan_printer + .push_line(format!("Only fine-channel: {:.3} MHz", *f as f64 / 1e6).into()), + + [f, .., l] => chan_printer.push_block(vec![ + format!("First fine-channel: {:.3} MHz", *f as f64 / 1e6).into(), + format!("Last fine-channel: {:.3} MHz", *l as f64 / 1e6).into(), + ]), + + [] => unreachable!("cannot be empty"), + }; + match unflagged_fine_chan_freqs.as_slice() { + [f] => chan_printer + .push_line(format!("Only unflagged fine-channel: {:.3} MHz", *f / 1e6).into()), + + [f, .., l] => chan_printer.push_block(vec![ + format!("First unflagged fine-channel: {:.3} MHz", *f / 1e6).into(), + format!("Last unflagged fine-channel: {:.3} MHz", *l / 1e6).into(), + ]), + + [] => unreachable!("cannot be empty"), + }; + chan_printer.display(); + + Ok(InputVisParams { + vis_reader, + solutions, + timeblocks, + time_res: time_res * time_average_factor.get() as i64, + spw, + tile_baseline_flags, + using_autos: !no_autos, + ignore_weights, + dut1, + }) + } +} + +#[derive(thiserror::Error, Debug)] +pub(crate) enum InputVisArgsError { + #[error("Specified file does not exist: {0}")] + DoesNotExist(String), + + #[error("Could not read specified file: {0}")] + CouldNotRead(String), + + #[error("The specified file '{0}' is a \"PPDs metafits\" and is not supported. Please use a newer metafits file.")] + PpdMetafitsUnsupported(String), + + #[error("The specified file '{0}' was not a recognised file type.\n\nSupported file formats:{}", *SUPPORTED_INPUT_FILE_TYPES)] + NotRecognised(String), + + #[error("No input data was given!")] + NoInputData, + + #[error("Multiple metafits files were specified: {0:?}\nThis is unsupported.")] + MultipleMetafits(Vec1), + + #[error("Multiple measurement sets were specified: {0:?}\nThis is currently unsupported.")] + MultipleMeasurementSets(Vec1), + + #[error("Multiple uvfits files were specified: {0:?}\nThis is currently unsupported.")] + MultipleUvfits(Vec1), + + #[error("Multiple calibration solutions files were specified: {0:?}\nThis is unsupported.")] + MultipleSolutions(Vec1), + + #[error("{0}\n\nSupported file formats:{}", *SUPPORTED_INPUT_FILE_TYPES)] + InvalidDataInput(&'static str), + + #[error("Array position specified as {pos:?}, not [, , ]")] + BadArrayPosition { pos: Vec }, + + #[error("The data either contains no timesteps or no timesteps are being used")] + NoTimesteps, + + #[error("Duplicate timesteps were specified; this is invalid")] + DuplicateTimesteps, + + #[error("Timestep {got} was specified but it isn't available; the last timestep is {last}")] + UnavailableTimestep { got: usize, last: usize }, + + #[error("The data either contains no tiles or all tiles are flagged")] + NoTiles, + + #[error("Got a tile flag {got}, but the biggest possible antenna index is {max}")] + BadTileIndexForFlagging { got: usize, max: usize }, + + #[error("Bad tile flag value: '{0}' is neither an integer or an available antenna name. Run with extra verbosity to see all tile statuses.")] + BadTileNameForFlagging(String), + + #[error("The data either contains no frequency channels or all channels are flagged")] + NoChannels, + + #[error("Got a fine-channel flag {got}, but the biggest possible index is {max}")] + FineChanFlagTooBig { got: u16, max: usize }, + + #[error( + "Got a fine-channel-per-coarse-channel flag {got}, but the biggest possible index is {max}" + )] + FineChanFlagPerCoarseChanTooBig { got: u16, max: u16 }, + + #[error("The input data and the solutions have different numbers of tiles (data: {data}, solutions: {solutions}); cannot continue")] + TileCountMismatch { data: usize, solutions: usize }, + + #[error("Error when parsing input data time average factor: {0}")] + ParseTimeAverageFactor(crate::unit_parsing::UnitParseError), + + #[error("Input data time average factor isn't an integer")] + TimeFactorNotInteger, + + #[error("Input data time resolution isn't a multiple of input data's: {out} seconds vs {inp} seconds")] + TimeResNotMultiple { out: f64, inp: f64 }, + + #[error("Error when parsing input data freq. average factor: {0}")] + ParseFreqAverageFactor(crate::unit_parsing::UnitParseError), + + #[error("Input data freq. average factor isn't an integer")] + FreqFactorNotInteger, + + #[error("Input data freq. resolution isn't a multiple of input data's: {out} Hz vs {inp} Hz")] + FreqResNotMultiple { out: f64, inp: f64 }, + + #[error(transparent)] + PfbParse(#[from] crate::io::read::pfb_gains::PfbParseError), + + #[error(transparent)] + Raw(#[from] crate::io::read::RawReadError), + + #[error(transparent)] + Ms(#[from] crate::io::read::MsReadError), + + #[error(transparent)] + Uvfits(#[from] crate::io::read::UvfitsReadError), + + #[error(transparent)] + Solutions(#[from] crate::solutions::SolutionsReadError), + + #[error(transparent)] + Glob(#[from] crate::io::GlobError), + + #[error("IO error when attempting to read file '{0}': {1}")] + IO(String, std::io::Error), +} + +// It looks a bit neater to print out a collection of numbers as a range rather +// than individual indices if they're sequential. This function inspects a +// collection and returns a string to be printed. +fn range_or_comma_separated(collection: &[usize]) -> String { + if collection.is_empty() { + return "".to_string(); + } + + let mut iter = collection.iter(); + let mut prev = *iter.next().unwrap(); + // Innocent until proven guilty. + let mut is_sequential = true; + for next in iter { + if *next == prev + 1 { + prev = *next; + } else { + is_sequential = false; + break; + } + } + + if is_sequential { + if collection.len() == 1 { + format!("[{}]", collection[0]) + } else { + format!( + "[{:?})", + (*collection.first().unwrap()..*collection.last().unwrap() + 1) + ) + } + } else { + collection + .iter() + .map(|t| t.to_string()) + .collect::>() + .join(", ") + } +} diff --git a/src/cli/common/input_vis/tests.rs b/src/cli/common/input_vis/tests.rs new file mode 100644 index 00000000..20d64646 --- /dev/null +++ b/src/cli/common/input_vis/tests.rs @@ -0,0 +1,665 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +use std::path::PathBuf; + +use approx::{assert_abs_diff_eq, assert_relative_eq}; +use crossbeam_utils::atomic::AtomicCell; +use marlu::{ + constants::{MWA_HEIGHT_M, MWA_LAT_DEG, MWA_LONG_DEG}, + Jones, LatLngHeight, +}; +use ndarray::prelude::*; +use tempfile::TempDir; + +use super::{ + InputVisArgs, + InputVisArgsError::{ + FreqFactorNotInteger, FreqResNotMultiple, NoInputData, TimeFactorNotInteger, + TimeResNotMultiple, + }, +}; +use crate::{ + cli::{ + common::{BeamArgs, ModellingArgs, OutputVisArgs, SkyModelWithVetoArgs}, + vis_simulate::{VisSimulateArgs, VisSimulateCliArgs}, + }, + tests::{ + get_reduced_1090008640_ms, get_reduced_1090008640_raw, get_reduced_1090008640_uvfits, + DataAsStrings, + }, +}; + +#[test] +fn test_handle_no_input() { + let args = InputVisArgs::default(); + let result = args.parse(""); + + assert!(result.is_err()); + assert!(matches!(result, Err(NoInputData))); +} + +#[test] +fn test_handle_multiple_metafits() { + let DataAsStrings { + metafits, mut vis, .. + } = get_reduced_1090008640_raw(); + let mut files = vec![metafits.clone()]; + files.append(&mut vis); + // Add the metafits again. + files.push(metafits); + + let args = InputVisArgs { + files: Some(files), + ..Default::default() + }; + let result = args.parse(""); + + assert!(result.is_err()); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Multiple metafits files were specified")); +} + +#[test] +fn test_handle_multiple_ms() { + let DataAsStrings { + metafits, mut vis, .. + } = get_reduced_1090008640_ms(); + let mut files = vec![metafits]; + let ms = vis.swap_remove(0); + files.push(ms.clone()); + files.push(ms); + + let args = InputVisArgs { + files: Some(files), + ..Default::default() + }; + let result = args.parse(""); + + assert!(result.is_err()); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Multiple measurement sets were specified")); +} + +#[test] +fn test_handle_multiple_uvfits() { + let DataAsStrings { + metafits, mut vis, .. + } = get_reduced_1090008640_uvfits(); + let mut files = vec![metafits]; + let uvfits = vis.swap_remove(0); + files.push(uvfits.clone()); + files.push(uvfits); + + let args = InputVisArgs { + files: Some(files), + ..Default::default() + }; + let result = args.parse(""); + + assert!(result.is_err()); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Multiple uvfits files were specified")); +} + +#[test] +fn test_handle_only_metafits() { + let DataAsStrings { metafits, .. } = get_reduced_1090008640_raw(); + let files = vec![metafits]; + + let args = InputVisArgs { + files: Some(files), + ..Default::default() + }; + let result = args.parse(""); + + assert!(result.is_err()); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Received only a metafits file;")); +} + +#[test] +fn test_handle_array_pos() { + let DataAsStrings { + metafits, mut vis, .. + } = get_reduced_1090008640_raw(); + let mut files = vec![metafits]; + files.append(&mut vis); + + let expected = [MWA_LONG_DEG + 1.0, MWA_LAT_DEG + 1.0, MWA_HEIGHT_M + 1.0]; + let args = InputVisArgs { + files: Some(files), + array_position: Some(expected.to_vec()), + ..Default::default() + }; + let params = args.parse("").unwrap(); + + assert_abs_diff_eq!( + params.vis_reader.get_obs_context().array_position, + LatLngHeight { + longitude_rad: expected[0].to_radians(), + latitude_rad: expected[1].to_radians(), + height_metres: expected[2] + } + ); +} + +#[test] +fn test_handle_bad_array_pos() { + let DataAsStrings { + metafits, mut vis, .. + } = get_reduced_1090008640_raw(); + let mut files = vec![metafits]; + files.append(&mut vis); + + let two_elems_when_it_should_be_three = [MWA_LONG_DEG + 1.0, MWA_LAT_DEG + 1.0]; + let args = InputVisArgs { + files: Some(files), + array_position: Some(two_elems_when_it_should_be_three.to_vec()), + ..Default::default() + }; + let result = args.parse(""); + + assert!(result.is_err()); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Array position specified as")); +} + +#[test] +fn test_parse_time_average() { + let DataAsStrings { + metafits, mut vis, .. + } = get_reduced_1090008640_raw(); + let mut files = vec![metafits]; + files.append(&mut vis); + + let mut args = InputVisArgs { + files: Some(files), + time_average: Some("1".to_string()), + ..Default::default() + }; + args.clone().parse("").unwrap(); + + args.time_average = Some("2".to_string()); + args.clone().parse("").unwrap(); + + args.time_average = Some("20".to_string()); + args.clone().parse("").unwrap(); + + // The native time resolution is 2s. + args.time_average = Some("2s".to_string()); + args.clone().parse("").unwrap(); + + args.time_average = Some("4s".to_string()); + args.clone().parse("").unwrap(); + + args.time_average = Some("20s".to_string()); + args.clone().parse("").unwrap(); + + args.time_average = Some("1.5".to_string()); + let result = args.clone().parse(""); + assert!(matches!(result.err(), Some(TimeFactorNotInteger))); + + args.time_average = Some("3s".to_string()); + let result = args.clone().parse(""); + assert!(matches!(result.err(), Some(TimeResNotMultiple { .. }))); + + args.time_average = Some("7s".to_string()); + let result = args.parse(""); + assert!(matches!(result.err(), Some(TimeResNotMultiple { .. }))); +} + +#[test] +fn test_parse_freq_average() { + let DataAsStrings { + metafits, mut vis, .. + } = get_reduced_1090008640_raw(); + let mut files = vec![metafits]; + files.append(&mut vis); + + let mut args = InputVisArgs { + files: Some(files), + freq_average: Some("1".to_string()), + ..Default::default() + }; + args.clone().parse("").unwrap(); + + args.freq_average = Some("2".to_string()); + args.clone().parse("").unwrap(); + + args.freq_average = Some("20".to_string()); + args.clone().parse("").unwrap(); + + // The native freq. resolution is 40kHz. + args.freq_average = Some("40kHz".to_string()); + args.clone().parse("").unwrap(); + + args.freq_average = Some("80kHz".to_string()); + args.clone().parse("").unwrap(); + + args.freq_average = Some("960kHz".to_string()); + args.clone().parse("").unwrap(); + + args.freq_average = Some("1.5".to_string()); + let result = args.clone().parse(""); + assert!(matches!(result.err(), Some(FreqFactorNotInteger))); + + args.freq_average = Some("10kHz".to_string()); + let result = args.clone().parse(""); + assert!(matches!(result.err(), Some(FreqResNotMultiple { .. }))); + + args.freq_average = Some("79kHz".to_string()); + let result = args.parse(""); + assert!(matches!(result.err(), Some(FreqResNotMultiple { .. }))); +} + +#[test] +fn test_freq_averaging_works() { + let DataAsStrings { + metafits, mut vis, .. + } = get_reduced_1090008640_raw(); + let mut files = vec![metafits]; + files.append(&mut vis); + + let mut args = InputVisArgs { + files: Some(files), + freq_average: Some("1".to_string()), + ..Default::default() + }; + let default_params = args.clone().parse("").unwrap(); + let num_unflagged_tiles = default_params.get_num_unflagged_tiles(); + let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + let cross_vis_shape = ( + default_params.spw.chanblocks.len(), + num_unflagged_cross_baselines, + ); + let auto_vis_shape = (default_params.spw.chanblocks.len(), num_unflagged_tiles); + + let mut default_crosses_fb = Array2::zeros(cross_vis_shape); + let mut default_cross_weights_fb = Array2::zeros(cross_vis_shape); + let mut default_autos_fb = Array2::zeros(auto_vis_shape); + let mut default_auto_weights_fb = Array2::zeros(auto_vis_shape); + let error = AtomicCell::new(false); + default_params + .read_timeblock( + default_params.timeblocks.first(), + default_crosses_fb.view_mut(), + default_cross_weights_fb.view_mut(), + Some(( + default_autos_fb.view_mut(), + default_auto_weights_fb.view_mut(), + )), + &error, + ) + .unwrap(); + + // 27 unflagged channels, 8128 unflagged baselines. + assert_eq!(default_crosses_fb.dim(), (27, 8128)); + assert_eq!(default_cross_weights_fb.dim(), (27, 8128)); + // 27 unflagged channels, 128 unflagged tiles. + assert_eq!(default_autos_fb.dim(), (27, 128)); + assert_eq!(default_auto_weights_fb.dim(), (27, 128)); + + args.ignore_input_data_fine_channel_flags = true; + let ref_params = args.clone().parse("").unwrap(); + let num_unflagged_tiles = ref_params.get_num_unflagged_tiles(); + let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + let cross_vis_shape = ( + ref_params.spw.chanblocks.len(), + num_unflagged_cross_baselines, + ); + let auto_vis_shape = (ref_params.spw.chanblocks.len(), num_unflagged_tiles); + + let mut ref_crosses_fb = Array2::zeros(cross_vis_shape); + let mut ref_cross_weights_fb = Array2::zeros(cross_vis_shape); + let mut ref_autos_fb = Array2::zeros(auto_vis_shape); + let mut ref_auto_weights_fb = Array2::zeros(auto_vis_shape); + let error = AtomicCell::new(false); + ref_params + .read_timeblock( + ref_params.timeblocks.first(), + ref_crosses_fb.view_mut(), + ref_cross_weights_fb.view_mut(), + Some((ref_autos_fb.view_mut(), ref_auto_weights_fb.view_mut())), + &error, + ) + .unwrap(); + + // 32 unflagged channels, 8128 unflagged baselines. + assert_eq!(ref_crosses_fb.dim(), (32, 8128)); + assert_eq!(ref_cross_weights_fb.dim(), (32, 8128)); + // 32 unflagged channels, 128 unflagged tiles. + assert_eq!(ref_autos_fb.dim(), (32, 128)); + assert_eq!(ref_auto_weights_fb.dim(), (32, 128)); + + args.ignore_input_data_fine_channel_flags = false; + args.freq_average = Some("2".to_string()); + let av_params = args.clone().parse("").unwrap(); + let num_unflagged_tiles = av_params.get_num_unflagged_tiles(); + let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + let cross_vis_shape = ( + av_params.spw.chanblocks.len(), + num_unflagged_cross_baselines, + ); + let auto_vis_shape = (av_params.spw.chanblocks.len(), num_unflagged_tiles); + + let mut av_crosses_fb = Array2::zeros(cross_vis_shape); + let mut av_cross_weights_fb = Array2::zeros(cross_vis_shape); + let mut av_autos_fb = Array2::zeros(auto_vis_shape); + let mut av_auto_weights_fb = Array2::zeros(auto_vis_shape); + let error = AtomicCell::new(false); + av_params + .read_timeblock( + av_params.timeblocks.first(), + av_crosses_fb.view_mut(), + av_cross_weights_fb.view_mut(), + Some((av_autos_fb.view_mut(), av_auto_weights_fb.view_mut())), + &error, + ) + .unwrap(); + + // 14 unflagged chanblocks, 8128 unflagged baselines. + assert_eq!(av_crosses_fb.dim(), (14, 8128)); + assert_eq!(av_cross_weights_fb.dim(), (14, 8128)); + // 14 unflagged chanblocks, 128 unflagged tiles. + assert_eq!(av_autos_fb.dim(), (14, 128)); + assert_eq!(av_auto_weights_fb.dim(), (14, 128)); + + // Channels 0 1 16 30 31 are flagged by default, so manually average the + // unflagged channels and compare with the averaged arrays. + let flags = [0, 1, 16, 30, 31]; + let mut av_vis = Jones::default(); + let mut weight_sum: f64 = 0.0; + + for i_bl in 0..8128 { + let av_crosses_f = av_crosses_fb.slice(s![.., i_bl]); + let mut av_crosses_iter = av_crosses_f.iter(); + let av_cross_weights_f = av_cross_weights_fb.slice(s![.., i_bl]); + let mut av_cross_weights_iter = av_cross_weights_f.iter(); + let mut i_unflagged_chan = 0; + for i_chan in 0..32 { + if i_chan % 2 == 0 && weight_sum.abs() > 0.0 { + // Compare. + av_vis /= weight_sum; + let j = Jones::from(av_crosses_iter.next().unwrap()); + assert_relative_eq!(av_vis, j, max_relative = 1e-7); + assert_relative_eq!( + weight_sum, + *av_cross_weights_iter.next().unwrap() as f64, + max_relative = 1e-7 + ); + + av_vis = Jones::default(); + weight_sum = 0.0; + } + + if flags.contains(&i_chan) { + continue; + } + // Compare unaveraged vis with one another. + assert_abs_diff_eq!( + ref_crosses_fb[(i_chan, i_bl)], + default_crosses_fb[(i_unflagged_chan, i_bl)] + ); + assert_abs_diff_eq!( + ref_cross_weights_fb[(i_chan, i_bl)], + default_cross_weights_fb[(i_unflagged_chan, i_bl)] + ); + i_unflagged_chan += 1; + + let weight = ref_cross_weights_fb[(i_chan, i_bl)] as f64; + // Promote Jones before dividing to keep precision high. + av_vis += Jones::from(ref_crosses_fb[(i_chan, i_bl)]) * weight; + weight_sum += weight; + } + } + + for i_tile in 0..128 { + let av_autos_f = av_autos_fb.slice(s![.., i_tile]); + let mut av_autos_iter = av_autos_f.iter(); + let av_auto_weights_f = av_auto_weights_fb.slice(s![.., i_tile]); + let mut av_auto_weights_iter = av_auto_weights_f.iter(); + let mut i_unflagged_chan = 0; + for i_chan in 0..32 { + if i_chan % 2 == 0 && weight_sum.abs() > 0.0 { + // Compare. + av_vis /= weight_sum; + let j = Jones::from(av_autos_iter.next().unwrap()); + assert_relative_eq!(av_vis, j, max_relative = 1e-7); + assert_relative_eq!( + weight_sum, + *av_auto_weights_iter.next().unwrap() as f64, + max_relative = 1e-7 + ); + + av_vis = Jones::default(); + weight_sum = 0.0; + } + + if flags.contains(&i_chan) { + continue; + } + // Compare unaveraged vis with one another. + assert_abs_diff_eq!( + ref_autos_fb[(i_chan, i_tile)], + default_autos_fb[(i_unflagged_chan, i_tile)] + ); + assert_abs_diff_eq!( + ref_auto_weights_fb[(i_chan, i_tile)], + default_auto_weights_fb[(i_unflagged_chan, i_tile)] + ); + i_unflagged_chan += 1; + + let weight = ref_auto_weights_fb[(i_chan, i_tile)] as f64; + // Promote Jones before dividing to keep precision high. + av_vis += Jones::from(ref_autos_fb[(i_chan, i_tile)]) * weight; + weight_sum += weight; + } + } + + // Do it all again with 3 channel averaging. + args.freq_average = Some("3".to_string()); + let av_params = args.parse("").unwrap(); + let num_unflagged_tiles = av_params.get_num_unflagged_tiles(); + let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + let cross_vis_shape = ( + av_params.spw.chanblocks.len(), + num_unflagged_cross_baselines, + ); + let auto_vis_shape = (av_params.spw.chanblocks.len(), num_unflagged_tiles); + + let mut av_crosses_fb = Array2::zeros(cross_vis_shape); + let mut av_cross_weights_fb = Array2::zeros(cross_vis_shape); + let mut av_autos_fb = Array2::zeros(auto_vis_shape); + let mut av_auto_weights_fb = Array2::zeros(auto_vis_shape); + let error = AtomicCell::new(false); + av_params + .read_timeblock( + av_params.timeblocks.first(), + av_crosses_fb.view_mut(), + av_cross_weights_fb.view_mut(), + Some((av_autos_fb.view_mut(), av_auto_weights_fb.view_mut())), + &error, + ) + .unwrap(); + + // 10 unflagged chanblocks, 8128 unflagged baselines. + assert_eq!(av_crosses_fb.dim(), (10, 8128)); + assert_eq!(av_cross_weights_fb.dim(), (10, 8128)); + // 10 unflagged chanblocks, 128 unflagged tiles. + assert_eq!(av_autos_fb.dim(), (10, 128)); + assert_eq!(av_auto_weights_fb.dim(), (10, 128)); + + let flags = [0, 1, 16, 30, 31]; + let mut av_vis = Jones::default(); + let mut weight_sum: f64 = 0.0; + + for i_bl in 0..8128 { + let av_crosses_f = av_crosses_fb.slice(s![.., i_bl]); + let mut av_crosses_iter = av_crosses_f.iter(); + let av_cross_weights_f = av_cross_weights_fb.slice(s![.., i_bl]); + let mut av_cross_weights_iter = av_cross_weights_f.iter(); + let mut i_unflagged_chan = 0; + for i_chan in 0..32 { + if i_chan % 3 == 0 && weight_sum.abs() > 0.0 { + // Compare. + av_vis /= weight_sum; + let j = Jones::from(av_crosses_iter.next().unwrap()); + assert_relative_eq!(av_vis, j, max_relative = 1e-7); + assert_relative_eq!( + weight_sum, + *av_cross_weights_iter.next().unwrap() as f64, + max_relative = 1e-7 + ); + + av_vis = Jones::default(); + weight_sum = 0.0; + } + + if flags.contains(&i_chan) { + continue; + } + // Compare unaveraged vis with one another. + assert_abs_diff_eq!( + ref_crosses_fb[(i_chan, i_bl)], + default_crosses_fb[(i_unflagged_chan, i_bl)] + ); + assert_abs_diff_eq!( + ref_cross_weights_fb[(i_chan, i_bl)], + default_cross_weights_fb[(i_unflagged_chan, i_bl)] + ); + i_unflagged_chan += 1; + + let weight = ref_cross_weights_fb[(i_chan, i_bl)] as f64; + // Promote Jones before dividing to keep precision high. + av_vis += Jones::from(ref_crosses_fb[(i_chan, i_bl)]) * weight; + weight_sum += weight; + } + } + + for i_tile in 0..128 { + let av_autos_f = av_autos_fb.slice(s![.., i_tile]); + let mut av_autos_iter = av_autos_f.iter(); + let av_auto_weights_f = av_auto_weights_fb.slice(s![.., i_tile]); + let mut av_auto_weights_iter = av_auto_weights_f.iter(); + let mut i_unflagged_chan = 0; + for i_chan in 0..32 { + if i_chan % 3 == 0 && weight_sum.abs() > 0.0 { + // Compare. + av_vis /= weight_sum; + let j = Jones::from(av_autos_iter.next().unwrap()); + assert_relative_eq!(av_vis, j, max_relative = 1e-7); + assert_relative_eq!( + weight_sum, + *av_auto_weights_iter.next().unwrap() as f64, + max_relative = 1e-7 + ); + + av_vis = Jones::default(); + weight_sum = 0.0; + } + + if flags.contains(&i_chan) { + continue; + } + // Compare unaveraged vis with one another. + assert_abs_diff_eq!( + ref_autos_fb[(i_chan, i_tile)], + default_autos_fb[(i_unflagged_chan, i_tile)] + ); + assert_abs_diff_eq!( + ref_auto_weights_fb[(i_chan, i_tile)], + default_auto_weights_fb[(i_unflagged_chan, i_tile)] + ); + i_unflagged_chan += 1; + + let weight = ref_auto_weights_fb[(i_chan, i_tile)] as f64; + // Promote Jones before dividing to keep precision high. + av_vis += Jones::from(ref_autos_fb[(i_chan, i_tile)]) * weight; + weight_sum += weight; + } + } +} + +#[test] +/// `timesteps_to_timeblocks` now needs the time resolution, because not +/// supplying it was giving incorrect timeblocks when averaging. This test +/// helps to ensure that the expected behaviour works. +fn sparse_timeblocks_with_averaging() { + // First, make some data with enough timesteps. + let tmp_dir = TempDir::new().unwrap(); + let DataAsStrings { + metafits, + vis: _, + mwafs: _, + srclist, + } = get_reduced_1090008640_raw(); + + let mut args = VisSimulateArgs { + args_file: None, + beam_args: BeamArgs { + beam_file: None, + unity_dipole_gains: true, + delays: None, + no_beam: true, + }, + modelling_args: ModellingArgs { + ..Default::default() + }, + srclist_args: SkyModelWithVetoArgs { + source_list: Some(srclist), + num_sources: Some(1), + ..Default::default() + }, + simulate_args: VisSimulateCliArgs::default(), + }; + args.simulate_args.metafits = Some(PathBuf::from(&metafits)); + args.simulate_args.num_timesteps = Some(20); + args.simulate_args.time_res = Some(2.0); + args.simulate_args.num_fine_channels = Some(1); + let output = tmp_dir.path().join("20ts.uvfits"); + args.simulate_args.output_model_files = Some(vec![output.clone()]); + args.run(false).unwrap(); + + // Now try to read it with sparse timesteps and averaging. + let args = InputVisArgs { + files: Some(vec![metafits, output.display().to_string()]), + timesteps: Some(vec![6, 12, 18]), + time_average: Some("8s".to_string()), + ..Default::default() + }; + let params = args.parse("").unwrap(); + assert_eq!(params.timeblocks.len(), 3); + + let output_params = OutputVisArgs { + outputs: Some(vec![tmp_dir.path().join("output.uvfits")]), + output_vis_time_average: None, + output_vis_freq_average: None, + } + .parse( + params.time_res, + params.spw.freq_res, + ¶ms.timeblocks.mapped_ref(|tb| tb.median), + false, + "adsf.uvfits", + None, + ) + .unwrap(); + assert_eq!(output_params.output_timeblocks.len(), 3); +} diff --git a/src/cli/common/mod.rs b/src/cli/common/mod.rs new file mode 100644 index 00000000..617511dc --- /dev/null +++ b/src/cli/common/mod.rs @@ -0,0 +1,541 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Common arguments for command-line interfaces. Here, we abstract many aspects +//! of `hyperdrive`, e.g. the `di-calibrate` and `vis-subtract` subcommands both +//! take visibilities as input, so the same vis input arguments are shared +//! between them. + +mod beam; +mod input_vis; +mod printers; +#[cfg(test)] +mod tests; + +pub(super) use beam::BeamArgs; +pub(super) use input_vis::{InputVisArgs, InputVisArgsError}; +pub(super) use printers::InfoPrinter; +pub(crate) use printers::{display_warnings, Warn}; + +use std::{num::NonZeroUsize, path::PathBuf, str::FromStr}; + +use clap::Parser; +use hifitime::{Duration, Epoch}; +use itertools::Itertools; +use log::{trace, Level::Trace}; +use marlu::RADec; +use serde::{Deserialize, Serialize}; +use strum::IntoEnumIterator; +use strum_macros::{Display, EnumIter, EnumString}; +use vec1::Vec1; + +use super::HyperdriveError; +use crate::{ + averaging::{ + parse_freq_average_factor, parse_time_average_factor, timesteps_to_timeblocks, + AverageFactorError, + }, + beam::Beam, + constants::{ + DEFAULT_CUTOFF_DISTANCE, DEFAULT_VETO_THRESHOLD, MWA_HEIGHT_M, MWA_LAT_DEG, MWA_LONG_DEG, + }, + io::{ + get_single_match_from_glob, + write::{can_write_to_file, VisOutputType, VIS_OUTPUT_EXTENSIONS}, + }, + model::ModelDevice, + params::{ModellingParams, OutputVisParams}, + srclist::{ + read::read_source_list_file, veto_sources, ComponentCounts, ReadSourceListError, + SourceList, SourceListType, SOURCE_LIST_TYPES_COMMA_SEPARATED, + }, + MODEL_DEVICE, +}; + +lazy_static::lazy_static! { + pub(super) static ref ARG_FILE_TYPES_COMMA_SEPARATED: String = ArgFileTypes::iter().join(", "); + + pub(super) static ref ARG_FILE_HELP: String = + format!("All arguments may be specified in a file. Any CLI arguments override arguments set in the file. Supported formats: {}", *ARG_FILE_TYPES_COMMA_SEPARATED); + + pub(super) static ref ARRAY_POSITION_HELP: String = + format!("The Earth longitude, latitude, and height of the instrumental array [degrees, degrees, meters]. Default (MWA): ({MWA_LONG_DEG}°, {MWA_LAT_DEG}°, {MWA_HEIGHT_M}m)"); + + pub(super) static ref SOURCE_LIST_TYPE_HELP: String = + format!("The type of sky-model source list. Valid types are: {}. If not specified, all types are attempted", *SOURCE_LIST_TYPES_COMMA_SEPARATED); + + pub(super) static ref SOURCE_DIST_CUTOFF_HELP: String = + format!("Specifies the maximum distance from the phase centre a source can be [degrees]. Default: {DEFAULT_CUTOFF_DISTANCE}"); + + pub(super) static ref VETO_THRESHOLD_HELP: String = + format!("Specifies the minimum Stokes XX+YY a source must have before it gets vetoed [Jy]. Default: {DEFAULT_VETO_THRESHOLD}"); + + pub(super) static ref SOURCE_LIST_INPUT_TYPE_HELP: String = + format!("Specifies the type of the input source list. Currently supported types: {}", *SOURCE_LIST_TYPES_COMMA_SEPARATED); + + pub(super) static ref SOURCE_LIST_OUTPUT_TYPE_HELP: String = + format!("Specifies the type of the output source list. May be required depending on the output filename. Currently supported types: {}", + *SOURCE_LIST_TYPES_COMMA_SEPARATED); + + pub(super) static ref SRCLIST_BY_BEAM_OUTPUT_TYPE_HELP: String = + format!("Specifies the type of the output source list. If not specified, the input source list type is used. Currently supported types: {}", + *SOURCE_LIST_TYPES_COMMA_SEPARATED); +} + +#[derive(Debug, Display, EnumIter, EnumString)] +pub(super) enum ArgFileTypes { + #[strum(serialize = "toml")] + Toml, + #[strum(serialize = "json")] + Json, +} + +macro_rules! unpack_arg_file { + ($arg_file:expr) => ({ + use std::{fs::File, io::Read, str::FromStr}; + + use crate::cli::common::{ArgFileTypes, ARG_FILE_TYPES_COMMA_SEPARATED}; + + debug!("Attempting to parse argument file {}", $arg_file.display()); + + let mut contents = String::new(); + let arg_file_type = $arg_file + .extension() + .and_then(|e| e.to_str()) + .map(|e| e.to_lowercase()) + .and_then(|e| ArgFileTypes::from_str(&e).ok()); + + match arg_file_type { + Some(ArgFileTypes::Toml) => { + debug!("Parsing toml file..."); + let mut fh = File::open(&$arg_file)?; + fh.read_to_string(&mut contents)?; + match toml::from_str(&contents) { + Ok(p) => p, + Err(err) => { + return Err(HyperdriveError::ArgFile(format!( + "Couldn't decode toml structure from {:?}:\n{err}", + $arg_file + ))) + } + } + } + Some(ArgFileTypes::Json) => { + debug!("Parsing json file..."); + let mut fh = File::open(&$arg_file)?; + fh.read_to_string(&mut contents)?; + match serde_json::from_str(&contents) { + Ok(p) => p, + Err(err) => { + return Err(HyperdriveError::ArgFile(format!( + "Couldn't decode json structure from {:?}:\n{err}", + $arg_file + ))) + } + } + } + + _ => { + return Err(HyperdriveError::ArgFile(format!( + "Argument file '{:?}' doesn't have a recognised file extension! Valid extensions are: {}", $arg_file, *ARG_FILE_TYPES_COMMA_SEPARATED) + )) + } + } + }); +} + +/// Arguments to be parsed for visibility outputs. Unlike other "arg" structs, +/// this one is not parsed by `clap`; this is to allow the help texts for +/// `hyperdrive` subcommands to better details what the output visibilities +/// represent (e.g. di-calibrate outputs model visibilities, whereas +/// vis-subtract outputs subtracted visibilities; attempting to have one set of +/// help text for both of these vis outputs is not as clear as just having the +/// curated help text specified in each of di-calibrate and vis-subtract). +#[derive(Debug, Clone, Default)] +pub(super) struct OutputVisArgs { + pub(super) outputs: Option>, + pub(super) output_vis_time_average: Option, + pub(super) output_vis_freq_average: Option, +} + +impl OutputVisArgs { + pub(super) fn parse( + self, + input_vis_time_res: Duration, + input_vis_freq_res_hz: f64, + timestamps: &Vec1, + write_smallest_contiguous_band: bool, + default_output_filename: &str, + vis_description: Option<&str>, + ) -> Result { + let OutputVisArgs { + outputs, + output_vis_time_average, + output_vis_freq_average, + } = self; + + let (time_average_factor, freq_average_factor) = { + // Parse and verify user input (specified resolutions must + // evenly divide the input data's resolutions). + let time_factor = parse_time_average_factor( + Some(input_vis_time_res), + output_vis_time_average.as_deref(), + NonZeroUsize::new(1).unwrap(), + ) + .map_err(|e| match e { + AverageFactorError::Zero => HyperdriveError::Generic( + "The output visibility time average factor cannot be 0".to_string(), + ), + AverageFactorError::NotInteger => HyperdriveError::Generic( + "The output visibility time average factor isn't an integer".to_string(), + ), + AverageFactorError::NotIntegerMultiple { out, inp } => HyperdriveError::Generic(format!("The output visibility time resolution isn't a multiple of input data's: {out} seconds vs {inp} seconds")), + AverageFactorError::Parse(e) => HyperdriveError::Generic(format!("Error when parsing the output visibility time average factor: {e}")), + })?; + let freq_factor = + parse_freq_average_factor(Some(input_vis_freq_res_hz), output_vis_freq_average.as_deref(), NonZeroUsize::new(1).unwrap()) + .map_err(|e| match e { + AverageFactorError::Zero => { + HyperdriveError::Generic( + "The output visibility freq. average factor cannot be 0".to_string(), + ) } + AverageFactorError::NotInteger => { + HyperdriveError::Generic( + "The output visibility freq. average factor isn't an integer".to_string(), + ) } + AverageFactorError::NotIntegerMultiple { out, inp } => { + HyperdriveError::Generic(format!("The output visibility freq. resolution isn't a multiple of input data's: {out} seconds vs {inp} seconds")) + } + AverageFactorError::Parse(e) => { + HyperdriveError::Generic(format!("Error when parsing the output visibility freq. average factor: {e}")) + } + })?; + + (time_factor, freq_factor) + }; + + let mut vis_printer = if let Some(vis_description) = vis_description { + InfoPrinter::new(format!("Output {vis_description} vis info").into()) + } else { + InfoPrinter::new("Output vis info".into()) + }; + + let output_files = { + let outputs = outputs.unwrap_or_else(|| vec![PathBuf::from(default_output_filename)]); + let mut valid_outputs = Vec::with_capacity(outputs.len()); + for file in outputs { + // Is the output file type supported? + let ext = file.extension().and_then(|os_str| os_str.to_str()); + match ext.and_then(|s| VisOutputType::from_str(s).ok()) { + Some(t) => { + can_write_to_file(&file)?; + valid_outputs.push((file, t)); + } + None => { + return Err(HyperdriveError::VisWrite(format!( + "An invalid output format was specified ({}). Supported:\n{}", + ext.unwrap_or(""), + *VIS_OUTPUT_EXTENSIONS + ))) + } + } + } + + Vec1::try_from_vec(valid_outputs).expect("cannot be empty") + }; + + let vis_str = output_files.iter().map(|(pb, _)| pb.display()).join(", "); + if let Some(vis_description) = vis_description { + vis_printer + .push_line(format!("Writing {vis_description} visibilities to: {vis_str}").into()); + } else { + vis_printer.push_line(format!("Writing visibilities to: {vis_str}").into()); + } + + let mut block = vec![]; + if time_average_factor.get() != 1 || freq_average_factor.get() != 1 { + block.push( + format!( + "Time averaging {}x ({}s)", + time_average_factor, + input_vis_time_res.to_seconds() * time_average_factor.get() as f64 + ) + .into(), + ); + + block.push( + format!( + "Freq. averaging {}x ({}kHz)", + freq_average_factor, + input_vis_freq_res_hz * freq_average_factor.get() as f64 / 1000.0 + ) + .into(), + ); + } + vis_printer.push_block(block); + if write_smallest_contiguous_band { + vis_printer.push_line("Writing the smallest possible contiguous band, ignoring any flagged fine channels at the edges of the SPW".into()); + } + vis_printer.display(); + + let timeblocks = + timesteps_to_timeblocks(timestamps, input_vis_time_res, time_average_factor, None); + + Ok(OutputVisParams { + output_files, + output_time_average_factor: time_average_factor, + output_freq_average_factor: freq_average_factor, + output_timeblocks: timeblocks, + write_smallest_contiguous_band, + }) + } +} + +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(super) struct SkyModelWithVetoArgs { + /// Path to the sky-model source list file. + #[clap(short, long, help_heading = "SKY MODEL")] + pub(super) source_list: Option, + + #[clap(long, help = SOURCE_LIST_TYPE_HELP.as_str(), help_heading = "SKY MODEL")] + pub(super) source_list_type: Option, + + /// The number of sources to use in the source list. The default is to use + /// them all. Example: If 1000 sources are specified here, then the top 1000 + /// sources are used (based on their flux densities after the beam + /// attenuation) within the specified source distance cutoff. + #[clap(short, long, help_heading = "SKY MODEL")] + pub(super) num_sources: Option, + + #[clap(long, help = SOURCE_DIST_CUTOFF_HELP.as_str(), help_heading = "SKY MODEL")] + pub(super) source_dist_cutoff: Option, + + #[clap(long, help = VETO_THRESHOLD_HELP.as_str(), help_heading = "SKY MODEL")] + pub(super) veto_threshold: Option, +} + +impl SkyModelWithVetoArgs { + pub(super) fn merge(self, other: Self) -> Self { + Self { + source_list: self.source_list.or(other.source_list), + source_list_type: self.source_list_type.or(other.source_list_type), + num_sources: self.num_sources.or(other.num_sources), + source_dist_cutoff: self.source_dist_cutoff.or(other.source_dist_cutoff), + veto_threshold: self.veto_threshold.or(other.veto_threshold), + } + } + + pub(super) fn parse( + self, + phase_centre: RADec, + lst_rad: f64, + array_latitude_rad: f64, + veto_freqs_hz: &[f64], + beam: &dyn Beam, + ) -> Result { + let Self { + source_list, + source_list_type, + num_sources, + source_dist_cutoff, + veto_threshold, + } = self; + + let mut printer = InfoPrinter::new("Sky model info".into()); + + // Handle the source list argument. + let sl_pb: PathBuf = match source_list { + None => return Err(ReadSourceListError::NoSourceList), + Some(sl) => { + // If the specified source list file can't be found, treat + // it as a glob and expand it to find a match. + let pb = PathBuf::from(&sl); + if pb.exists() { + pb + } else { + get_single_match_from_glob(&sl)? + } + } + }; + + // Read the source list file. If the type was manually specified, + // use that, otherwise the reading code will try all available + // kinds. + let sl_type_not_specified = source_list_type.is_none(); + let sl_type = source_list_type.and_then(|t| SourceListType::from_str(t.as_ref()).ok()); + let (mut sl, sl_type) = read_source_list_file(sl_pb, sl_type)?; + + let ComponentCounts { + num_points, + num_gaussians, + num_shapelets, + .. + } = sl.get_counts(); + printer.push_block(vec![ + format!("Source list contains {} sources", sl.len()).into(), + format!("({} components, {num_points} points, {num_gaussians} Gaussians, {num_shapelets} shapelets)", num_points + num_gaussians + num_shapelets).into() + ]); + + // If the user didn't specify the source list type, then print out + // what we found. + if sl_type_not_specified { + trace!("Successfully parsed {}-style source list", sl_type); + } + + trace!("Found {} sources in the source list", sl.len()); + // Veto any sources that may be troublesome, and/or cap the total number + // of sources. If the user doesn't specify how many source-list sources + // to use, then all sources are used. + if num_sources == Some(0) || sl.is_empty() { + return Err(ReadSourceListError::NoSources); + } + veto_sources( + &mut sl, + phase_centre, + lst_rad, + array_latitude_rad, + veto_freqs_hz, + beam, + num_sources, + source_dist_cutoff.unwrap_or(DEFAULT_CUTOFF_DISTANCE), + veto_threshold.unwrap_or(DEFAULT_VETO_THRESHOLD), + )?; + if sl.is_empty() { + return Err(ReadSourceListError::NoSourcesAfterVeto); + } + + { + let ComponentCounts { + num_points, + num_gaussians, + num_shapelets, + .. + } = sl.get_counts(); + let num_components = num_points + num_gaussians + num_shapelets; + printer.push_block(vec![ + format!( + "Using {} sources with a total of {num_components} components", + sl.len() + ) + .into(), + format!( + "{num_points} points, {num_gaussians} Gaussians, {num_shapelets} shapelets" + ) + .into(), + ]); + if num_components > 10000 { + "Using more than 10,000 sky model components!".warn(); + } + if log::log_enabled!(Trace) { + trace!("Using sources:"); + let mut v = Vec::with_capacity(5); + for source in sl.keys() { + if v.len() == 5 { + trace!(" {v:?}"); + v.clear(); + } + v.push(source); + } + if !v.is_empty() { + trace!(" {v:?}"); + } + } + } + + printer.display(); + + Ok(sl) + } +} + +#[derive(Parser, Debug, Clone, Copy, Default, Serialize, Deserialize)] +pub(super) struct ModellingArgs { + /// If specified, don't precess the array to J2000. We assume that sky-model + /// sources are specified in the J2000 epoch. + #[clap(long, help_heading = "MODELLING")] + #[serde(default)] + pub(super) no_precession: bool, + + /// Use the CPU for visibility generation. This is deliberately made + /// non-default because using a GPU is much faster. + #[cfg(feature = "cuda")] + #[clap(long, help_heading = "MODELLING")] + #[serde(default)] + pub(super) cpu: bool, +} + +impl ModellingArgs { + pub(super) fn merge(self, other: Self) -> Self { + Self { + no_precession: self.no_precession || other.no_precession, + #[cfg(feature = "cuda")] + cpu: self.cpu || other.cpu, + } + } + + pub(super) fn parse(self) -> ModellingParams { + let ModellingArgs { + no_precession, + #[cfg(feature = "cuda")] + cpu, + } = self; + + #[cfg(feature = "cuda")] + if cpu { + MODEL_DEVICE.store(ModelDevice::Cpu); + } + + let d = MODEL_DEVICE.load(); + let mut printer = InfoPrinter::new("Sky- and beam-modelling info".into()); + let mut block = vec![]; + match d { + ModelDevice::Cpu => { + block.push(format!("Using CPU with {} precision", d.get_precision()).into()); + block.push(crate::model::get_cpu_info().into()); + } + + #[cfg(feature = "cuda")] + ModelDevice::Cuda => { + block.push(format!("Using GPU with {} precision", d.get_precision()).into()); + let (device_info, driver_info) = match crate::cuda::get_device_info() { + Ok(i) => i, + Err(e) => { + // For some reason, despite hyperdrive being compiled + // with the "cuda" feature, we failed to get the device + // info. Maybe there's no CUDA-capable device present. + // Either way, we cannot continue. I'd rather not have + // error handling here because (1) without the "cuda" + // feature, this function will never fail on the CPU + // path, so adding error handling means the caller would + // have to handle a `Result` uselessly and (2) if this + // "petty" display function fails, then we can't use the + // GPU for real work anyway. + eprintln!("Couldn't retrieve CUDA device info for device 0, is a device present? {e}"); + std::process::exit(1); + } + }; + block.push( + format!( + "CUDA device: {} (capability {}, {} MiB)", + device_info.name, device_info.capability, device_info.total_global_mem + ) + .into(), + ); + block.push( + format!( + "CUDA driver: {}, runtime: {}", + driver_info.driver_version, driver_info.runtime_version + ) + .into(), + ); + } + } + printer.push_block(block); + printer.display(); + + ModellingParams { + apply_precession: !no_precession, + } + } +} diff --git a/src/cli/common/printers.rs b/src/cli/common/printers.rs new file mode 100644 index 00000000..d2c88f70 --- /dev/null +++ b/src/cli/common/printers.rs @@ -0,0 +1,141 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +/// Pretty printers for reporting information. +use std::{borrow::Cow, sync::Mutex}; + +const VERTICAL: char = '│'; +const UP_AND_RIGHT: char = '└'; +const VERTICAL_AND_RIGHT: char = '├'; + +lazy_static::lazy_static! { + static ref WARNING_PRINTER: Mutex = Mutex::new(WarningPrinter::new()); +} + +pub(crate) struct InfoPrinter { + title: Cow<'static, str>, + blocks: Vec>>, +} + +impl InfoPrinter { + pub(crate) fn new(title: Cow<'static, str>) -> Self { + Self { + title, + blocks: vec![], + } + } + + pub(super) fn overwrite_title(&mut self, title: Cow<'static, str>) { + self.title = title; + } + + pub(crate) fn push_line(&mut self, line: Cow<'static, str>) { + self.blocks.push(vec![line]); + } + + pub(crate) fn push_block(&mut self, block: Vec>) { + self.blocks.push(block); + } + + pub(crate) fn display(self) { + log::info!("{}", console::style(self.title).bold()); + let num_blocks = self.blocks.len(); + for (i_block, block) in self.blocks.into_iter().enumerate() { + let num_lines = block.len(); + for (i_line, line) in block.into_iter().enumerate() { + let symbol = match (i_line, i_line + 1 == num_lines, i_block + 1 == num_blocks) { + (0, false, _) => VERTICAL_AND_RIGHT, + (0, _, false) => VERTICAL_AND_RIGHT, + (0, true, true) => UP_AND_RIGHT, + _ => VERTICAL, + }; + log::info!("{symbol} {line}"); + } + } + log::info!(""); + } +} + +struct WarningPrinter { + blocks: Vec>>, +} + +impl WarningPrinter { + fn new() -> Self { + Self { blocks: vec![] } + } + + fn push_line(&mut self, line: Cow<'static, str>) { + self.blocks.push(vec![line]); + } + + fn push_block(&mut self, block: Vec>) { + self.blocks.push(block); + } + + fn display(&mut self) { + log::debug!("Displaying warnings"); + if self.blocks.is_empty() { + return; + } + + log::warn!("{}", console::style("Warnings").bold()); + let num_blocks = self.blocks.len(); + for (i_block, block) in self.blocks.iter().enumerate() { + let num_lines = block.len(); + for (i_line, line) in block.iter().enumerate() { + let symbol = match (i_line, i_line + 1 == num_lines, i_block + 1 == num_blocks) { + (0, false, _) => VERTICAL_AND_RIGHT, + (0, _, false) => VERTICAL_AND_RIGHT, + (0, true, true) => UP_AND_RIGHT, + _ => VERTICAL, + }; + log::warn!("{symbol} {line}"); + } + } + log::warn!(""); + self.blocks.clear(); + } +} + +pub(crate) trait Warn { + fn warn(self); +} + +impl Warn for &'static str { + fn warn(self) { + WARNING_PRINTER.lock().unwrap().push_line(self.into()); + } +} + +impl Warn for String { + fn warn(self) { + WARNING_PRINTER.lock().unwrap().push_line(self.into()); + } +} + +impl Warn for Cow<'static, str> { + fn warn(self) { + WARNING_PRINTER.lock().unwrap().push_line(self); + } +} + +impl Warn for Vec> { + fn warn(self) { + WARNING_PRINTER.lock().unwrap().push_block(self); + } +} + +impl Warn for [Cow<'static, str>; N] { + fn warn(self) { + WARNING_PRINTER.lock().unwrap().push_block(self.to_vec()); + } +} + +/// Print out any warnings that have been collected as CLI arguments have been +/// parsed. This should only be called once before all arguments have been +/// parsed into parameters. +pub(crate) fn display_warnings() { + WARNING_PRINTER.lock().unwrap().display(); +} diff --git a/src/cli/common/tests.rs b/src/cli/common/tests.rs new file mode 100644 index 00000000..b942f606 --- /dev/null +++ b/src/cli/common/tests.rs @@ -0,0 +1,84 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Tests against command-line interfaces that aren't big enough to go in their +//! own modules. + +use marlu::{constants::MWA_LAT_RAD, RADec}; + +use crate::srclist::ReadSourceListError; + +use super::{BeamArgs, SkyModelWithVetoArgs}; + +#[test] +fn all_sources_vetoed_causes_error() { + let beam = BeamArgs { + no_beam: true, + ..Default::default() + } + .parse(128, None, None, None) + .expect("no problems setting up a NoBeam"); + + let source_list = Some( + "test_files/1090008640/srclist_pumav3_EoR0aegean_EoR1pietro+ForA_1090008640_peel100.txt" + .to_string(), + ); + + // First, verify that vetoing all sources < 100 Jy leaves nothing behind. + let result = SkyModelWithVetoArgs { + source_list: source_list.clone(), + veto_threshold: Some(100.0), + ..Default::default() + } + .parse( + RADec::from_degrees(0.0, -30.0), + 0.0, + MWA_LAT_RAD, + &[150e6], + &*beam, + ); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ReadSourceListError::NoSourcesAfterVeto + )); + + // Deliberately set the number of sources to 0. + let result = SkyModelWithVetoArgs { + source_list: source_list.clone(), + num_sources: Some(0), + ..Default::default() + } + .parse( + RADec::from_degrees(0.0, -30.0), + 0.0, + MWA_LAT_RAD, + &[150e6], + &*beam, + ); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ReadSourceListError::NoSources + )); + + // Set the source dist cutoff to something not useful. + let result = SkyModelWithVetoArgs { + source_list: source_list.clone(), + source_dist_cutoff: Some(0.01), + ..Default::default() + } + .parse( + RADec::from_degrees(0.0, -30.0), + 0.0, + MWA_LAT_RAD, + &[150e6], + &*beam, + ); + assert!(result.is_err()); + assert!(matches!( + result.unwrap_err(), + ReadSourceListError::NoSourcesAfterVeto + )); +} diff --git a/src/cli/di_calibrate/mod.rs b/src/cli/di_calibrate/mod.rs index 2514b924..e9095e11 100644 --- a/src/cli/di_calibrate/mod.rs +++ b/src/cli/di_calibrate/mod.rs @@ -2,76 +2,57 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. +//! Parse calibration arguments into parameters. -//! Handling of calibration arguments. -//! -//! Strategy: Users give arguments to hyperdrive (handled by [calibrate::args]). -//! hyperdrive turns arguments into parameters (handled by [calibrate::params]). -//! Using this paradigm, the code to handle arguments and parameters (and -//! associated errors) can be neatly split. - -mod error; -mod params; #[cfg(test)] -pub(crate) mod tests; - -pub(crate) use error::DiCalArgsError; -pub(crate) use params::DiCalParams; +mod tests; -use std::{fs::File, io::Read, path::PathBuf, str::FromStr}; +use std::{num::NonZeroUsize, path::PathBuf, str::FromStr}; use clap::Parser; use itertools::Itertools; -use log::{debug, info, trace}; +use log::{debug, info, log_enabled, trace, Level::Debug}; +use marlu::{ + pos::{precession::precess_time, xyz::xyzs_to_cross_uvws}, + LatLngHeight, XyzGeodetic, +}; +use rayon::prelude::*; use serde::{Deserialize, Serialize}; -use strum::IntoEnumIterator; -use strum_macros::{Display, EnumIter, EnumString}; +use vec1::{vec1, Vec1}; +use super::common::{ + display_warnings, BeamArgs, InfoPrinter, InputVisArgs, ModellingArgs, OutputVisArgs, + SkyModelWithVetoArgs, Warn, ARG_FILE_HELP, +}; use crate::{ - di_calibrate::DiCalibrateError, - help_texts::*, - io::write::VIS_OUTPUT_EXTENSIONS, + averaging::{parse_time_average_factor, timesteps_to_timeblocks, AverageFactorError}, + io::write::{can_write_to_file, VIS_OUTPUT_EXTENSIONS}, + params::{DiCalParams, ModellingParams}, solutions::{self, CalSolutionType, CalibrationSolutions, CAL_SOLUTION_EXTENSIONS}, - unit_parsing::WAVELENGTH_FORMATS, + unit_parsing::{parse_wavelength, WavelengthUnit, WAVELENGTH_FORMATS}, HyperdriveError, }; -#[derive(Debug, Display, EnumIter, EnumString)] -enum ArgFileTypes { - #[strum(serialize = "toml")] - Toml, - #[strum(serialize = "json")] - Json, -} - // The default minimum baseline cutoff. -pub(crate) const DEFAULT_UVW_MIN: &str = "50λ"; +const DEFAULT_UVW_MIN: &str = "50λ"; -/// The maximum number of times to iterate when performing "MitchCal" in +/// The maximum number of times to iterate when performing calibration in /// direction-independent calibration. -pub(crate) const DEFAULT_MAX_ITERATIONS: u32 = 50; +const DEFAULT_MAX_ITERATIONS: u32 = 50; -/// The threshold to satisfy convergence when performing "MitchCal" in +/// The threshold to satisfy convergence when performing calibration in /// direction-independent calibration. -pub(crate) const DEFAULT_STOP_THRESHOLD: f64 = 1e-8; +const DEFAULT_STOP_THRESHOLD: f64 = 1e-8; -/// The minimum threshold to satisfy convergence when performing "MitchCal" in +/// The minimum threshold to satisfy convergence when performing calibration in /// direction-independent calibration. Reaching this threshold counts as /// "converged", but it's not as good as the stop threshold. -pub(crate) const DEFAULT_MIN_THRESHOLD: f64 = 1e-4; +const DEFAULT_MIN_THRESHOLD: f64 = 1e-4; -pub(crate) const DEFAULT_OUTPUT_SOLUTIONS_FILENAME: &str = "hyperdrive_solutions.fits"; +const DEFAULT_OUTPUT_SOLUTIONS_FILENAME: &str = "hyperdrive_solutions.fits"; lazy_static::lazy_static! { - static ref ARG_FILE_TYPES_COMMA_SEPARATED: String = ArgFileTypes::iter().join(", "); - - static ref ARG_FILE_HELP: String = - format!("All of the arguments to di-calibrate may be specified in a file. Any CLI arguments override parameters set in the file. Supported formats: {}", *ARG_FILE_TYPES_COMMA_SEPARATED); - - static ref OUTPUTS_HELP: String = + static ref DI_SOLS_OUTPUTS_HELP: String = format!("Paths to the output calibration solution files. Supported formats: {}. Default: {}", *CAL_SOLUTION_EXTENSIONS, DEFAULT_OUTPUT_SOLUTIONS_FILENAME); static ref MODEL_FILENAME_HELP: String = @@ -84,44 +65,46 @@ lazy_static::lazy_static! { format!("The maximum UVW length to use. This value must have a unit annotated. Allowed units: {}. No default.", *WAVELENGTH_FORMATS); static ref MAX_ITERATIONS_HELP: String = - format!("The maximum number of times to iterate when performing \"MitchCal\". Default: {DEFAULT_MAX_ITERATIONS}"); + format!("The maximum number of times to iterate during calibration. Default: {DEFAULT_MAX_ITERATIONS}"); static ref STOP_THRESHOLD_HELP: String = - format!("The threshold at which we stop iterating when performing \"MitchCal\". Default: {DEFAULT_STOP_THRESHOLD:e}"); + format!("The threshold at which we stop iterating during calibration. Default: {DEFAULT_STOP_THRESHOLD:e}"); static ref MIN_THRESHOLD_HELP: String = - format!("The minimum threshold to satisfy convergence when performing \"MitchCal\". Even when this threshold is exceeded, iteration will continue until max iterations or the stop threshold is reached. Default: {DEFAULT_MIN_THRESHOLD:e}"); + format!("The minimum threshold to satisfy convergence during calibration. Even when this threshold is exceeded, iteration will continue until max iterations or the stop threshold is reached. Default: {DEFAULT_MIN_THRESHOLD:e}"); } #[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] -pub struct DiCalArgs { - #[clap(name = "ARGUMENTS_FILE", help = ARG_FILE_HELP.as_str(), parse(from_os_str))] - pub args_file: Option, +struct DiCalCliArgs { + #[clap(short='o', long="outputs", multiple_values(true), help = DI_SOLS_OUTPUTS_HELP.as_str(), help_heading = "OUTPUT FILES")] + solutions: Option>, - /// Paths to input data files to be calibrated. These can include a metafits - /// file, gpubox files, mwaf files, a measurement set and/or uvfits files. - #[clap(short, long, multiple_values(true), help_heading = "INPUT FILES")] - pub data: Option>, + /// The number of timesteps to average together during calibration. Also + /// supports a target time resolution (e.g. 8s). If this is 0, then all data + /// are averaged together. Default: 0. e.g. If this variable is 4, then we + /// produce calibration solutions in timeblocks with up to 4 timesteps each. + /// If the variable is instead 4s, then each timeblock contains up to 4s + /// worth of data. + #[clap(short, long, help_heading = "CALIBRATION")] + timesteps_per_timeblock: Option, - /// Path to the sky-model source list file. - #[clap(short, long, help_heading = "INPUT FILES")] - pub source_list: Option, + #[clap(long, help = UVW_MIN_HELP.as_str(), help_heading = "CALIBRATION")] + uvw_min: Option, - #[clap(long, help = SOURCE_LIST_TYPE_HELP.as_str(), help_heading = "INPUT FILES")] - pub source_list_type: Option, + #[clap(long, help = UVW_MAX_HELP.as_str(), help_heading = "CALIBRATION")] + uvw_max: Option, - #[clap(long, help = MS_DATA_COL_NAME_HELP, help_heading = "INPUT FILES")] - pub ms_data_column_name: Option, + #[clap(long, help = MAX_ITERATIONS_HELP.as_str(), help_heading = "CALIBRATION")] + max_iterations: Option, - /// Use a DUT1 value of 0 seconds rather than what is in the input data. - #[clap(long, help_heading = "INPUT FILES")] - pub ignore_dut1: bool, + #[clap(long, help = STOP_THRESHOLD_HELP.as_str(), help_heading = "CALIBRATION")] + stop_threshold: Option, - #[clap(short, long, multiple_values(true), help = OUTPUTS_HELP.as_str(), help_heading = "OUTPUT FILES")] - pub outputs: Option>, + #[clap(long, help = MIN_THRESHOLD_HELP.as_str(), help_heading = "CALIBRATION")] + min_threshold: Option, - #[clap(short, long, multiple_values(true), help = MODEL_FILENAME_HELP.as_str(), help_heading = "OUTPUT FILES")] - pub model_filenames: Option>, + #[clap(long, multiple_values(true), help = MODEL_FILENAME_HELP.as_str(), help_heading = "OUTPUT FILES")] + model_filenames: Option>, /// When writing out model visibilities, average this many timesteps /// together. Also supports a target time resolution (e.g. 8s). The value @@ -132,7 +115,7 @@ pub struct DiCalArgs { /// instead 4s, then 8 model timesteps are averaged together before writing /// the data out. #[clap(long, help_heading = "OUTPUT FILES")] - pub output_model_time_average: Option, + output_model_time_average: Option, /// When writing out model visibilities, average this many fine freq. /// channels together. Also supports a target freq. resolution (e.g. 80kHz). @@ -144,151 +127,49 @@ pub struct DiCalArgs { /// data out. If the variable is instead 80kHz, then 4 model fine freq. /// channels are averaged together before writing the data out. #[clap(long, help_heading = "OUTPUT FILES")] - pub output_model_freq_average: Option, - - /// The number of sources to use in the source list. The default is to use - /// them all. Example: If 1000 sources are specified here, then the top 1000 - /// sources are used (based on their flux densities after the beam - /// attenuation) within the specified source distance cutoff. - #[clap(short, long, help_heading = "SKY-MODEL SOURCES")] - pub num_sources: Option, - - #[clap(long, help = SOURCE_DIST_CUTOFF_HELP.as_str(), help_heading = "SKY-MODEL SOURCES")] - pub source_dist_cutoff: Option, - - #[clap(long, help = VETO_THRESHOLD_HELP.as_str(), help_heading = "SKY-MODEL SOURCES")] - pub veto_threshold: Option, - - /// The path to the HDF5 MWA FEE beam file. If not specified, this must be - /// provided by the MWA_BEAM_FILE environment variable. - #[clap(long, help_heading = "BEAM")] - pub beam_file: Option, - - /// Pretend that all MWA dipoles are alive and well, ignoring whatever is in - /// the metafits file. - #[clap(long, help_heading = "BEAM")] - pub unity_dipole_gains: bool, - - #[clap(long, multiple_values(true), help = DIPOLE_DELAYS_HELP.as_str(), help_heading = "BEAM")] - pub delays: Option>, - - /// Don't apply a beam response when generating a sky model. The default is - /// to use the FEE beam. - #[clap(long, help_heading = "BEAM")] - pub no_beam: bool, - - /// The number of timesteps to average together during calibration. Also - /// supports a target time resolution (e.g. 8s). If this is 0, then all data - /// are averaged together. Default: 0. e.g. If this variable is 4, then we - /// produce calibration solutions in timeblocks with up to 4 timesteps each. - /// If the variable is instead 4s, then each timeblock contains up to 4s - /// worth of data. - #[clap(short, long, help_heading = "CALIBRATION")] - pub timesteps_per_timeblock: Option, - - /// The number of fine-frequency channels to average together before - /// calibration. If this is 0, then all data is averaged together. Default: - /// 1. e.g. If the input data is in 20kHz resolution and this variable was - /// 2, then we average 40kHz worth of data into a chanblock before - /// calibration. If the variable is instead 40kHz, then each chanblock - /// contains up to 40kHz worth of data. - #[clap(short, long, help_heading = "CALIBRATION")] - pub freq_average_factor: Option, - - /// The timesteps to use from the input data. The timesteps will be - /// ascendingly sorted for calibration. No duplicates are allowed. The - /// default is to use all unflagged timesteps. - #[clap(long, multiple_values(true), help_heading = "CALIBRATION")] - pub timesteps: Option>, - - /// Use all timesteps in the data, including flagged ones. The default is to - /// use all unflagged timesteps. - #[clap(long, conflicts_with("timesteps"), help_heading = "CALIBRATION")] - pub use_all_timesteps: bool, - - #[clap(long, help = UVW_MIN_HELP.as_str(), help_heading = "CALIBRATION")] - pub uvw_min: Option, - - #[clap(long, help = UVW_MAX_HELP.as_str(), help_heading = "CALIBRATION")] - pub uvw_max: Option, - - #[clap(long, help = MAX_ITERATIONS_HELP.as_str(), help_heading = "CALIBRATION")] - pub max_iterations: Option, - - #[clap(long, help = STOP_THRESHOLD_HELP.as_str(), help_heading = "CALIBRATION")] - pub stop_thresh: Option, - - #[clap(long, help = MIN_THRESHOLD_HELP.as_str(), help_heading = "CALIBRATION")] - pub min_thresh: Option, + output_model_freq_average: Option, + + /// When writing out model visibilities, rather than writing out the entire + /// input bandwidth, write out only the smallest contiguous band. e.g. + /// Typical 40 kHz MWA data has 768 channels, but the first 2 and last 2 + /// channels are usually flagged. Turning this option on means that 764 + /// channels would be written out instead of 768. Note that other flagged + /// channels in the band are unaffected, because the data written out must + /// be contiguous. + #[clap(long, help_heading = "OUTPUT FILES")] + #[serde(default)] + output_smallest_contiguous_band: bool, +} - #[clap( - long, help = ARRAY_POSITION_HELP.as_str(), help_heading = "CALIBRATION", - number_of_values = 3, - allow_hyphen_values = true, - value_names = &["LONG_DEG", "LAT_DEG", "HEIGHT_M"] - )] - pub array_position: Option>, - - /// If specified, don't precess the array to J2000. We assume that sky-model - /// sources are specified in the J2000 epoch. - #[clap(long, help_heading = "CALIBRATION")] - pub no_precession: bool, - - #[cfg(feature = "cuda")] - /// Use the CPU for visibility generation. This is deliberately made - /// non-default because using a GPU is much faster. - #[clap(long, help_heading = "CALIBRATION")] - pub cpu: bool, - - /// Additional tiles to be flagged. These values correspond to either the - /// values in the "Antenna" column of HDU 2 in the metafits file (e.g. 0 3 - /// 127), or the "TileName" (e.g. Tile011). - #[clap(long, multiple_values(true), help_heading = "FLAGGING")] - pub tile_flags: Option>, - - /// If specified, pretend that all tiles are unflagged in the input data. - #[clap(long, help_heading = "FLAGGING")] - pub ignore_input_data_tile_flags: bool, - - /// If specified, pretend all fine channels in the input data are unflagged. - #[clap(long, help_heading = "FLAGGING")] - pub ignore_input_data_fine_channel_flags: bool, - - /// The fine channels to be flagged in each coarse channel. e.g. 0 1 16 30 - /// 31 are typical for 40 kHz data. If this is not specified, it defaults to - /// flagging 80 kHz (or as close to this as possible) at the edges, as well - /// as the centre channel for non-MWAX data. - #[clap(long, multiple_values(true), help_heading = "FLAGGING")] - pub fine_chan_flags_per_coarse_chan: Option>, - - /// The fine channels to be flagged across the whole observation band. e.g. - /// 0 767 are the first and last fine channels for 40 kHz data. - #[clap(long, multiple_values(true), help_heading = "FLAGGING")] - pub fine_chan_flags: Option>, - - #[clap(long, help = PFB_FLAVOUR_HELP.as_str(), help_heading = "RAW MWA DATA")] - pub pfb_flavour: Option, - - /// When reading in raw MWA data, don't apply digital gains. - #[clap(long, help_heading = "RAW MWA DATA")] - pub no_digital_gains: bool, - - /// When reading in raw MWA data, don't apply cable length corrections. Note - /// that some data may have already had the correction applied before it was - /// written. - #[clap(long, help_heading = "RAW MWA DATA")] - pub no_cable_length_correction: bool, - - /// When reading in raw MWA data, don't apply geometric corrections. Note - /// that some data may have already had the correction applied before it was - /// written. - #[clap(long, help_heading = "RAW MWA DATA")] - pub no_geometric_correction: bool, - - /// When reading in visibilities and generating sky-model visibilities, - /// don't draw progress bars. - #[clap(long, help_heading = "USER INTERFACE")] - pub no_progress_bars: bool, +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(super) struct DiCalArgs { + #[clap(name = "ARGUMENTS_FILE", help = ARG_FILE_HELP.as_str(), parse(from_os_str))] + args_file: Option, + + #[clap(flatten)] + #[serde(rename = "data")] + #[serde(default)] + data_args: InputVisArgs, + + #[clap(flatten)] + #[serde(rename = "sky-model")] + #[serde(default)] + srclist_args: SkyModelWithVetoArgs, + + #[clap(flatten)] + #[serde(rename = "model")] + #[serde(default)] + model_args: ModellingArgs, + + #[clap(flatten)] + #[serde(rename = "beam")] + #[serde(default)] + beam_args: BeamArgs, + + #[clap(flatten)] + #[serde(rename = "di-calibration")] + #[serde(default)] + calibration_args: DiCalCliArgs, } impl DiCalArgs { @@ -301,206 +182,457 @@ impl DiCalArgs { /// /// This function should only ever merge arguments, and not try to make /// sense of them. - pub(crate) fn merge(self) -> Result { + pub(super) fn merge(self) -> Result { + debug!("Merging command-line arguments with the argument file"); + let cli_args = self; if let Some(arg_file) = cli_args.args_file { - // Read in the file arguments. - let file_args: DiCalArgs = { - debug!( - "Attempting to parse argument file {} ...", - arg_file.display() - ); - - let mut contents = String::new(); - let file_args_extension = arg_file - .extension() - .and_then(|e| e.to_str()) - .map(|e| e.to_lowercase()) - .and_then(|e| ArgFileTypes::from_str(&e).ok()); - match file_args_extension { - Some(ArgFileTypes::Toml) => { - debug!("Parsing toml file..."); - let mut fh = File::open(&arg_file)?; - fh.read_to_string(&mut contents)?; - match toml::from_str(&contents) { - Ok(p) => p, - Err(e) => { - return Err(DiCalArgsError::TomlDecode { - file: arg_file.display().to_string(), - err: e.to_string(), - }) - } - } - } - Some(ArgFileTypes::Json) => { - debug!("Parsing json file..."); - let mut fh = File::open(&arg_file)?; - fh.read_to_string(&mut contents)?; - match serde_json::from_str(&contents) { - Ok(p) => p, - Err(e) => { - return Err(DiCalArgsError::JsonDecode { - file: arg_file.display().to_string(), - err: e.to_string(), - }) - } - } - } - - _ => { - return Err(DiCalArgsError::UnrecognisedArgFileExt( - arg_file.display().to_string(), - )) - } - } - }; - - // Ensure all of the file args are accounted for by pattern - // matching. + // Read in the file arguments. Ensure all of the file args are + // accounted for by pattern matching. let DiCalArgs { args_file: _, - data, - source_list, - source_list_type, - ms_data_column_name, - ignore_dut1, - outputs, - model_filenames, - output_model_time_average, - output_model_freq_average, - num_sources, - source_dist_cutoff, - veto_threshold, - beam_file, - unity_dipole_gains, - delays, - no_beam, - timesteps_per_timeblock, - freq_average_factor, - timesteps, - use_all_timesteps, - uvw_min, - uvw_max, - max_iterations, - stop_thresh, - min_thresh, - array_position, - no_precession, - #[cfg(feature = "cuda")] - cpu, - tile_flags, - ignore_input_data_tile_flags, - ignore_input_data_fine_channel_flags, - fine_chan_flags_per_coarse_chan, - fine_chan_flags, - pfb_flavour, - no_digital_gains, - no_cable_length_correction, - no_geometric_correction, - no_progress_bars, - } = file_args; + data_args, + srclist_args, + model_args, + beam_args, + calibration_args, + } = unpack_arg_file!(arg_file); + // Merge all the arguments, preferring the CLI args when available. Ok(DiCalArgs { args_file: None, - data: cli_args.data.or(data), - source_list: cli_args.source_list.or(source_list), - source_list_type: cli_args.source_list_type.or(source_list_type), - ms_data_column_name: cli_args.ms_data_column_name.or(ms_data_column_name), - ignore_dut1: cli_args.ignore_dut1 || ignore_dut1, - outputs: cli_args.outputs.or(outputs), - model_filenames: cli_args.model_filenames.or(model_filenames), - output_model_time_average: cli_args - .output_model_time_average - .or(output_model_time_average), - output_model_freq_average: cli_args - .output_model_freq_average - .or(output_model_freq_average), - num_sources: cli_args.num_sources.or(num_sources), - source_dist_cutoff: cli_args.source_dist_cutoff.or(source_dist_cutoff), - veto_threshold: cli_args.veto_threshold.or(veto_threshold), - beam_file: cli_args.beam_file.or(beam_file), - unity_dipole_gains: cli_args.unity_dipole_gains || unity_dipole_gains, - delays: cli_args.delays.or(delays), - no_beam: cli_args.no_beam || no_beam, - timesteps_per_timeblock: cli_args - .timesteps_per_timeblock - .or(timesteps_per_timeblock), - freq_average_factor: cli_args.freq_average_factor.or(freq_average_factor), - timesteps: cli_args.timesteps.or(timesteps), - use_all_timesteps: cli_args.use_all_timesteps || use_all_timesteps, - uvw_min: cli_args.uvw_min.or(uvw_min), - uvw_max: cli_args.uvw_max.or(uvw_max), - max_iterations: cli_args.max_iterations.or(max_iterations), - stop_thresh: cli_args.stop_thresh.or(stop_thresh), - min_thresh: cli_args.min_thresh.or(min_thresh), - array_position: cli_args.array_position.or(array_position), - no_precession: cli_args.no_precession || no_precession, - #[cfg(feature = "cuda")] - cpu: cli_args.cpu || cpu, - tile_flags: cli_args.tile_flags.or(tile_flags), - ignore_input_data_tile_flags: cli_args.ignore_input_data_tile_flags - || ignore_input_data_tile_flags, - ignore_input_data_fine_channel_flags: cli_args.ignore_input_data_fine_channel_flags - || ignore_input_data_fine_channel_flags, - fine_chan_flags_per_coarse_chan: cli_args - .fine_chan_flags_per_coarse_chan - .or(fine_chan_flags_per_coarse_chan), - fine_chan_flags: cli_args.fine_chan_flags.or(fine_chan_flags), - pfb_flavour: cli_args.pfb_flavour.or(pfb_flavour), - no_digital_gains: cli_args.no_digital_gains || no_digital_gains, - no_cable_length_correction: cli_args.no_cable_length_correction - || no_cable_length_correction, - no_geometric_correction: cli_args.no_geometric_correction - || no_geometric_correction, - no_progress_bars: cli_args.no_progress_bars || no_progress_bars, + data_args: cli_args.data_args.merge(data_args), + srclist_args: cli_args.srclist_args.merge(srclist_args), + model_args: cli_args.model_args.merge(model_args), + beam_args: cli_args.beam_args.merge(beam_args), + calibration_args: cli_args.calibration_args.merge(calibration_args), }) } else { Ok(cli_args) } } - pub(crate) fn into_params(self) -> Result { - DiCalParams::new(self) - } + /// Parse the arguments into parameters ready for calibration. + fn parse(self) -> Result { + debug!("{:#?}", self); + + let DiCalArgs { + args_file: _, + data_args, + srclist_args, + model_args, + beam_args, + calibration_args, + } = self; + + let input_vis_params = data_args.parse("DI calibrating")?; + let obs_context = input_vis_params.get_obs_context(); + let total_num_tiles = input_vis_params.get_total_num_tiles(); + + let beam = beam_args.parse( + total_num_tiles, + obs_context.dipole_delays.clone(), + obs_context.dipole_gains.clone(), + Some(obs_context.input_data_type), + )?; + let modelling_params @ ModellingParams { apply_precession } = model_args.parse(); + + let DiCalCliArgs { + timesteps_per_timeblock, + uvw_min, + uvw_max, + max_iterations, + stop_threshold, + min_threshold, + solutions, + model_filenames, + output_model_time_average, + output_model_freq_average, + output_smallest_contiguous_band, + } = calibration_args; + + let LatLngHeight { + longitude_rad, + latitude_rad, + height_metres: _, + } = obs_context.array_position; + let precession_info = precess_time( + longitude_rad, + latitude_rad, + obs_context.phase_centre, + // obs_context.timestamps[*timesteps_to_use.first()], + input_vis_params.timeblocks.first().median, + input_vis_params.dut1, + ); + let (lst_rad, latitude_rad) = if apply_precession { + ( + precession_info.lmst_j2000, + precession_info.array_latitude_j2000, + ) + } else { + (precession_info.lmst, latitude_rad) + }; - pub fn run(self, dry_run: bool) -> Result, HyperdriveError> { - let args = if self.args_file.is_some() { - trace!("Merging command-line arguments with the argument file"); - self.merge().map_err(DiCalibrateError::from)? + let source_list = srclist_args.parse( + obs_context.phase_centre, + lst_rad, + latitude_rad, + &obs_context.get_veto_freqs(), + &*beam, + )?; + + // Set up the calibration timeblocks. + let time_average_factor = parse_time_average_factor( + Some(input_vis_params.time_res), + timesteps_per_timeblock.as_deref(), + NonZeroUsize::new( + input_vis_params.timeblocks.last().timesteps.last() + - input_vis_params.timeblocks.first().timesteps.first() + + 1, + ) + .expect("is not 0"), + ) + .map_err(|e| match e { + AverageFactorError::Zero => DiCalArgsError::CalTimeFactorZero, + AverageFactorError::NotInteger => DiCalArgsError::CalTimeFactorNotInteger, + AverageFactorError::NotIntegerMultiple { out, inp } => { + DiCalArgsError::CalTimeResNotMultiple { out, inp } + } + AverageFactorError::Parse(e) => DiCalArgsError::ParseCalTimeAverageFactor(e), + })?; + let all_selected_timestamps = Vec1::try_from_vec( + input_vis_params + .timeblocks + .iter() + .flat_map(|t| &t.timestamps) + .copied() + .collect(), + ) + .expect("cannot be empty"); + let cal_timeblocks = timesteps_to_timeblocks( + &all_selected_timestamps, + input_vis_params.time_res, + time_average_factor, + None, + ); + + let mut cal_printer = InfoPrinter::new("DI calibration set up".into()); + // I'm quite bored right now. + let timeblock_plural = if input_vis_params.timeblocks.len() > 1 { + "timeblocks" } else { - self + "timeblock" + }; + let chanblock_plural = if input_vis_params.spw.chanblocks.len() > 1 { + "chanblocks" + } else { + "chanblock" + }; + cal_printer.push_block(vec![ + format!( + "{} calibration {timeblock_plural}, {} calibration {chanblock_plural}", + cal_timeblocks.len(), + input_vis_params.spw.chanblocks.len() + ) + .into(), + format!("{time_average_factor} timesteps per timeblock").into(), + // format!("{freq_average_factor} channels per chanblock").into(), // TODO: Not yet implemented + ]); + + // Set baseline weights from UVW cuts. Use a lambda from the centroid + // frequency if UVW cutoffs are specified as wavelengths. + let freq_centroid = obs_context + .fine_chan_freqs + .iter() + .map(|&u| u as f64) + .sum::() + / obs_context.fine_chan_freqs.len() as f64; + let lambda = marlu::constants::VEL_C / freq_centroid; + let (uvw_min, uvw_min_metres) = { + let (quantity, unit) = parse_wavelength(uvw_min.as_deref().unwrap_or(DEFAULT_UVW_MIN)) + .map_err(DiCalArgsError::ParseUvwMin)?; + match unit { + WavelengthUnit::M => ((quantity, unit), quantity), + WavelengthUnit::L => ((quantity, unit), quantity * lambda), + } + }; + let (uvw_max, uvw_max_metres) = match uvw_max { + None => ((f64::INFINITY, WavelengthUnit::M), f64::INFINITY), + Some(s) => { + let (quantity, unit) = parse_wavelength(&s).map_err(DiCalArgsError::ParseUvwMax)?; + match unit { + WavelengthUnit::M => ((quantity, unit), quantity), + WavelengthUnit::L => ((quantity, unit), quantity * lambda), + } + } + }; + + let unflagged_tile_xyzs: Vec = obs_context + .tile_xyzs + .par_iter() + .enumerate() + .filter(|(tile_index, _)| { + !input_vis_params + .tile_baseline_flags + .flagged_tiles + .contains(tile_index) + }) + .map(|(_, xyz)| *xyz) + .collect(); + + let (baseline_weights, num_flagged_baselines) = { + let mut baseline_weights = Vec1::try_from_vec(vec![ + 1.0; + input_vis_params + .tile_baseline_flags + .unflagged_cross_baseline_to_tile_map + .len() + ]) + .expect("not possible to have no unflagged tiles here"); + let uvws = xyzs_to_cross_uvws( + &unflagged_tile_xyzs, + obs_context.phase_centre.to_hadec(lst_rad), + ); + assert_eq!(baseline_weights.len(), uvws.len()); + let uvw_min = uvw_min_metres.powi(2); + let uvw_max = uvw_max_metres.powi(2); + let mut num_flagged_baselines = 0; + for (uvw, baseline_weight) in uvws.into_iter().zip(baseline_weights.iter_mut()) { + let uvw_length = uvw.u.powi(2) + uvw.v.powi(2) + uvw.w.powi(2); + if uvw_length < uvw_min || uvw_length > uvw_max { + *baseline_weight = 0.0; + num_flagged_baselines += 1; + } + } + (baseline_weights, num_flagged_baselines) + }; + if num_flagged_baselines == baseline_weights.len() { + return Err(DiCalArgsError::AllBaselinesFlaggedFromUvwCutoffs.into()); + } + + let mut block = vec![]; + block.push( + format!( + "Calibrating with {} of {} baselines", + baseline_weights.len() - num_flagged_baselines, + baseline_weights.len() + ) + .into(), + ); + match (uvw_min, uvw_min.0.is_infinite()) { + // Again, bored. + (_, true) => block.push("Minimum UVW cutoff: ∞".into()), + ((quantity, WavelengthUnit::M), _) => { + block.push(format!("Minimum UVW cutoff: {quantity}m").into()) + } + ((quantity, WavelengthUnit::L), _) => block.push( + format!( + "Minimum UVW cutoff: {quantity}λ ({:.3}m)", + quantity * lambda + ) + .into(), + ), + } + match (uvw_max, uvw_max.0.is_infinite()) { + (_, true) => block.push("Maximum UVW cutoff: ∞".into()), + ((quantity, WavelengthUnit::M), _) => { + block.push(format!("Maximum UVW cutoff: {quantity}m").into()) + } + ((quantity, WavelengthUnit::L), _) => block.push( + format!( + "Maximum UVW cutoff: {quantity}λ ({:.3}m)", + quantity * lambda + ) + .into(), + ), + } + // Report extra info if we need to use our own lambda (the user + // specified wavelengths). + if matches!(uvw_min.1, WavelengthUnit::L) || matches!(uvw_max.1, WavelengthUnit::L) { + block.push( + format!( + "(Used obs. centroid frequency {} MHz to convert lambdas to metres)", + freq_centroid / 1e6 + ) + .into(), + ); + } + cal_printer.push_block(block); + + let mut unflagged_fine_chan_freqs = vec![]; + let flagged_fine_chans = &input_vis_params.spw.flagged_chan_indices; + for (i_chan, &freq) in obs_context.fine_chan_freqs.iter().enumerate() { + if !flagged_fine_chans.contains(&(i_chan as u16)) { + unflagged_fine_chan_freqs.push(freq as f64); + } + } + if log_enabled!(Debug) { + let unflagged_fine_chans: Vec<_> = (0..obs_context.fine_chan_freqs.len()) + .filter(|i_chan| !flagged_fine_chans.contains(&(*i_chan as u16))) + .collect(); + match unflagged_fine_chans.as_slice() { + [] => (), + [f] => debug!("Only unflagged fine-channel: {}", f), + [f_0, .., f_n] => { + debug!("First unflagged fine-channel: {}", f_0); + debug!("Last unflagged fine-channel: {}", f_n); + } + } + + let fine_chan_flags_vec = flagged_fine_chans.iter().sorted().collect::>(); + debug!("Flagged fine-channels: {:?}", fine_chan_flags_vec); + } + // There should never be any no unflagged channels, because this + // should've been handled by the input-vis-reading code. + assert!(!unflagged_fine_chan_freqs.is_empty()); + + // Make sure the calibration thresholds are sensible. + let mut stop_threshold = stop_threshold.unwrap_or(DEFAULT_STOP_THRESHOLD); + let min_threshold = min_threshold.unwrap_or(DEFAULT_MIN_THRESHOLD); + if stop_threshold > min_threshold { + format!("Specified stop threshold ({:e}) is bigger than the min. threshold ({:e}); capping stop threshold.", stop_threshold, min_threshold).warn(); + stop_threshold = min_threshold; + } + let max_iterations = max_iterations.unwrap_or(DEFAULT_MAX_ITERATIONS); + + cal_printer.push_block(vec![ + "Chanblocks will stop iterating".into(), + format!( + "- when the iteration difference is less than {:e} (stop threshold)", + stop_threshold + ) + .into(), + format!("- or after {} iterations.", max_iterations).into(), + format!( + "Chanblocks with an iteration diff. less than {:e} are considered converged (min. threshold)", + min_threshold + ) + .into(), + ]); + + let output_solution_files = { + match solutions { + // Defaults. + None => { + let pb = PathBuf::from(DEFAULT_OUTPUT_SOLUTIONS_FILENAME); + let sol_type = pb + .extension() + .and_then(|os_str| os_str.to_str()) + .and_then(|s| CalSolutionType::from_str(s).ok()) + // Tests should pick up a bad default filename. + .expect("DEFAULT_OUTPUT_SOLUTIONS_FILENAME has an unhandled extension!"); + vec1![(pb, sol_type)] + } + Some(outputs) => { + let mut cal_sols = vec![]; + for file in outputs { + // Is the output file type supported? + let ext = file.extension().and_then(|os_str| os_str.to_str()); + match ext.and_then(|s| CalSolutionType::from_str(s).ok()) { + Some(sol_type) => { + trace!("{} is a solution output", file.display()); + can_write_to_file(&file) + .map_err(|e| HyperdriveError::Generic(e.to_string()))?; + cal_sols.push((file, sol_type)); + } + None => { + return Err(DiCalArgsError::CalibrationOutputFile { + ext: ext.unwrap_or("").to_string(), + } + .into()) + } + } + } + Vec1::try_from_vec(cal_sols).expect("cannot fail") + } + } }; + if output_solution_files.is_empty() { + return Err(DiCalArgsError::NoOutput.into()); + } - debug!("{:#?}", &args); - trace!("Converting arguments into calibration parameters"); - let parameters = args.into_params()?; + cal_printer.push_line( + format!( + "Writing calibration solutions to: {}", + output_solution_files + .iter() + .map(|(pb, _)| pb.display()) + .join(", ") + ) + .into(), + ); + + // Parse the output model vis args like normal output vis args, to + // re-use existing code (we only make the args distinct to make it clear + // that these visibilities are not calibrated, just the model vis). + let output_model_vis_params = match model_filenames { + None => None, + Some(model_filenames) => { + let output_vis_params = OutputVisArgs { + outputs: Some(model_filenames), + output_vis_time_average: output_model_time_average, + output_vis_freq_average: output_model_freq_average, + } + .parse( + input_vis_params.time_res, + input_vis_params.spw.freq_res, + &input_vis_params.timeblocks.mapped_ref(|tb| tb.median), + output_smallest_contiguous_band, + "hyp_model.uvfits", // not actually used + Some("model"), + )?; + + Some(output_vis_params) + } + }; + + cal_printer.display(); + display_warnings(); + + Ok(DiCalParams { + input_vis_params, + beam, + source_list, + cal_timeblocks, + uvw_min: uvw_min_metres, + uvw_max: uvw_max_metres, + freq_centroid, + baseline_weights, + max_iterations, + stop_threshold, + min_threshold, + output_solution_files, + output_model_vis_params, + modelling_params, + }) + } + + pub(super) fn run( + self, + dry_run: bool, + ) -> Result, HyperdriveError> { + debug!("Converting arguments into parameters"); + trace!("{:#?}", self); + let params = self.parse()?; if dry_run { info!("Dry run -- exiting now."); return Ok(None); } - let sols = parameters.calibrate()?; + let sols = params.run()?; // Write out the solutions. - if parameters.output_solutions_filenames.len() == 1 { - let (sol_type, file) = ¶meters.output_solutions_filenames[0]; + let num_solution_files = params.output_solution_files.len(); + for (i, (file, sol_type)) in params.output_solution_files.into_iter().enumerate() { match sol_type { - CalSolutionType::Fits => solutions::hyperdrive::write(&sols, file)?, - CalSolutionType::Bin => solutions::ao::write(&sols, file)?, + CalSolutionType::Fits => solutions::hyperdrive::write(&sols, &file)?, + CalSolutionType::Bin => solutions::ao::write(&sols, &file)?, } - info!("Calibration solutions written to {}", file.display()); - } else { - for (i, (sol_type, file)) in parameters - .output_solutions_filenames - .into_iter() - .enumerate() - { - match sol_type { - CalSolutionType::Fits => solutions::hyperdrive::write(&sols, &file)?, - CalSolutionType::Bin => solutions::ao::write(&sols, &file)?, - } + if num_solution_files == 1 { + info!("Calibration solutions written to {}", file.display()); + } else { if i == 0 { info!("Calibration solutions written to:"); } @@ -511,3 +643,75 @@ impl DiCalArgs { Ok(Some(sols)) } } + +/// Errors associated with DI calibration arguments. +#[derive(thiserror::Error, Debug)] +pub(super) enum DiCalArgsError { + #[error("No calibration output was specified. There must be at least one calibration solution file.")] + NoOutput, + + #[error( + "All baselines were flagged due to UVW cutoffs. Try adjusting the UVW min and/or max." + )] + AllBaselinesFlaggedFromUvwCutoffs, + + #[error("Cannot write calibration solutions to a file type '{ext}'.\nSupported formats are: {}", *crate::solutions::CAL_SOLUTION_EXTENSIONS)] + CalibrationOutputFile { ext: String }, + + #[error("Error when parsing time average factor: {0}")] + ParseCalTimeAverageFactor(crate::unit_parsing::UnitParseError), + + #[error("Calibration time average factor isn't an integer")] + CalTimeFactorNotInteger, + + #[error("Calibration time resolution isn't a multiple of input data's: {out} seconds vs {inp} seconds")] + CalTimeResNotMultiple { out: f64, inp: f64 }, + + #[error("Calibration time average factor cannot be 0")] + CalTimeFactorZero, + + // #[error("Error when parsing freq. average factor: {0}")] + // ParseCalFreqAverageFactor(crate::unit_parsing::UnitParseError), + + // #[error("Calibration freq. average factor isn't an integer")] + // CalFreqFactorNotInteger, + + // #[error("Calibration freq. resolution isn't a multiple of input data's: {out} Hz vs {inp} Hz")] + // CalFreqResNotMultiple { out: f64, inp: f64 }, + + // #[error("Calibration freq. average factor cannot be 0")] + // CalFreqFactorZero, + #[error("Error when parsing minimum UVW cutoff: {0}")] + ParseUvwMin(crate::unit_parsing::UnitParseError), + + #[error("Error when parsing maximum UVW cutoff: {0}")] + ParseUvwMax(crate::unit_parsing::UnitParseError), + + #[error(transparent)] + IO(#[from] std::io::Error), +} + +impl DiCalCliArgs { + fn merge(self, other: Self) -> Self { + Self { + timesteps_per_timeblock: self + .timesteps_per_timeblock + .or(other.timesteps_per_timeblock), + uvw_min: self.uvw_min.or(other.uvw_min), + uvw_max: self.uvw_max.or(other.uvw_max), + max_iterations: self.max_iterations.or(other.max_iterations), + stop_threshold: self.stop_threshold.or(other.stop_threshold), + min_threshold: self.min_threshold.or(other.min_threshold), + solutions: self.solutions.or(other.solutions), + model_filenames: self.model_filenames.or(other.model_filenames), + output_model_time_average: self + .output_model_time_average + .or(other.output_model_time_average), + output_model_freq_average: self + .output_model_freq_average + .or(other.output_model_freq_average), + output_smallest_contiguous_band: self.output_smallest_contiguous_band + || other.output_smallest_contiguous_band, + } + } +} diff --git a/src/cli/di_calibrate/params.rs b/src/cli/di_calibrate/params.rs deleted file mode 100644 index 29f2e9fd..00000000 --- a/src/cli/di_calibrate/params.rs +++ /dev/null @@ -1,1236 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//! Parameters required for DI calibration. - -use std::{ - collections::HashSet, - path::{Path, PathBuf}, - str::FromStr, -}; - -use hifitime::Duration; -use itertools::Itertools; -use log::{debug, log_enabled, trace, warn, Level::Debug}; -use marlu::{ - pos::{precession::precess_time, xyz::xyzs_to_cross_uvws}, - Jones, LatLngHeight, XyzGeodetic, -}; -use ndarray::prelude::*; -use rayon::prelude::*; -use vec1::Vec1; - -use super::{DiCalArgs, DiCalArgsError}; -use crate::{ - averaging::{ - channels_to_chanblocks, parse_freq_average_factor, parse_time_average_factor, - timesteps_to_timeblocks, AverageFactorError, Fence, Timeblock, - }, - beam::{create_fee_beam_object, create_no_beam_object, Beam, Delays}, - constants::{DEFAULT_CUTOFF_DISTANCE, DEFAULT_VETO_THRESHOLD}, - context::ObsContext, - di_calibrate::{calibrate_timeblocks, get_cal_vis, CalVis}, - filenames::InputDataTypes, - io::{ - get_single_match_from_glob, - read::{ - MsReader, RawDataCorrections, RawDataReader, UvfitsReader, VisInputType, VisRead, - VisReadError, - }, - write::{can_write_to_file, VisOutputType}, - }, - math::TileBaselineFlags, - messages, - model::ModellerInfo, - solutions::{CalSolutionType, CalibrationSolutions}, - srclist::{read::read_source_list_file, veto_sources, SourceList, SourceListType}, - unit_parsing::{parse_wavelength, WavelengthUnit}, -}; - -/// Parameters needed to perform calibration. -pub(crate) struct DiCalParams { - /// Interface to the MWA data, and metadata on the input data. - pub(crate) input_data: Box, - - /// If the input data is raw MWA data, these are the corrections being - /// applied as the visibilities are read. - // TODO: Populate these if reading from a MS or uvfits - this can inform us - // what corrections were used when forming those visibilities. - pub(crate) raw_data_corrections: Option, - - /// Beam object. - pub(crate) beam: Box, - - /// The sky-model source list. - pub(crate) source_list: SourceList, - - /// The minimum UVW cutoff used in calibration \[metres\]. - pub(crate) uvw_min: f64, - - /// The maximum UVW cutoff used in calibration \[metres\]. - pub(crate) uvw_max: f64, - - /// The centroid frequency of the observation used to convert UVW cutoffs - /// specified in lambdas to metres \[Hz\]. - pub(crate) freq_centroid: f64, - - /// Multiplicative factors to apply to unflagged baselines. These are mostly - /// all 1.0, but flagged baselines (perhaps due to a UVW cutoff) have values - /// of 0.0. - pub(crate) baseline_weights: Vec1, - - /// Blocks of timesteps used for calibration. Each timeblock contains - /// indices of the input data to average together during calibration. Each - /// timeblock may have a different number of timesteps; the number of blocks - /// and their lengths depends on which input data timesteps are being used - /// as well as the `time_average_factor` (i.e. the number of timesteps to - /// average during calibration; by default we average all timesteps). - /// - /// Simple examples: If we are averaging all data over time to form - /// calibration solutions, there will only be one timeblock, and that block - /// will contain all input data timestep indices. On the other hand, if - /// `time_average_factor` is 1, then there are as many timeblocks as there - /// are timesteps, and each block contains 1 timestep index. - /// - /// A more complicated example: If we are using input data timesteps 10, 11, - /// 12 and 15 with a `time_average_factor` of 4, then there will be 2 - /// timeblocks, even though there are only 4 timesteps. This is because - /// timestep 10 and 15 can't occupy the same timeblock with the "length" is - /// 4. So the first timeblock contains 10, 11 and 12, whereas the second - /// contains only 15. - pub(crate) timeblocks: Vec1, - - /// The timestep indices into the input data to be used for calibration. - pub(crate) timesteps: Vec1, - - /// The number of frequency samples to average together during calibration. - /// - /// e.g. If the input data is in 40kHz resolution and this variable was 2, - /// then we average 80kHz worth of data together during calibration. - pub(crate) freq_average_factor: usize, - - /// Spectral windows, or, groups of contiguous-bands of channels to be - /// calibrated. Multiple [Fence]s can represent a "picket fence" - /// observation. Each [Fence] is composed of chanblocks, and the unflagged - /// chanblocks are calibrated. Each chanblock may represent of multiple - /// channels, depending on `freq_average_factor`; when visibilities are read - /// from `input_data`, the channels are averaged according to - /// `freq_average_factor`. If no frequency channels are flagged, then these - /// chanblocks will represent all frequency channels. However, it's likely - /// at least some channels are flagged, so the `flagged_chanblock_indices` - /// in every [Fence] may be needed. - pub(crate) fences: Vec1, - - /// The frequencies of each of the observation's unflagged fine channels - /// \[Hz\]. - pub(crate) unflagged_fine_chan_freqs: Vec, - - /// The fine channels to be flagged across the entire observation. e.g. For - /// a 40 kHz observation, there are 768 fine channels, and this could - /// contain 0 and 767. - pub(crate) flagged_fine_chans: HashSet, - - /// Information on flagged tiles, baselines and mapping between indices. - pub(crate) tile_baseline_flags: TileBaselineFlags, - - /// The unflagged [XyzGeodetic] coordinates of each tile \[metres\]. This - /// does not change over time; it is determined only by the telescope's tile - /// layout. - pub(crate) unflagged_tile_xyzs: Vec, - - /// The Earth position of the array. This is populated by user input or the input data. - pub(crate) array_position: LatLngHeight, - - /// The UT1 - UTC offset. If this is 0, effectively UT1 == UTC, which is a - /// wrong assumption by up to 0.9s. We assume the this value does not change - /// over the timestamps used in this `DiCalParams`. - /// - /// Note that this need not be the same DUT1 in the input data's - /// [`ObsContext`]; the user may choose to suppress that DUT1 or supply - /// their own. - pub(crate) dut1: Duration, - - /// Should the array be precessed back to J2000? - pub(crate) apply_precession: bool, - - /// The maximum number of times to iterate when performing "MitchCal". - pub(crate) max_iterations: u32, - - /// The threshold at which we stop convergence when performing "MitchCal". - /// This is smaller than `min_threshold`. - pub(crate) stop_threshold: f64, - - /// The minimum threshold to satisfy convergence when performing "MitchCal". - /// Reaching this threshold counts as "converged", but it's not as good as - /// the stop threshold. This is bigger than `stop_threshold`. - pub(crate) min_threshold: f64, - - /// The paths to the files where the calibration solutions are written. The - /// same solutions are written to each file here, but the format may be - /// different (indicated by the file extension). Supported formats are - /// detailed by [super::solutions::CalSolutionType]. - pub(crate) output_solutions_filenames: Vec<(CalSolutionType, PathBuf)>, - - /// The optional sky-model visibilities files. If specified, model - /// visibilities will be written out before calibration. - pub(crate) model_files: Option>, - - /// The number of calibrated time samples to average together before writing - /// out calibrated visibilities. - pub(crate) output_model_time_average_factor: usize, - - /// The number of calibrated frequencies samples to average together before - /// writing out calibrated visibilities. - pub(crate) output_model_freq_average_factor: usize, - - /// When reading in visibilities and generating sky-model visibilities, - /// don't draw progress bars. - pub(crate) no_progress_bars: bool, - - /// Information on the sky-modelling device (CPU or CUDA-capable device). - pub(crate) modeller_info: ModellerInfo, -} - -impl DiCalParams { - /// Create a new params struct from arguments. - /// - /// If the time or frequency resolution aren't specified, they default to - /// the observation's native resolution. - /// - /// Source list vetoing is performed in this function, using the specified - /// number of sources and/or the veto threshold. - pub(crate) fn new( - DiCalArgs { - args_file: _, - data, - source_list, - source_list_type, - ms_data_column_name, - ignore_dut1, - outputs, - model_filenames, - output_model_time_average, - output_model_freq_average, - num_sources, - source_dist_cutoff, - veto_threshold, - beam_file, - unity_dipole_gains, - delays, - no_beam, - timesteps_per_timeblock, - freq_average_factor, - timesteps, - use_all_timesteps, - uvw_min, - uvw_max, - max_iterations, - stop_thresh, - min_thresh, - array_position, - no_precession, - #[cfg(feature = "cuda")] - cpu, - tile_flags, - ignore_input_data_tile_flags, - ignore_input_data_fine_channel_flags, - fine_chan_flags_per_coarse_chan, - fine_chan_flags, - pfb_flavour, - no_digital_gains, - no_cable_length_correction, - no_geometric_correction, - no_progress_bars, - }: DiCalArgs, - ) -> Result { - // If we're going to use a GPU for modelling, get the device info so we - // can ensure a CUDA-capable device is available, and so we can report - // it to the user later. - #[cfg(feature = "cuda")] - let modeller_info = if cpu { - ModellerInfo::Cpu - } else { - let (device_info, driver_info) = crate::cuda::get_device_info()?; - ModellerInfo::Cuda { - device_info, - driver_info, - } - }; - #[cfg(not(feature = "cuda"))] - let modeller_info = ModellerInfo::Cpu; - - // If the user supplied the array position, unpack it here. - let array_position = match array_position { - Some(pos) => { - if pos.len() != 3 { - return Err(DiCalArgsError::BadArrayPosition { pos }); - } - Some(LatLngHeight { - longitude_rad: pos[0].to_radians(), - latitude_rad: pos[1].to_radians(), - height_metres: pos[2], - }) - } - None => None, - }; - - // Handle input data. We expect one of three possibilities: - // - gpubox files, a metafits file (and maybe mwaf files), - // - a measurement set (and maybe a metafits file), or - // - uvfits files. - // If none or multiple of these possibilities are met, then we must fail. - let input_data_types = match data { - Some(strings) => InputDataTypes::new(&strings)?, - None => return Err(DiCalArgsError::NoInputData), - }; - let (input_data, raw_data_corrections): (Box, Option) = - match ( - input_data_types.metafits, - input_data_types.gpuboxes, - input_data_types.mwafs, - input_data_types.ms, - input_data_types.uvfits, - ) { - // Valid input for reading raw data. - (Some(meta), Some(gpuboxes), mwafs, None, None) => { - // Ensure that there's only one metafits. - let meta = if meta.len() > 1 { - return Err(DiCalArgsError::MultipleMetafits(meta)); - } else { - meta.first() - }; - - debug!("gpubox files: {:?}", &gpuboxes); - debug!("mwaf files: {:?}", &mwafs); - - let corrections = RawDataCorrections::new( - pfb_flavour.as_deref(), - !no_digital_gains, - !no_cable_length_correction, - !no_geometric_correction, - )?; - let input_data = RawDataReader::new( - meta, - &gpuboxes, - mwafs.as_deref(), - corrections, - array_position, - )?; - - messages::InputFileDetails::Raw { - obsid: input_data.mwalib_context.metafits_context.obs_id, - gpubox_count: gpuboxes.len(), - metafits_file_name: meta.display().to_string(), - mwaf: input_data.get_flags(), - raw_data_corrections: corrections, - } - .print("DI calibrating"); - - (Box::new(input_data), Some(corrections)) - } - - // Valid input for reading a measurement set. - (meta, None, None, Some(ms), None) => { - // Only one MS is supported at the moment. - let ms: PathBuf = if ms.len() > 1 { - return Err(DiCalArgsError::MultipleMeasurementSets(ms)); - } else { - ms.first().clone() - }; - - // Ensure that there's only one metafits. - let meta: Option<&Path> = match meta.as_ref() { - None => None, - Some(m) => { - if m.len() > 1 { - return Err(DiCalArgsError::MultipleMetafits(m.clone())); - } else { - Some(m.first().as_path()) - } - } - }; - - let input_data = - MsReader::new(ms.clone(), ms_data_column_name, meta, array_position) - .map_err(VisReadError::from)?; - - messages::InputFileDetails::MeasurementSet { - obsid: input_data.get_obs_context().obsid, - file_name: ms.display().to_string(), - metafits_file_name: meta.map(|m| m.display().to_string()), - } - .print("DI calibrating"); - - (Box::new(input_data), None) - } - - // Valid input for reading uvfits files. - (meta, None, None, None, Some(uvfits)) => { - // Only one uvfits is supported at the moment. - let uvfits: PathBuf = if uvfits.len() > 1 { - return Err(DiCalArgsError::MultipleUvfits(uvfits)); - } else { - uvfits.first().clone() - }; - - // Ensure that there's only one metafits. - let meta: Option<&Path> = match meta.as_ref() { - None => None, - Some(m) => { - if m.len() > 1 { - return Err(DiCalArgsError::MultipleMetafits(m.clone())); - } else { - Some(m.first()) - } - } - }; - - let input_data = UvfitsReader::new(uvfits.clone(), meta, array_position) - .map_err(VisReadError::from)?; - - messages::InputFileDetails::UvfitsFile { - obsid: input_data.get_obs_context().obsid, - file_name: uvfits.display().to_string(), - metafits_file_name: meta.map(|m| m.display().to_string()), - } - .print("DI calibrating"); - - (Box::new(input_data), None) - } - - // The following matches are for invalid combinations of input - // files. Make an error message for the user. - (Some(_), _, None, None, None) => { - let msg = "Received only a metafits file; a uvfits file, a measurement set or gpubox files are required."; - return Err(DiCalArgsError::InvalidDataInput(msg)); - } - (Some(_), _, Some(_), None, None) => { - let msg = - "Received only a metafits file and mwaf files; gpubox files are required."; - return Err(DiCalArgsError::InvalidDataInput(msg)); - } - (None, Some(_), _, None, None) => { - let msg = "Received gpuboxes without a metafits file; this is not supported."; - return Err(DiCalArgsError::InvalidDataInput(msg)); - } - (None, None, Some(_), None, None) => { - let msg = "Received mwaf files without gpuboxes and a metafits file; this is not supported."; - return Err(DiCalArgsError::InvalidDataInput(msg)); - } - (_, Some(_), _, Some(_), None) => { - let msg = "Received gpuboxes and measurement set files; this is not supported."; - return Err(DiCalArgsError::InvalidDataInput(msg)); - } - (_, Some(_), _, None, Some(_)) => { - let msg = "Received gpuboxes and uvfits files; this is not supported."; - return Err(DiCalArgsError::InvalidDataInput(msg)); - } - (_, _, _, Some(_), Some(_)) => { - let msg = "Received uvfits and measurement set files; this is not supported."; - return Err(DiCalArgsError::InvalidDataInput(msg)); - } - (_, _, Some(_), Some(_), _) => { - let msg = "Received mwafs and measurement set files; this is not supported."; - return Err(DiCalArgsError::InvalidDataInput(msg)); - } - (_, _, Some(_), _, Some(_)) => { - let msg = "Received mwafs and uvfits files; this is not supported."; - return Err(DiCalArgsError::InvalidDataInput(msg)); - } - (None, None, None, None, None) => return Err(DiCalArgsError::NoInputData), - }; - - let obs_context = input_data.get_obs_context(); - - let array_position = obs_context.array_position; - let dut1 = if ignore_dut1 { None } else { obs_context.dut1 }; - - let timesteps_to_use = { - match (use_all_timesteps, timesteps) { - (true, _) => obs_context.all_timesteps.clone(), - (false, None) => Vec1::try_from_vec(obs_context.unflagged_timesteps.clone()) - .map_err(|_| DiCalArgsError::NoTimesteps)?, - (false, Some(mut ts)) => { - // Make sure there are no duplicates. - let timesteps_hashset: HashSet<&usize> = ts.iter().collect(); - if timesteps_hashset.len() != ts.len() { - return Err(DiCalArgsError::DuplicateTimesteps); - } - - // Ensure that all specified timesteps are actually available. - for t in &ts { - if !(0..obs_context.timestamps.len()).contains(t) { - return Err(DiCalArgsError::UnavailableTimestep { - got: *t, - last: obs_context.timestamps.len() - 1, - }); - } - } - - ts.sort_unstable(); - Vec1::try_from_vec(ts).map_err(|_| DiCalArgsError::NoTimesteps)? - } - } - }; - - let precession_info = precess_time( - array_position.longitude_rad, - array_position.latitude_rad, - obs_context.phase_centre, - obs_context.timestamps[*timesteps_to_use.first()], - dut1.unwrap_or_else(|| Duration::from_seconds(0.0)), - ); - let (lmst, latitude) = if no_precession { - (precession_info.lmst, array_position.latitude_rad) - } else { - ( - precession_info.lmst_j2000, - precession_info.array_latitude_j2000, - ) - }; - - // The length of the tile XYZ collection is the total number of tiles in - // the array, even if some tiles are flagged. - let total_num_tiles = obs_context.get_total_num_tiles(); - - // Assign the tile flags. - let flagged_tiles = - obs_context.get_tile_flags(ignore_input_data_tile_flags, tile_flags.as_deref())?; - let num_unflagged_tiles = total_num_tiles - flagged_tiles.len(); - if log_enabled!(Debug) { - obs_context.print_debug_tile_statuses(); - } - if num_unflagged_tiles == 0 { - return Err(DiCalArgsError::NoTiles); - } - messages::ArrayDetails { - array_position: Some(array_position), - array_latitude_j2000: if no_precession { - None - } else { - Some(precession_info.array_latitude_j2000) - }, - total_num_tiles, - num_unflagged_tiles, - flagged_tiles: &flagged_tiles - .iter() - .sorted() - .map(|&i| (obs_context.tile_names[i].as_str(), i)) - .collect::>(), - } - .print(); - - let dipole_delays = match delays { - // We have user-provided delays; check that they're are sensible, - // regardless of whether we actually need them. - Some(d) => { - if d.len() != 16 || d.iter().any(|&v| v > 32) { - return Err(DiCalArgsError::BadDelays); - } - Some(Delays::Partial(d)) - } - - // No delays were provided; use whatever was in the input data. - None => obs_context.dipole_delays.as_ref().cloned(), - }; - - let beam: Box = if no_beam { - create_no_beam_object(total_num_tiles) - } else { - let mut dipole_delays = dipole_delays.ok_or(DiCalArgsError::NoDelays)?; - let dipole_gains = if unity_dipole_gains { - None - } else { - // If we don't have dipole gains from the input data, then - // we issue a warning that we must assume no dead dipoles. - if obs_context.dipole_gains.is_none() { - match input_data.get_input_data_type() { - VisInputType::MeasurementSet => { - warn!("Measurement sets cannot supply dead dipole information."); - warn!("Without a metafits file, we must assume all dipoles are alive."); - warn!("This will make beam Jones matrices inaccurate in sky-model generation."); - } - VisInputType::Uvfits => { - warn!("uvfits files cannot supply dead dipole information."); - warn!("Without a metafits file, we must assume all dipoles are alive."); - warn!("This will make beam Jones matrices inaccurate in sky-model generation."); - } - VisInputType::Raw => unreachable!(), - } - } - obs_context.dipole_gains.clone() - }; - if dipole_gains.is_none() { - // If we don't have dipole gains, we must assume all dipoles are - // "alive". But, if any dipole delays are 32, then the beam code - // will still ignore those dipoles. So use ideal dipole delays - // for all tiles. - let ideal_delays = dipole_delays.get_ideal_delays(); - - // Warn the user if they wanted unity dipole gains but the ideal - // dipole delays contain 32. - if unity_dipole_gains && ideal_delays.iter().any(|&v| v == 32) { - warn!( - "Some ideal dipole delays are 32; these dipoles will not have unity gains" - ); - } - dipole_delays.set_to_ideal_delays(); - } - - create_fee_beam_object(beam_file, total_num_tiles, dipole_delays, dipole_gains)? - }; - let beam_file = beam.get_beam_file(); - debug!("Beam file: {beam_file:?}"); - - // Set up frequency information. Determine all of the fine-channel flags. - let mut flagged_fine_chans: HashSet = match fine_chan_flags { - Some(flags) => flags.into_iter().collect(), - None => HashSet::new(), - }; - if !ignore_input_data_fine_channel_flags { - flagged_fine_chans.extend(obs_context.flagged_fine_chans.iter()); - } - // Assign the per-coarse-channel fine-channel flags. - let fine_chan_flags_per_coarse_chan: HashSet = { - let mut out_flags = HashSet::new(); - // Handle user flags. - if let Some(fine_chan_flags_per_coarse_chan) = fine_chan_flags_per_coarse_chan { - out_flags.extend(fine_chan_flags_per_coarse_chan.into_iter()); - } - // Handle input data flags. - if let (false, Some(flags)) = ( - ignore_input_data_fine_channel_flags, - obs_context.flagged_fine_chans_per_coarse_chan.as_ref(), - ) { - out_flags.extend(flags.iter()); - } - out_flags - }; - // Take the per-coarse-channel flags and put them in the fine channel - // flags. - match ( - obs_context.mwa_coarse_chan_nums.as_ref(), - obs_context.num_fine_chans_per_coarse_chan.map(|n| n.get()), - ) { - (Some(mwa_coarse_chan_nums), Some(num_fine_chans_per_coarse_chan)) => { - for (i_cc, _) in mwa_coarse_chan_nums.iter().enumerate() { - for f in &fine_chan_flags_per_coarse_chan { - flagged_fine_chans.insert(f + num_fine_chans_per_coarse_chan * i_cc); - } - } - } - - // We can't do anything without the number of fine channels per - // coarse channel. - (_, None) => { - warn!("Flags per coarse channel were specified, but no information on how many fine channels per coarse channel is available; flags are being ignored."); - } - - // If we don't have MWA coarse channel numbers but we do have - // per-coarse-channel flags, warn the user. - (None, _) => { - if !fine_chan_flags_per_coarse_chan.is_empty() { - warn!("Flags per coarse channel were specified, but no MWA coarse channel information is available; flags are being ignored."); - } - } - } - let mut unflagged_fine_chan_freqs = vec![]; - for (i_chan, &freq) in obs_context.fine_chan_freqs.iter().enumerate() { - if !flagged_fine_chans.contains(&i_chan) { - unflagged_fine_chan_freqs.push(freq as f64); - } - } - if log_enabled!(Debug) { - let unflagged_fine_chans: Vec<_> = (0..obs_context.fine_chan_freqs.len()) - .filter(|i_chan| !flagged_fine_chans.contains(i_chan)) - .collect(); - match unflagged_fine_chans.as_slice() { - [] => (), - [f] => debug!("Only unflagged fine-channel: {}", f), - [f_0, .., f_n] => { - debug!("First unflagged fine-channel: {}", f_0); - debug!("Last unflagged fine-channel: {}", f_n); - } - } - - let fine_chan_flags_vec = flagged_fine_chans.iter().sorted().collect::>(); - debug!("Flagged fine-channels: {:?}", fine_chan_flags_vec); - } - if unflagged_fine_chan_freqs.is_empty() { - return Err(DiCalArgsError::NoChannels); - } - - messages::ObservationDetails { - dipole_delays: beam.get_ideal_dipole_delays(), - beam_file, - num_tiles_with_dead_dipoles: if unity_dipole_gains { - None - } else { - obs_context.dipole_gains.as_ref().map(|array| { - array - .outer_iter() - .filter(|tile_dipole_gains| { - tile_dipole_gains.iter().any(|g| g.abs() < f64::EPSILON) - }) - .count() - }) - }, - phase_centre: obs_context.phase_centre, - pointing_centre: obs_context.pointing_centre, - dut1, - lmst: Some(precession_info.lmst), - lmst_j2000: if no_precession { - Some(precession_info.lmst_j2000) - } else { - None - }, - available_timesteps: Some(&obs_context.all_timesteps), - unflagged_timesteps: Some(&obs_context.unflagged_timesteps), - using_timesteps: Some(×teps_to_use), - first_timestamp: Some(obs_context.timestamps[*timesteps_to_use.first()]), - last_timestamp: if timesteps_to_use.len() > 1 { - Some(obs_context.timestamps[*timesteps_to_use.last()]) - } else { - None - }, - time_res: obs_context.time_res, - total_num_channels: obs_context.fine_chan_freqs.len(), - num_unflagged_channels: Some(unflagged_fine_chan_freqs.len()), - flagged_chans_per_coarse_chan: obs_context - .flagged_fine_chans_per_coarse_chan - .as_deref(), - first_freq_hz: Some(*obs_context.fine_chan_freqs.first() as f64), - last_freq_hz: Some(*obs_context.fine_chan_freqs.last() as f64), - first_unflagged_freq_hz: unflagged_fine_chan_freqs.first().copied(), - last_unflagged_freq_hz: unflagged_fine_chan_freqs.last().copied(), - freq_res_hz: obs_context.freq_res, - } - .print(); - - // Validate calibration solution outputs. - let output_solutions_filenames = { - match outputs { - // Defaults. - None => { - let pb = - PathBuf::from(crate::cli::di_calibrate::DEFAULT_OUTPUT_SOLUTIONS_FILENAME); - let sol_type = pb - .extension() - .and_then(|os_str| os_str.to_str()) - .and_then(|s| CalSolutionType::from_str(s).ok()) - // Tests should pick up a bad default filename. - .expect("DEFAULT_OUTPUT_SOLUTIONS_FILENAME has an unhandled extension!"); - vec![(sol_type, pb)] - } - Some(outputs) => { - let mut cal_sols = vec![]; - for file in outputs { - // Is the output file type supported? - let ext = file.extension().and_then(|os_str| os_str.to_str()); - match ext.and_then(|s| CalSolutionType::from_str(s).ok()) { - Some(sol_type) => { - trace!("{} is a solution output", file.display()); - can_write_to_file(&file)?; - cal_sols.push((sol_type, file)); - } - None => { - return Err(DiCalArgsError::CalibrationOutputFile { - ext: ext.unwrap_or("").to_string(), - }) - } - } - } - cal_sols - } - } - }; - if output_solutions_filenames.is_empty() { - return Err(DiCalArgsError::NoOutput); - } - - // Handle the output model files, if specified. - let model_files = if let Some(model_files) = model_filenames { - let mut valid_model_files = Vec::with_capacity(model_files.len()); - for file in model_files { - // Is the output file type supported? - let ext = file.extension().and_then(|os_str| os_str.to_str()); - match ext.and_then(|s| VisOutputType::from_str(s).ok()) { - Some(t) => { - can_write_to_file(&file)?; - valid_model_files.push((file, t)); - } - None => { - return Err(DiCalArgsError::VisFileType { - ext: ext.unwrap_or("").to_string(), - }) - } - } - } - Vec1::try_from_vec(valid_model_files).ok() - } else { - None - }; - - // Set up the timeblocks. - let time_average_factor = parse_time_average_factor( - obs_context.time_res, - timesteps_per_timeblock.as_deref(), - *timesteps_to_use.last() - *timesteps_to_use.first() + 1, - ) - .map_err(|e| match e { - AverageFactorError::Zero => DiCalArgsError::CalTimeFactorZero, - AverageFactorError::NotInteger => DiCalArgsError::CalTimeFactorNotInteger, - AverageFactorError::NotIntegerMultiple { out, inp } => { - DiCalArgsError::CalTimeResNotMultiple { out, inp } - } - AverageFactorError::Parse(e) => DiCalArgsError::ParseCalTimeAverageFactor(e), - })?; - // Check that the factor is not too big. - let time_average_factor = if time_average_factor > timesteps_to_use.len() { - warn!( - "Cannot average {} timesteps during calibration; only {} are being used. Capping.", - time_average_factor, - timesteps_to_use.len() - ); - timesteps_to_use.len() - } else { - time_average_factor - }; - - let timeblocks = timesteps_to_timeblocks( - &obs_context.timestamps, - time_average_factor, - ×teps_to_use, - ); - - // Set up the chanblocks. - let freq_average_factor = - parse_freq_average_factor(obs_context.freq_res, freq_average_factor.as_deref(), 1) - .map_err(|e| match e { - AverageFactorError::Zero => DiCalArgsError::CalFreqFactorZero, - AverageFactorError::NotInteger => DiCalArgsError::CalFreqFactorNotInteger, - AverageFactorError::NotIntegerMultiple { out, inp } => { - DiCalArgsError::CalFreqResNotMultiple { out, inp } - } - AverageFactorError::Parse(e) => DiCalArgsError::ParseCalFreqAverageFactor(e), - })?; - // Check that the factor is not too big. - let freq_average_factor = if freq_average_factor > unflagged_fine_chan_freqs.len() { - warn!( - "Cannot average {} channels; only {} are being used. Capping.", - freq_average_factor, - unflagged_fine_chan_freqs.len() - ); - unflagged_fine_chan_freqs.len() - } else { - freq_average_factor - }; - - let fences = channels_to_chanblocks( - &obs_context.fine_chan_freqs, - obs_context.freq_res, - freq_average_factor, - &flagged_fine_chans, - ); - // There must be at least one chanblock for calibration. - match fences.as_slice() { - // No fences is the same as no chanblocks. - [] => return Err(DiCalArgsError::NoChannels), - [f] => { - // Check that the chanblocks aren't all flagged. - if f.chanblocks.is_empty() { - return Err(DiCalArgsError::NoChannels); - } - } - [f, ..] => { - // Check that the chanblocks aren't all flagged. - if f.chanblocks.is_empty() { - return Err(DiCalArgsError::NoChannels); - } - // TODO: Allow picket fence. - eprintln!("\"Picket fence\" data detected. hyperdrive does not support this right now -- exiting."); - eprintln!("See for more info: https://MWATelescope.github.io/mwa_hyperdrive/defs/mwa/picket_fence.html"); - std::process::exit(1); - } - } - let fences = Vec1::try_from_vec(fences).map_err(|_| DiCalArgsError::NoChannels)?; - - let tile_index_maps = TileBaselineFlags::new(total_num_tiles, flagged_tiles); - let flagged_tiles = &tile_index_maps.flagged_tiles; - - let unflagged_tile_xyzs: Vec = obs_context - .tile_xyzs - .par_iter() - .enumerate() - .filter(|(tile_index, _)| !flagged_tiles.contains(tile_index)) - .map(|(_, xyz)| *xyz) - .collect(); - - // Set baseline weights from UVW cuts. Use a lambda from the centroid - // frequency if UVW cutoffs are specified as wavelengths. - let freq_centroid = obs_context - .fine_chan_freqs - .iter() - .map(|&u| u as f64) - .sum::() - / obs_context.fine_chan_freqs.len() as f64; - let lambda = marlu::constants::VEL_C / freq_centroid; - let (uvw_min, uvw_min_metres) = { - let (quantity, unit) = parse_wavelength( - uvw_min - .as_deref() - .unwrap_or(crate::cli::di_calibrate::DEFAULT_UVW_MIN), - ) - .map_err(DiCalArgsError::ParseUvwMin)?; - match unit { - WavelengthUnit::M => ((quantity, unit), quantity), - WavelengthUnit::L => ((quantity, unit), quantity * lambda), - } - }; - let (uvw_max, uvw_max_metres) = match uvw_max { - None => ((f64::INFINITY, WavelengthUnit::M), f64::INFINITY), - Some(s) => { - let (quantity, unit) = parse_wavelength(&s).map_err(DiCalArgsError::ParseUvwMax)?; - match unit { - WavelengthUnit::M => ((quantity, unit), quantity), - WavelengthUnit::L => ((quantity, unit), quantity * lambda), - } - } - }; - - let (baseline_weights, num_flagged_baselines) = { - let mut baseline_weights = Vec1::try_from_vec(vec![ - 1.0; - tile_index_maps - .unflagged_cross_baseline_to_tile_map - .len() - ]) - .map_err(|_| DiCalArgsError::NoTiles)?; - let uvws = xyzs_to_cross_uvws( - &unflagged_tile_xyzs, - obs_context.phase_centre.to_hadec(lmst), - ); - assert_eq!(baseline_weights.len(), uvws.len()); - let uvw_min = uvw_min_metres.powi(2); - let uvw_max = uvw_max_metres.powi(2); - let mut num_flagged_baselines = 0; - for (uvw, baseline_weight) in uvws.into_iter().zip(baseline_weights.iter_mut()) { - let uvw_length = uvw.u.powi(2) + uvw.v.powi(2) + uvw.w.powi(2); - if uvw_length < uvw_min || uvw_length > uvw_max { - *baseline_weight = 0.0; - num_flagged_baselines += 1; - } - } - (baseline_weights, num_flagged_baselines) - }; - if num_flagged_baselines == baseline_weights.len() { - return Err(DiCalArgsError::AllBaselinesFlaggedFromUvwCutoffs); - } - - // Make sure the calibration thresholds are sensible. - let mut stop_threshold = - stop_thresh.unwrap_or(crate::cli::di_calibrate::DEFAULT_STOP_THRESHOLD); - let min_threshold = min_thresh.unwrap_or(crate::cli::di_calibrate::DEFAULT_MIN_THRESHOLD); - if stop_threshold > min_threshold { - warn!("Specified stop threshold ({}) is bigger than the min. threshold ({}); capping stop threshold.", stop_threshold, min_threshold); - stop_threshold = min_threshold; - } - let max_iterations = - max_iterations.unwrap_or(crate::cli::di_calibrate::DEFAULT_MAX_ITERATIONS); - - messages::CalibrationDetails { - timesteps_per_timeblock: time_average_factor, - channels_per_chanblock: freq_average_factor, - num_timeblocks: timeblocks.len(), - num_chanblocks: fences.first().chanblocks.len(), - uvw_min, - uvw_max, - num_calibration_baselines: baseline_weights.len() - num_flagged_baselines, - total_num_baselines: baseline_weights.len(), - lambda, - freq_centroid, - min_threshold, - stop_threshold, - max_iterations, - } - .print(); - - let mut source_list: SourceList = { - // Handle the source list argument. - let sl_pb: PathBuf = match source_list { - None => return Err(DiCalArgsError::NoSourceList), - Some(sl) => { - // If the specified source list file can't be found, treat - // it as a glob and expand it to find a match. - let pb = PathBuf::from(&sl); - if pb.exists() { - pb - } else { - get_single_match_from_glob(&sl)? - } - } - }; - - // Read the source list file. If the type was manually specified, - // use that, otherwise the reading code will try all available - // kinds. - let sl_type_specified = source_list_type.is_none(); - let sl_type = source_list_type.and_then(|t| SourceListType::from_str(t.as_ref()).ok()); - let (sl, sl_type) = match crate::misc::expensive_op( - || read_source_list_file(sl_pb, sl_type), - "Still reading source list file", - ) { - Ok((sl, sl_type)) => (sl, sl_type), - Err(e) => return Err(DiCalArgsError::from(e)), - }; - - // If the user didn't specify the source list type, then print out - // what we found. - if sl_type_specified { - trace!("Successfully parsed {}-style source list", sl_type); - } - - sl - }; - trace!("Found {} sources in the source list", source_list.len()); - // Veto any sources that may be troublesome, and/or cap the total number - // of sources. If the user doesn't specify how many source-list sources - // to use, then all sources are used. - if num_sources == Some(0) || source_list.is_empty() { - return Err(DiCalArgsError::NoSources); - } - veto_sources( - &mut source_list, - obs_context.phase_centre, - lmst, - latitude, - &obs_context.get_veto_freqs(), - &*beam, - num_sources, - source_dist_cutoff.unwrap_or(DEFAULT_CUTOFF_DISTANCE), - veto_threshold.unwrap_or(DEFAULT_VETO_THRESHOLD), - )?; - if source_list.is_empty() { - return Err(DiCalArgsError::NoSourcesAfterVeto); - } - - messages::SkyModelDetails { - source_list: &source_list, - } - .print(); - - messages::print_modeller_info(&modeller_info); - - // Handle output visibility arguments. - let (output_model_time_average_factor, output_model_freq_average_factor) = match model_files - .as_ref() - { - None => { - // If we're not writing out model visibilities but arguments - // are set for them, issue warnings. - match (output_model_time_average, output_model_freq_average) { - (None, None) => (), - (time, freq) => { - warn!("Not writing out model visibilities, but"); - if time.is_some() { - warn!(" output_model_time_average is set"); - } - if freq.is_some() { - warn!(" output_model_freq_average is set"); - } - } - } - // We're not writing a file; it doesn't matter what these values - // are. - (1, 1) - } - Some(ms) => { - // Parse and verify user input (specified resolutions must - // evenly divide the input data's resolutions). - let time_factor = parse_time_average_factor( - obs_context.time_res, - output_model_time_average.as_deref(), - 1, - ) - .map_err(|e| match e { - AverageFactorError::Zero => DiCalArgsError::OutputVisTimeAverageFactorZero, - AverageFactorError::NotInteger => DiCalArgsError::OutputVisTimeFactorNotInteger, - AverageFactorError::NotIntegerMultiple { out, inp } => { - DiCalArgsError::OutputVisTimeResNotMultiple { out, inp } - } - AverageFactorError::Parse(e) => { - DiCalArgsError::ParseOutputVisTimeAverageFactor(e) - } - })?; - let freq_factor = parse_freq_average_factor( - obs_context.freq_res.map(|f| f * freq_average_factor as f64), - output_model_freq_average.as_deref(), - 1, - ) - .map_err(|e| match e { - AverageFactorError::Zero => DiCalArgsError::OutputVisFreqAverageFactorZero, - AverageFactorError::NotInteger => DiCalArgsError::OutputVisFreqFactorNotInteger, - AverageFactorError::NotIntegerMultiple { out, inp } => { - DiCalArgsError::OutputVisFreqResNotMultiple { out, inp } - } - AverageFactorError::Parse(e) => { - DiCalArgsError::ParseOutputVisFreqAverageFactor(e) - } - })?; - - // Test that we can write to the output files. - for m in ms { - can_write_to_file(&m.0)?; - } - - (time_factor, freq_factor) - } - }; - { - messages::OutputFileDetails { - output_solutions: &output_solutions_filenames - .iter() - .map(|p| p.1.clone()) - .collect::>(), - vis_type: "model", - output_vis: model_files.as_ref(), - input_vis_time_res: obs_context.time_res, - input_vis_freq_res: obs_context.freq_res, - output_vis_time_average_factor: output_model_time_average_factor, - output_vis_freq_average_factor: output_model_freq_average_factor, - } - .print(); - } - - Ok(DiCalParams { - input_data, - raw_data_corrections, - beam, - source_list, - uvw_min: uvw_min_metres, - uvw_max: uvw_max_metres, - freq_centroid, - baseline_weights, - timeblocks, - timesteps: timesteps_to_use, - freq_average_factor, - fences, - unflagged_fine_chan_freqs, - flagged_fine_chans, - tile_baseline_flags: tile_index_maps, - unflagged_tile_xyzs, - array_position, - dut1: dut1.unwrap_or_else(|| Duration::from_seconds(0.0)), - apply_precession: !no_precession, - max_iterations, - stop_threshold, - min_threshold, - output_solutions_filenames, - model_files, - output_model_time_average_factor, - output_model_freq_average_factor, - no_progress_bars, - modeller_info, - }) - } - - /// Get read-only access to the [ObsContext]. This reflects the state of the - /// observation in the data. - pub(crate) fn get_obs_context(&self) -> &ObsContext { - self.input_data.get_obs_context() - } - - /// Get the total number of tiles in the observation, i.e. flagged and - /// unflagged. - pub(crate) fn get_total_num_tiles(&self) -> usize { - self.get_obs_context().get_total_num_tiles() - } - - /// Get the number of unflagged tiles to be used (may not match what is in - /// the observation data). - pub(crate) fn get_num_unflagged_tiles(&self) -> usize { - self.get_total_num_tiles() - self.tile_baseline_flags.flagged_tiles.len() - } - - /// The number of calibration timesteps. - pub(crate) fn get_num_timesteps(&self) -> usize { - self.timeblocks - .iter() - .fold(0, |acc, tb| acc + tb.range.len()) - } - - /// The number of unflagged cross-correlation baselines. - pub(crate) fn get_num_unflagged_cross_baselines(&self) -> usize { - let n = self.unflagged_tile_xyzs.len(); - (n * (n - 1)) / 2 - } - - pub(crate) fn read_crosses( - &self, - vis: ArrayViewMut2>, - weights: ArrayViewMut2, - timestep: usize, - ) -> Result<(), VisReadError> { - self.input_data.read_crosses( - vis, - weights, - timestep, - &self.tile_baseline_flags, - &self.flagged_fine_chans, - ) - } - - /// Use the [`DiCalParams`] to perform calibration and obtain solutions. - pub(crate) fn calibrate(&self) -> Result { - // TODO: Fix. - if self.freq_average_factor > 1 { - panic!("Frequency averaging isn't working right now. Sorry!"); - } - - let CalVis { - vis_data_tfb, - vis_weights_tfb, - vis_model_tfb, - pols, - } = get_cal_vis(self, !self.no_progress_bars)?; - assert_eq!(vis_weights_tfb.len_of(Axis(2)), self.baseline_weights.len()); - - // The shape of the array containing output Jones matrices. - let num_timeblocks = self.timeblocks.len(); - let num_chanblocks = self.fences.first().chanblocks.len(); - let num_unflagged_tiles = self.get_num_unflagged_tiles(); - - if log_enabled!(Debug) { - let shape = (num_timeblocks, num_unflagged_tiles, num_chanblocks); - debug!( - "Shape of DI Jones matrices array: ({} timeblocks, {} tiles, {} chanblocks; {} MiB)", - shape.0, - shape.1, - shape.2, - shape.0 * shape.1 * shape.2 * std::mem::size_of::>() - // 1024 * 1024 == 1 MiB. - / 1024 / 1024 - ); - } - - let (sols, results) = calibrate_timeblocks( - vis_data_tfb.view(), - vis_model_tfb.view(), - &self.timeblocks, - // TODO: Picket fences. - &self.fences.first().chanblocks, - self.max_iterations, - self.stop_threshold, - self.min_threshold, - pols, - !self.no_progress_bars, - true, - ); - - // "Complete" the solutions. - let sols = sols.into_cal_sols(self, Some(results.map(|r| r.max_precision))); - - Ok(sols) - } -} diff --git a/src/cli/di_calibrate/tests.rs b/src/cli/di_calibrate/tests.rs index 8d709cf5..4522a44b 100644 --- a/src/cli/di_calibrate/tests.rs +++ b/src/cli/di_calibrate/tests.rs @@ -4,65 +4,108 @@ //! Tests against calibration parameters and converting arguments to parameters. -use approx::assert_abs_diff_eq; +use std::{collections::HashSet, fs::File, io::Write, path::PathBuf}; + +use approx::{assert_abs_diff_eq, assert_abs_diff_ne}; +use clap::Parser; use marlu::{ constants::{MWA_HEIGHT_M, MWA_LAT_DEG, MWA_LONG_DEG}, - LatLngHeight, + Jones, LatLngHeight, }; +use ndarray::prelude::*; +use serial_test::serial; +use tempfile::{tempdir, TempDir}; -use super::DiCalArgsError::{ - BadArrayPosition, BadDelays, CalFreqFactorNotInteger, CalFreqResNotMultiple, - CalTimeFactorNotInteger, CalTimeResNotMultiple, CalibrationOutputFile, InvalidDataInput, - MultipleMeasurementSets, MultipleMetafits, MultipleUvfits, NoInputData, -}; +use super::{DiCalArgs, DiCalCliArgs}; use crate::{ - beam::BeamType, - tests::reduced_obsids::{ - get_reduced_1090008640, get_reduced_1090008640_ms, get_reduced_1090008640_uvfits, + cli::{ + common::{BeamArgs, InputVisArgs, SkyModelWithVetoArgs}, + vis_simulate::VisSimulateArgs, + }, + io::read::{ + fits::{fits_get_col, fits_get_required_key, fits_open, fits_open_hdu}, + MsReader, VisRead, }, + math::TileBaselineFlags, + params::CalVis, + tests::{ + get_reduced_1090008640_ms, get_reduced_1090008640_raw, get_reduced_1090008640_uvfits, + DataAsStrings, + }, + CalibrationSolutions, HyperdriveError, }; +fn get_reduced_1090008640(use_fee_beam: bool, include_mwaf: bool) -> DiCalArgs { + let DataAsStrings { + metafits, + mut vis, + mut mwafs, + srclist, + } = get_reduced_1090008640_raw(); + let mut files = vec![metafits]; + files.append(&mut vis); + if include_mwaf { + files.append(&mut mwafs); + } + + DiCalArgs { + args_file: None, + data_args: InputVisArgs { + files: Some(files), + ..Default::default() + }, + srclist_args: SkyModelWithVetoArgs { + source_list: Some(srclist), + ..Default::default() + }, + beam_args: BeamArgs { + no_beam: !use_fee_beam, + ..Default::default() + }, + ..Default::default() + } +} + #[test] fn test_new_params_defaults() { let args = get_reduced_1090008640(false, true); - let params = args.into_params().unwrap(); - let obs_context = params.get_obs_context(); + let params = args.parse().unwrap(); + let input_vis_params = ¶ms.input_vis_params; + let obs_context = input_vis_params.get_obs_context(); + let total_num_tiles = obs_context.get_total_num_tiles(); + let num_unflagged_tiles = + total_num_tiles - input_vis_params.tile_baseline_flags.flagged_tiles.len(); // The default time resolution should be 2.0s, as per the metafits. assert_abs_diff_eq!(obs_context.time_res.unwrap().to_seconds(), 2.0); // The default freq resolution should be 40kHz, as per the metafits. assert_abs_diff_eq!(obs_context.freq_res.unwrap(), 40e3); // No tiles are flagged in the input data, and no additional flags were // supplied. - assert_eq!( - obs_context.get_total_num_tiles(), - obs_context.get_num_unflagged_tiles() - ); - assert_eq!(params.tile_baseline_flags.flagged_tiles.len(), 0); + assert_eq!(total_num_tiles, num_unflagged_tiles); + assert_eq!(input_vis_params.tile_baseline_flags.flagged_tiles.len(), 0); // By default there are 5 flagged channels per coarse channel. We only have - // one coarse channel here so we expect 27/32 channels. Also no picket fence - // shenanigans. - assert_eq!(params.fences.len(), 1); - assert_eq!(params.fences[0].chanblocks.len(), 27); + // one coarse channel here so we expect 27/32 channels. + assert_eq!(input_vis_params.spw.chanblocks.len(), 27); } #[test] fn test_new_params_no_input_flags() { let mut args = get_reduced_1090008640(false, true); - args.ignore_input_data_tile_flags = true; - args.ignore_input_data_fine_channel_flags = true; - let params = args.into_params().unwrap(); - let obs_context = params.get_obs_context(); + args.data_args.ignore_input_data_tile_flags = true; + args.data_args.ignore_input_data_fine_channel_flags = true; + let params = args.parse().unwrap(); + let input_vis_params = ¶ms.input_vis_params; + let obs_context = input_vis_params.get_obs_context(); + let total_num_tiles = obs_context.get_total_num_tiles(); + let num_unflagged_tiles = + total_num_tiles - input_vis_params.tile_baseline_flags.flagged_tiles.len(); assert_abs_diff_eq!(obs_context.time_res.unwrap().to_seconds(), 2.0); assert_abs_diff_eq!(obs_context.freq_res.unwrap(), 40e3); - assert_eq!( - obs_context.get_total_num_tiles(), - obs_context.get_num_unflagged_tiles(), - ); - assert_eq!(params.tile_baseline_flags.flagged_tiles.len(), 0); + assert_eq!(total_num_tiles, num_unflagged_tiles); + assert_eq!(input_vis_params.tile_baseline_flags.flagged_tiles.len(), 0); - assert_eq!(params.fences.len(), 1); - assert_eq!(params.fences[0].chanblocks.len(), 32); + assert_eq!(input_vis_params.spw.chanblocks.len(), 32); } #[test] @@ -70,26 +113,26 @@ fn test_new_params_time_averaging() { // The native time resolution is 2.0s. let mut args = get_reduced_1090008640(false, true); // 1 is a valid time average factor. - args.timesteps_per_timeblock = Some("1".to_string()); - let result = args.into_params(); + args.calibration_args.timesteps_per_timeblock = Some("1".to_string()); + let result = args.parse(); assert!(result.is_ok()); let mut args = get_reduced_1090008640(false, true); // 2 is a valid time average factor. - args.timesteps_per_timeblock = Some("2".to_string()); - let result = args.into_params(); + args.calibration_args.timesteps_per_timeblock = Some("2".to_string()); + let result = args.parse(); assert!(result.is_ok()); let mut args = get_reduced_1090008640(false, true); // 4.0s should be a multiple of 2.0s - args.timesteps_per_timeblock = Some("4.0s".to_string()); - let result = args.into_params(); + args.calibration_args.timesteps_per_timeblock = Some("4.0s".to_string()); + let result = args.parse(); assert!(result.is_ok()); let mut args = get_reduced_1090008640(false, true); // 8.0s should be a multiple of 2.0s - args.timesteps_per_timeblock = Some("8.0s".to_string()); - let result = args.into_params(); + args.calibration_args.timesteps_per_timeblock = Some("8.0s".to_string()); + let result = args.parse(); assert!(result.is_ok()); } @@ -98,24 +141,32 @@ fn test_new_params_time_averaging_fail() { // The native time resolution is 2.0s. let mut args = get_reduced_1090008640(false, true); // 1.5 is an invalid time average factor. - args.timesteps_per_timeblock = Some("1.5".to_string()); - let result = args.into_params(); + args.calibration_args.timesteps_per_timeblock = Some("1.5".to_string()); + let result = args.parse(); assert!(result.is_err()); - assert!(matches!(result, Err(CalTimeFactorNotInteger))); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Calibration time average factor isn't an integer")); let mut args = get_reduced_1090008640(false, true); // 2.01s is not a multiple of 2.0s - args.timesteps_per_timeblock = Some("2.01s".to_string()); - let result = args.into_params(); + args.calibration_args.timesteps_per_timeblock = Some("2.01s".to_string()); + let result = args.parse(); assert!(result.is_err()); - assert!(matches!(result, Err(CalTimeResNotMultiple { .. }))); + assert!(result.err().unwrap().to_string().contains( + "Calibration time resolution isn't a multiple of input data's: 2.01 seconds vs 2 seconds" + )); let mut args = get_reduced_1090008640(false, true); // 3.0s is not a multiple of 2.0s - args.timesteps_per_timeblock = Some("3.0s".to_string()); - let result = args.into_params(); + args.calibration_args.timesteps_per_timeblock = Some("3.0s".to_string()); + let result = args.parse(); assert!(result.is_err()); - assert!(matches!(result, Err(CalTimeResNotMultiple { .. }))); + assert!(result.err().unwrap().to_string().contains( + "Calibration time resolution isn't a multiple of input data's: 3 seconds vs 2 seconds" + )); } #[test] @@ -123,90 +174,57 @@ fn test_new_params_freq_averaging() { // The native freq. resolution is 40kHz. let mut args = get_reduced_1090008640(false, true); // 3 is a valid freq average factor. - args.freq_average_factor = Some("3".to_string()); - let result = args.into_params(); + args.data_args.freq_average = Some("3".to_string()); + let result = args.parse(); assert!(result.is_ok()); let mut args = get_reduced_1090008640(false, true); // 80kHz should be a multiple of 40kHz - args.freq_average_factor = Some("80kHz".to_string()); - let result = args.into_params(); + args.data_args.freq_average = Some("80kHz".to_string()); + let result = args.parse(); assert!(result.is_ok()); let mut args = get_reduced_1090008640(false, true); // 200kHz should be a multiple of 40kHz - args.freq_average_factor = Some("200kHz".to_string()); - let result = args.into_params(); + args.data_args.freq_average = Some("200kHz".to_string()); + let result = args.parse(); assert!(result.is_ok()); } -#[test] -fn test_new_params_freq_averaging_fail() { - // The native freq. resolution is 40kHz. - let mut args = get_reduced_1090008640(false, true); - // 1.5 is an invalid freq average factor. - args.freq_average_factor = Some("1.5".to_string()); - let result = args.into_params(); - assert!(result.is_err()); - assert!(matches!(result, Err(CalFreqFactorNotInteger))); - - let mut args = get_reduced_1090008640(false, true); - // 10kHz is not a multiple of 40kHz - args.freq_average_factor = Some("10kHz".to_string()); - let result = args.into_params(); - assert!(result.is_err()); - assert!(matches!(result, Err(CalFreqResNotMultiple { .. }))); - - let mut args = get_reduced_1090008640(false, true); - // 79kHz is not a multiple of 40kHz - args.freq_average_factor = Some("79kHz".to_string()); - let result = args.into_params(); - assert!(result.is_err()); - assert!(matches!(result, Err(CalFreqResNotMultiple { .. }))); -} - #[test] fn test_new_params_tile_flags() { // 1090008640 has no flagged tiles in its metafits. let mut args = get_reduced_1090008640(false, true); // Manually flag antennas 1, 2 and 3. - args.tile_flags = Some(vec!["1".to_string(), "2".to_string(), "3".to_string()]); - let params = match args.into_params() { - Ok(p) => p, - Err(e) => panic!("{}", e), - }; - assert_eq!(params.tile_baseline_flags.flagged_tiles.len(), 3); - assert!(params.tile_baseline_flags.flagged_tiles.contains(&1)); - assert!(params.tile_baseline_flags.flagged_tiles.contains(&2)); - assert!(params.tile_baseline_flags.flagged_tiles.contains(&3)); + args.data_args.tile_flags = Some(vec!["1".to_string(), "2".to_string(), "3".to_string()]); + let params = args.parse().unwrap(); + let input_vis_params = ¶ms.input_vis_params; + let tile_baseline_flags = &input_vis_params.tile_baseline_flags; + assert_eq!(tile_baseline_flags.flagged_tiles.len(), 3); + assert!(tile_baseline_flags.flagged_tiles.contains(&1)); + assert!(tile_baseline_flags.flagged_tiles.contains(&2)); + assert!(tile_baseline_flags.flagged_tiles.contains(&3)); assert_eq!( - params - .tile_baseline_flags + tile_baseline_flags .tile_to_unflagged_cross_baseline_map .len(), 7750 ); assert_eq!( - params - .tile_baseline_flags - .tile_to_unflagged_cross_baseline_map[&(0, 4)], + tile_baseline_flags.tile_to_unflagged_cross_baseline_map[&(0, 4)], 0 ); assert_eq!( - params - .tile_baseline_flags - .tile_to_unflagged_cross_baseline_map[&(0, 5)], + tile_baseline_flags.tile_to_unflagged_cross_baseline_map[&(0, 5)], 1 ); assert_eq!( - params - .tile_baseline_flags - .tile_to_unflagged_cross_baseline_map[&(0, 6)], + tile_baseline_flags.tile_to_unflagged_cross_baseline_map[&(0, 6)], 2 ); assert_eq!( - params + input_vis_params .tile_baseline_flags .tile_to_unflagged_cross_baseline_map[&(0, 7)], 3 @@ -214,187 +232,804 @@ fn test_new_params_tile_flags() { } #[test] -fn test_handle_delays() { +fn test_handle_invalid_output() { let mut args = get_reduced_1090008640(false, true); - args.no_beam = false; - // only 3 delays instead of 16 expected - args.delays = Some((0..3).collect::>()); - let result = args.clone().into_params(); + args.calibration_args.solutions = Some(vec!["invalid.out".into()]); + let result = args.parse(); assert!(result.is_err()); - assert!(matches!(result, Err(BadDelays))); - - // delays > 32 - args.delays = Some((20..36).collect::>()); - let result = args.clone().into_params(); + assert!(result + .err() + .unwrap() + .to_string() + .contains("Cannot write calibration solutions to a file type 'out'")); +} - assert!(result.is_err()); - assert!(matches!(result, Err(BadDelays))); - - let delays = (0..16).collect::>(); - args.delays = Some(delays.clone()); - let result = args.into_params(); - - assert!(result.is_ok(), "result={:?} not Ok", result.err().unwrap()); - - let fee_beam = result.unwrap().beam; - assert_eq!(fee_beam.get_beam_type(), BeamType::FEE); - let beam_delays = fee_beam - .get_dipole_delays() - .expect("expected some delays to be provided from the FEE beam!"); - // Each row of the delays should be the same as the 16 input values. - for row in beam_delays.outer_iter() { - assert_eq!(row.as_slice().unwrap(), delays); +#[track_caller] +fn test_args_with_arg_file(args: &DiCalArgs) { + let temp_dir = tempdir().expect("Couldn't make tempdir"); + for filename in ["calibrate.toml", "calibrate.json"] { + let arg_file = temp_dir.path().join(filename); + let mut f = File::create(&arg_file).expect("couldn't make file"); + let ser = match filename.split('.').last() { + Some("toml") => { + toml::to_string_pretty(&args).expect("couldn't serialise DiCalArgs as toml") + } + Some("json") => { + serde_json::to_string_pretty(&args).expect("couldn't serialise DiCalArgs as json") + } + _ => unreachable!(), + }; + eprintln!("{ser}"); + write!(&mut f, "{ser}").unwrap(); + // I don't know why, but the first argument ("di-calibrate" here) is + // necessary, and the result is the same for any string! + let parsed_args = DiCalArgs::parse_from(["di-calibrate", &arg_file.display().to_string()]) + .merge() + .unwrap(); + assert!(parsed_args.data_args.files.is_some()); + let result = parsed_args.run(true).expect("args happily ingested"); + // No solutions returned because we did a dry run. + assert!(result.is_none()); } } #[test] -fn test_unity_dipole_gains() { - let mut args = get_reduced_1090008640(false, true); - args.no_beam = false; - let params = args.clone().into_params().unwrap(); - - let fee_beam = params.beam; - assert_eq!(fee_beam.get_beam_type(), BeamType::FEE); - let beam_gains = fee_beam.get_dipole_gains(); - - // Because there are dead dipoles in the metafits, we expect some of the - // gains to not be 1. - assert!(!beam_gains.iter().all(|g| (*g - 1.0).abs() < f64::EPSILON)); - - // Now ignore dead dipoles. - args.unity_dipole_gains = true; - let params = args.into_params().unwrap(); - - let fee_beam = params.beam; - assert_eq!(fee_beam.get_beam_type(), BeamType::FEE); - let beam_gains = fee_beam.get_dipole_gains(); - - // Now we expect all gains to be 1s, as we're ignoring dead dipoles. - assert!(beam_gains.iter().all(|g| (*g - 1.0).abs() < f64::EPSILON)); - // Verify that there are no dead dipoles in the delays. - assert!(fee_beam - .get_dipole_delays() - .unwrap() - .iter() - .all(|d| *d != 32)); +fn arg_file_absolute_paths() { + let args = get_reduced_1090008640(false, true); + test_args_with_arg_file(&args); } #[test] -fn test_handle_no_input() { +fn arg_file_absolute_globs() { let mut args = get_reduced_1090008640(false, true); - args.data = None; - let result = args.into_params(); - - assert!(result.is_err()); - assert!(matches!(result, Err(NoInputData))); + let first = PathBuf::from(&args.data_args.files.unwrap()[0]); + let parent = first.parent().unwrap(); + args.data_args.files = Some(vec![ + format!("{}/*.metafits", parent.display()), + format!("{}/*gpubox*", parent.display()), + format!("{}/*.mwaf", parent.display()), + ]); + test_args_with_arg_file(&args); } #[test] -fn test_handle_multiple_metafits() { - // when reading raw +fn arg_file_relative_globs() { let mut args = get_reduced_1090008640(false, true); - args.data - .as_mut() - .unwrap() - .push("test_files/1090008640_WODEN/1090008640.metafits".into()); - let result = args.into_params(); - - assert!(result.is_err()); - assert!(matches!(result, Err(MultipleMetafits(_)))); + args.data_args.files = Some(vec![ + "test_files/1090008640/*.metafits".to_string(), + "test_files/1090008640/*gpubox*".to_string(), + "test_files/1090008640/*.mwaf".to_string(), + ]); + args.srclist_args.source_list = Some("test_files/1090008640/*srclist*_100.yaml".to_string()); + test_args_with_arg_file(&args); +} - // when reading ms - let mut args = get_reduced_1090008640_ms(); - args.data - .as_mut() - .unwrap() - .push("test_files/1090008640_WODEN/1090008640.metafits".into()); - let result = args.into_params(); +#[test] +#[serial] +fn test_1090008640_di_calibrate_writes_solutions() { + let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); + let DataAsStrings { + metafits, + vis, + srclist, + .. + } = get_reduced_1090008640_raw(); + let gpufits = &vis[0]; + let sols = tmp_dir.path().join("sols.fits"); + let cal_model = tmp_dir.path().join("hyp_model.uvfits"); - assert!(result.is_err()); - assert!(matches!(result, Err(MultipleMetafits(_)))); + #[rustfmt::skip] + let cal_args = DiCalArgs::parse_from([ + "di-calibrate", + "--data", &metafits, gpufits, + "--source-list", &srclist, + "--outputs", &format!("{}", sols.display()), + "--model-filenames", &format!("{}", cal_model.display()), + ]); - // when reading uvfits - let mut args = get_reduced_1090008640_uvfits(); - args.data - .as_mut() - .unwrap() - .push("test_files/1090008640_WODEN/1090008640.metafits".into()); - let result = args.into_params(); + // Run di-cal and check that it succeeds + let result = cal_args.run(false); + assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); - assert!(result.is_err()); - assert!(matches!(result, Err(MultipleMetafits(_)))); + // check solutions file has been created, is readable + assert!(sols.exists(), "sols file not written"); + let sol_data = CalibrationSolutions::read_solutions_from_ext(sols, metafits.into()).unwrap(); + assert_eq!(sol_data.obsid, Some(1090008640)); } #[test] -fn test_handle_multiple_ms() { - let mut args = get_reduced_1090008640_ms(); - args.data - .as_mut() - .unwrap() - .push("test_files/1090008640/1090008640.ms".into()); - let result = args.into_params(); +fn test_1090008640_di_calibrate_uses_array_position() { + let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); + let DataAsStrings { + metafits, + vis, + srclist, + .. + } = get_reduced_1090008640_raw(); + let gpufits = &vis[0]; + let sols = tmp_dir.path().join("sols.fits"); + let cal_model = tmp_dir.path().join("hyp_model.uvfits"); - assert!(result.is_err()); - assert!(matches!(result, Err(MultipleMeasurementSets(_)))); + // with non-default array position + let exp_lat_deg = MWA_LAT_DEG - 1.; + let exp_long_deg = MWA_LONG_DEG - 1.; + let exp_height_m = MWA_HEIGHT_M - 1.; + + #[rustfmt::skip] + let cal_args = DiCalArgs::parse_from([ + "di-calibrate", + "--data", &metafits, gpufits, + "--source-list", &srclist, + "--outputs", &format!("{}", sols.display()), + "--model-filenames", &format!("{}", cal_model.display()), + "--array-position", + &format!("{exp_long_deg}"), + &format!("{exp_lat_deg}"), + &format!("{exp_height_m}"), + ]); + + let pos = cal_args.data_args.array_position.unwrap(); + + assert_abs_diff_eq!(pos[0], exp_long_deg); + assert_abs_diff_eq!(pos[1], exp_lat_deg); + assert_abs_diff_eq!(pos[2], exp_height_m); } #[test] -fn test_handle_multiple_uvfits() { - let mut args = get_reduced_1090008640_uvfits(); - args.data - .as_mut() - .unwrap() - .push("test_files/1090008640/1090008640.uvfits".into()); - let result = args.into_params(); +fn test_1090008640_di_calibrate_array_pos_requires_3_args() { + let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); + let DataAsStrings { + metafits, + vis, + srclist, + .. + } = get_reduced_1090008640_raw(); + let gpufits = &vis[0]; + let sols = tmp_dir.path().join("sols.fits"); + let cal_model = tmp_dir.path().join("hyp_model.uvfits"); + + // no height specified + let exp_lat_deg = MWA_LAT_DEG - 1.; + let exp_long_deg = MWA_LONG_DEG - 1.; + + #[rustfmt::skip] + let result = DiCalArgs::try_parse_from([ + "di-calibrate", + "--data", &metafits, gpufits, + "--source-list", &srclist, + "--outputs", &format!("{}", sols.display()), + "--model-filenames", &format!("{}", cal_model.display()), + "--array-position", + &format!("{exp_long_deg}"), + &format!("{exp_lat_deg}"), + ]); assert!(result.is_err()); - assert!(matches!(result, Err(MultipleUvfits(_)))); + assert!(matches!( + result.err().unwrap().kind(), + clap::ErrorKind::WrongNumberOfValues + )); } #[test] -fn test_handle_only_metafits() { - let mut args = get_reduced_1090008640(false, true); - args.data = Some(vec!["test_files/1090008640/1090008640.metafits".into()]); - let result = args.into_params(); +/// Generate a model with "vis-simulate" (in uvfits), then feed it to +/// "di-calibrate" and write out the model used for calibration (as uvfits). The +/// visibilities should be exactly the same. +fn test_1090008640_calibrate_model_uvfits() { + let num_timesteps = 2; + let num_chans = 10; - assert!(result.is_err()); - assert!(matches!(result, Err(InvalidDataInput(_)))); + let temp_dir = TempDir::new().expect("couldn't make tmp dir"); + let model = temp_dir.path().join("model.uvfits"); + let DataAsStrings { + metafits, srclist, .. + } = get_reduced_1090008640_raw(); + #[rustfmt::skip] + let sim_args = VisSimulateArgs::parse_from([ + "vis-simulate", + "--metafits", &metafits, + "--source-list", &srclist, + "--output-model-files", &format!("{}", model.display()), + "--num-timesteps", &format!("{num_timesteps}"), + "--num-fine-channels", &format!("{num_chans}"), + "--veto-threshold", "0.0", // Don't complicate things with vetoing + // The array position is needed because, if not specified, it's read + // slightly different out of the uvfits. + "--array-position", "116.67081523611111", "-26.703319405555554", "377.827", + ]); + + // Run vis-simulate and check that it succeeds + let result = sim_args.run(false); + assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); + + let sols = temp_dir.path().join("sols.fits"); + let cal_model = temp_dir.path().join("cal_model.uvfits"); + + #[rustfmt::skip] + let cal_args = DiCalArgs::parse_from([ + "di-calibrate", + "--data", &format!("{}", model.display()), &metafits, + "--source-list", &srclist, + "--outputs", &format!("{}", sols.display()), + "--model-filenames", &format!("{}", cal_model.display()), + "--veto-threshold", "0.0", // Don't complicate things with vetoing + "--array-position", "116.67081523611111", "-26.703319405555554", "377.827", + ]); + + // Run di-cal and check that it succeeds + let result = cal_args.parse().unwrap().run(); + assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); + let sols = result.unwrap(); + + let mut uvfits_m = fits_open(&model).unwrap(); + let hdu_m = fits_open_hdu(&mut uvfits_m, 0).unwrap(); + let gcount_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "GCOUNT").unwrap(); + let pcount_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "PCOUNT").unwrap(); + let floats_per_pol_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "NAXIS2").unwrap(); + let num_pols_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "NAXIS3").unwrap(); + let num_fine_freq_chans_m: String = + fits_get_required_key(&mut uvfits_m, &hdu_m, "NAXIS4").unwrap(); + let jd_zero_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "PZERO5").unwrap(); + let ptype4_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "PTYPE4").unwrap(); + + let mut uvfits_c = fits_open(&cal_model).unwrap(); + let hdu_c = fits_open_hdu(&mut uvfits_c, 0).unwrap(); + let gcount_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "GCOUNT").unwrap(); + let pcount_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "PCOUNT").unwrap(); + let floats_per_pol_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "NAXIS2").unwrap(); + let num_pols_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "NAXIS3").unwrap(); + let num_fine_freq_chans_c: String = + fits_get_required_key(&mut uvfits_c, &hdu_c, "NAXIS4").unwrap(); + let jd_zero_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "PZERO5").unwrap(); + let ptype4_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "PTYPE4").unwrap(); + + assert_eq!(gcount_m, gcount_c); + assert_eq!(pcount_m, pcount_c); + assert_eq!(floats_per_pol_m, floats_per_pol_c); + assert_eq!(num_pols_m, num_pols_c); + assert_eq!(num_fine_freq_chans_m, num_fine_freq_chans_c); + assert_eq!(jd_zero_m, jd_zero_c); + assert_eq!(ptype4_m, ptype4_c); + + let hdu_m = fits_open_hdu(&mut uvfits_m, 1).unwrap(); + let tile_names_m: Vec = fits_get_col(&mut uvfits_m, &hdu_m, "ANNAME").unwrap(); + let hdu_c = fits_open_hdu(&mut uvfits_c, 1).unwrap(); + let tile_names_c: Vec = fits_get_col(&mut uvfits_c, &hdu_c, "ANNAME").unwrap(); + for (tile_m, tile_c) in tile_names_m.into_iter().zip(tile_names_c.into_iter()) { + assert_eq!(tile_m, tile_c); + } + + // Test visibility values. + fits_open_hdu(&mut uvfits_m, 0).unwrap(); + let mut group_params_m = Array1::zeros(5); + let mut vis_m = Array1::zeros(10 * 4 * 3); + fits_open_hdu(&mut uvfits_c, 0).unwrap(); + let mut group_params_c = group_params_m.clone(); + let mut vis_c = vis_m.clone(); + + let mut status = 0; + for i_row in 0..gcount_m.parse::().unwrap() { + unsafe { + // ffggpe = fits_read_grppar_flt + fitsio_sys::ffggpe( + uvfits_m.as_raw(), /* I - FITS file pointer */ + 1 + i_row, /* I - group to read (1 = 1st group) */ + 1, /* I - first vector element to read (1 = 1st) */ + group_params_m.len() as i64, /* I - number of values to read */ + group_params_m.as_mut_ptr(), /* O - array of values that are returned */ + &mut status, /* IO - error status */ + ); + assert_eq!(status, 0, "Status wasn't 0"); + assert_abs_diff_ne!(group_params_m, group_params_c); + // ffggpe = fits_read_grppar_flt + fitsio_sys::ffggpe( + uvfits_c.as_raw(), /* I - FITS file pointer */ + 1 + i_row, /* I - group to read (1 = 1st group) */ + 1, /* I - first vector element to read (1 = 1st) */ + group_params_c.len() as i64, /* I - number of values to read */ + group_params_c.as_mut_ptr(), /* O - array of values that are returned */ + &mut status, /* IO - error status */ + ); + assert_eq!(status, 0, "Status wasn't 0"); + assert_abs_diff_eq!(group_params_m, group_params_c); + + // ffgpve = fits_read_img_flt + fitsio_sys::ffgpve( + uvfits_m.as_raw(), /* I - FITS file pointer */ + 1 + i_row, /* I - group to read (1 = 1st group) */ + 1, /* I - first vector element to read (1 = 1st) */ + vis_m.len() as i64, /* I - number of values to read */ + 0.0, /* I - value for undefined pixels */ + vis_m.as_mut_ptr(), /* O - array of values that are returned */ + &mut 0, /* O - set to 1 if any values are null; else 0 */ + &mut status, /* IO - error status */ + ); + assert_abs_diff_ne!(vis_m, vis_c); + // ffgpve = fits_read_img_flt + fitsio_sys::ffgpve( + uvfits_c.as_raw(), /* I - FITS file pointer */ + 1 + i_row, /* I - group to read (1 = 1st group) */ + 1, /* I - first vector element to read (1 = 1st) */ + vis_c.len() as i64, /* I - number of values to read */ + 0.0, /* I - value for undefined pixels */ + vis_c.as_mut_ptr(), /* O - array of values that are returned */ + &mut 0, /* O - set to 1 if any values are null; else 0 */ + &mut status, /* IO - error status */ + ); + assert_eq!(status, 0, "Status wasn't 0"); + assert_abs_diff_eq!(vis_m, vis_c); + }; + } + + // Inspect the solutions; they should all be close to identity. + assert_abs_diff_eq!( + sols.di_jones, + Array3::from_elem(sols.di_jones.dim(), Jones::identity()), + epsilon = 1e-15 + ); } #[test] -fn test_handle_invalid_output() { - let mut args = get_reduced_1090008640(false, true); - args.outputs = Some(vec!["invalid.out".into()]); - let result = args.into_params(); +#[serial] +/// Generate a model with "vis-simulate" (in a measurement set), then feed it to +/// "di-calibrate" and write out the model used for calibration (into a +/// measurement set). The visibilities should be exactly the same. +fn test_1090008640_calibrate_model_ms() { + let num_timesteps = 2; + let num_chans = 10; - assert!(result.is_err()); - assert!(matches!(result, Err(CalibrationOutputFile { .. }))); + let temp_dir = TempDir::new().expect("couldn't make tmp dir"); + let model = temp_dir.path().join("model.ms"); + let DataAsStrings { + metafits, + vis: _, + mwafs: _, + srclist, + } = get_reduced_1090008640_raw(); + + // Non-default array position + let lat_deg = MWA_LAT_DEG - 1.; + let long_deg = MWA_LONG_DEG - 1.; + let height_m = MWA_HEIGHT_M - 1.; + + #[rustfmt::skip] + let sim_args = VisSimulateArgs::parse_from([ + "vis-simulate", + "--metafits", &metafits, + "--source-list", &srclist, + "--output-model-files", &format!("{}", model.display()), + "--num-timesteps", &format!("{num_timesteps}"), + "--num-fine-channels", &format!("{num_chans}"), + "--array-position", + &format!("{long_deg}"), + &format!("{lat_deg}"), + &format!("{height_m}"), + ]); + + // Run vis-simulate and check that it succeeds + let result = sim_args.run(false); + assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); + + let sols = temp_dir.path().join("sols.fits"); + let cal_model = temp_dir.path().join("cal_model.ms"); + #[rustfmt::skip] + let cal_args = DiCalArgs::parse_from([ + "di-calibrate", + "--data", &format!("{}", model.display()), &metafits, + "--source-list", &srclist, + "--outputs", &format!("{}", sols.display()), + "--model-filenames", &format!("{}", cal_model.display()), + "--array-position", + &format!("{long_deg}"), + &format!("{lat_deg}"), + &format!("{height_m}"), + ]); + + // Run di-cal and check that it succeeds + let result = cal_args.parse().unwrap().run(); + assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); + let sols = result.unwrap(); + + let array_pos = LatLngHeight::mwa(); + let ms_m = MsReader::new( + model, + None, + Some(&PathBuf::from(&metafits)), + Some(array_pos), + ) + .unwrap(); + let ctx_m = ms_m.get_obs_context(); + let ms_c = MsReader::new( + cal_model, + None, + Some(&PathBuf::from(metafits)), + Some(array_pos), + ) + .unwrap(); + let ctx_c = ms_c.get_obs_context(); + assert_eq!(ctx_m.all_timesteps, ctx_c.all_timesteps); + assert_eq!(ctx_m.all_timesteps.len(), num_timesteps); + assert_eq!(ctx_m.timestamps, ctx_c.timestamps); + assert_eq!(ctx_m.fine_chan_freqs, ctx_c.fine_chan_freqs); + let m_flags = ctx_m.flagged_tiles.iter().copied().collect(); + let c_flags = &ctx_c.flagged_tiles; + for m in &m_flags { + assert!(c_flags.contains(m)); + } + assert_eq!(ctx_m.tile_xyzs, ctx_c.tile_xyzs); + assert_eq!(ctx_m.flagged_fine_chans, ctx_c.flagged_fine_chans); + + let flagged_fine_chans_set: HashSet = ctx_m.flagged_fine_chans.iter().copied().collect(); + let tile_baseline_flags = TileBaselineFlags::new(ctx_m.tile_xyzs.len(), m_flags); + let max_baseline_idx = tile_baseline_flags + .tile_to_unflagged_cross_baseline_map + .values() + .max() + .unwrap(); + let data_shape = ( + ctx_m.fine_chan_freqs.len() - ctx_m.flagged_fine_chans.len(), + max_baseline_idx + 1, + ); + let mut vis_m = Array2::>::zeros(data_shape); + let mut vis_c = Array2::>::zeros(data_shape); + let mut weight_m = Array2::::zeros(data_shape); + let mut weight_c = Array2::::zeros(data_shape); + + for ×tep in &ctx_m.all_timesteps { + ms_m.read_crosses( + vis_m.view_mut(), + weight_m.view_mut(), + timestep, + &tile_baseline_flags, + &flagged_fine_chans_set, + ) + .unwrap(); + ms_c.read_crosses( + vis_c.view_mut(), + weight_c.view_mut(), + timestep, + &tile_baseline_flags, + &flagged_fine_chans_set, + ) + .unwrap(); + + // Unlike the equivalent uvfits test, we have to use an epsilon here. + // This is due to the MS antenna positions being in geocentric + // coordinates and not geodetic like uvfits; in the process of + // converting from geocentric to geodetic, small float errors are + // introduced. If a metafits' positions are used instead, the results + // are *exactly* the same, but we should trust the MS's positions, so + // these errors must remain. + #[cfg(feature = "cuda-single")] + assert_abs_diff_eq!(vis_m, vis_c, epsilon = 2e-4); + #[cfg(not(feature = "cuda-single"))] + assert_abs_diff_eq!(vis_m, vis_c, epsilon = 4e-6); + assert_abs_diff_eq!(weight_m, weight_c); + } + + // Inspect the solutions; they should all be close to identity. + #[cfg(feature = "cuda-single")] + let epsilon = 6e-8; + #[cfg(not(feature = "cuda-single"))] + let epsilon = 2e-9; + assert_abs_diff_eq!( + sols.di_jones, + Array3::from_elem(sols.di_jones.dim(), Jones::identity()), + epsilon = epsilon + ); } #[test] -fn test_handle_array_pos() { - let mut args = get_reduced_1090008640(false, true); - let expected = vec![MWA_LONG_DEG + 1.0, MWA_LAT_DEG + 1.0, MWA_HEIGHT_M + 1.0]; - args.array_position = Some(expected.clone()); - let result = args.into_params().unwrap(); +/// Generate a model with "vis-simulate" (in uvfits), then feed it to +/// "di-calibrate", testing the solution timeblocks that come out. +fn test_cal_timeblocks() { + let num_timesteps = 3; + let num_chans = 5; + + let temp_dir = TempDir::new().expect("couldn't make tmp dir"); + let model = temp_dir.path().join("model.uvfits"); + let DataAsStrings { + metafits, srclist, .. + } = get_reduced_1090008640_raw(); + #[rustfmt::skip] + let sim_args = VisSimulateArgs::parse_from([ + "vis-simulate", + "--metafits", &metafits, + "--source-list", &srclist, + "--output-model-files", &format!("{}", model.display()), + "--num-timesteps", &format!("{num_timesteps}"), + "--num-fine-channels", &format!("{num_chans}"), + "--veto-threshold", "0.0", // Don't complicate things with vetoing + // The array position is needed because, if not specified, it's read + // slightly different out of the uvfits. + "--array-position", "116.67081523611111", "-26.703319405555554", "377.827", + ]); + sim_args.run(false).unwrap(); + + let sols_file = temp_dir.path().join("sols.fits"); + #[rustfmt::skip] + let cal_args = DiCalArgs::parse_from([ + "di-calibrate", + "--data", &format!("{}", model.display()), &metafits, + "--source-list", &srclist, + "--outputs", &format!("{}", sols_file.display()), + "--veto-threshold", "0.0", // Don't complicate things with vetoing + "--array-position", "116.67081523611111", "-26.703319405555554", "377.827", + ]); + let sols = cal_args.run(false).unwrap().unwrap(); + let num_cal_timeblocks = sols.di_jones.len_of(Axis(0)); + // We didn't specify anything with calibration timeblocks, so this should be + // 1 (all input data timesteps are used at once in calibration). + assert_eq!(num_cal_timeblocks, 1); + #[cfg(not(feature = "cuda-single"))] + let eps = 0.0; // I am amazed + #[cfg(feature = "cuda-single")] + let eps = 2e-8; assert_abs_diff_eq!( - result.array_position, - LatLngHeight { - longitude_rad: expected[0].to_radians(), - latitude_rad: expected[1].to_radians(), - height_metres: expected[2] + sols.di_jones, + Array3::from_elem(sols.di_jones.dim(), Jones::identity()), + epsilon = eps + ); + + #[rustfmt::skip] + let cal_args = DiCalArgs::parse_from([ + "di-calibrate", + "--data", &format!("{}", model.display()), &metafits, + "--source-list", &srclist, + "--outputs", &format!("{}", sols_file.display()), + "--timesteps-per-timeblock", "2", + "--veto-threshold", "0.0", // Don't complicate things with vetoing + "--array-position", "116.67081523611111", "-26.703319405555554", "377.827", + ]); + let sols = cal_args.run(false).unwrap().unwrap(); + let num_cal_timeblocks = sols.di_jones.len_of(Axis(0)); + // 3 / 2 = 1.5 = 2 rounded up + assert_eq!(num_cal_timeblocks, 2); + #[cfg(not(feature = "cuda-single"))] + let eps = 0.0; + #[cfg(feature = "cuda-single")] + let eps = 4e-8; + assert_abs_diff_eq!( + sols.di_jones, + Array3::from_elem(sols.di_jones.dim(), Jones::identity()), + epsilon = eps + ); +} + +#[test] +fn test_flagging_all_uvw_lengths_causes_error() { + let mut args = get_reduced_1090008640(false, false); + args.calibration_args.uvw_min = Some("3000L".to_string()); + let error = args.parse().err().unwrap(); + assert!(matches!(error, HyperdriveError::DiCalibrate(_))); + match &error { + HyperdriveError::DiCalibrate(s) => { + let s = s.as_str(); + assert!(s == "All baselines were flagged due to UVW cutoffs. Try adjusting the UVW min and/or max."); } + _ => unreachable!(), + } + + let mut args = get_reduced_1090008640(false, false); + args.calibration_args.uvw_min = Some("0L".to_string()); + args.calibration_args.uvw_max = Some("1L".to_string()); + let error = args.parse().err().unwrap(); + assert!(matches!(error, HyperdriveError::DiCalibrate(_))); + match &error { + HyperdriveError::DiCalibrate(s) => { + let s = s.as_str(); + assert!(s == "All baselines were flagged due to UVW cutoffs. Try adjusting the UVW min and/or max."); + } + _ => unreachable!(), + } +} + +/// Given calibration parameters and visibilities, this function tests that +/// everything matches an expected quality. The values may change over time but +/// they should be consistent with whatever tests use this test code. +fn test_1090008640_quality( + params: crate::params::DiCalParams, + vis_data: ArrayView3>, + vis_model: ArrayView3>, +) { + let (_, cal_results) = crate::di_calibrate::calibrate_timeblocks( + vis_data, + vis_model, + ¶ms.input_vis_params.timeblocks, + ¶ms.input_vis_params.spw.chanblocks, + 50, + 1e-8, + 1e-4, + crate::context::Polarisations::default(), + false, ); + + // Only one timeblock. + assert_eq!(cal_results.dim().0, 1); + + let mut count_50 = 0; + let mut count_42 = 0; + let mut chanblocks_42 = vec![]; + let mut fewest_iterations = u32::MAX; + for cal_result in cal_results { + match cal_result.num_iterations { + 50 => { + count_50 += 1; + fewest_iterations = fewest_iterations.min(cal_result.num_iterations); + } + 42 => { + count_42 += 1; + chanblocks_42.push(cal_result.chanblock.unwrap()); + fewest_iterations = fewest_iterations.min(cal_result.num_iterations); + } + 0 => panic!("0 iterations? Something is wrong."), + _ => { + if cal_result.num_iterations % 2 == 1 { + panic!("An odd number of iterations shouldn't be possible; at the time of writing, only even numbers are allowed."); + } + fewest_iterations = fewest_iterations.min(cal_result.num_iterations); + } + } + + assert!( + cal_result.converged, + "Chanblock {} did not converge", + cal_result.chanblock.unwrap() + ); + assert_eq!(cal_result.num_failed, 0); + assert!(cal_result.max_precision < 1e8); + } + + let expected_count_50 = 14; + let expected_count_42 = 1; + let expected_chanblocks_42 = vec![13]; + let expected_fewest_iterations = 40; + if count_50 != expected_count_50 + || count_42 != expected_count_42 + || chanblocks_42 != expected_chanblocks_42 + || fewest_iterations != expected_fewest_iterations + { + panic!( + r#" +Calibration quality has changed. This test expects: + {expected_count_50} chanblocks with 50 iterations (got {count_50}), + {expected_count_42} chanblocks with 42 iterations (got {count_42}), + chanblocks {expected_chanblocks_42:?} to need 42 iterations (got {chanblocks_42:?}), and + no chanblocks to finish in less than {expected_fewest_iterations} iterations (got {fewest_iterations}). +"# + ); + } } #[test] -fn test_handle_bad_array_pos() { - let mut args = get_reduced_1090008640(false, true); - let expected = vec![MWA_LONG_DEG + 1.0, MWA_LAT_DEG + 1.0]; - args.array_position = Some(expected); - let result = args.into_params(); - assert!(result.is_err()); - assert!(matches!(result.err().unwrap(), BadArrayPosition { .. })) +fn test_1090008640_calibration_quality_raw() { + let temp_dir = tempdir().expect("Couldn't make temp dir"); + + let DataAsStrings { + metafits, + mut vis, + mwafs: _, + srclist, + } = get_reduced_1090008640_raw(); + let args = DiCalArgs { + data_args: InputVisArgs { + files: Some(vec![metafits, vis.swap_remove(0)]), + // To be consistent with other data quality tests, add these flags. + fine_chan_flags: Some(vec![0, 1, 2, 16, 30, 31]), + pfb_flavour: Some("none".to_string()), + ..Default::default() + }, + srclist_args: SkyModelWithVetoArgs { + source_list: Some(srclist), + ..Default::default() + }, + beam_args: BeamArgs { + no_beam: true, + ..Default::default() + }, + calibration_args: DiCalCliArgs { + solutions: Some(vec![temp_dir.path().join("hyp_sols.fits")]), + ..Default::default() + }, + ..Default::default() + }; + + let params = args.parse().unwrap(); + let CalVis { + vis_data, + vis_model, + .. + } = params + .get_cal_vis() + .expect("Couldn't read data and generate a model"); + test_1090008640_quality(params, vis_data.view(), vis_model.view()); +} + +#[test] +#[serial] +fn test_1090008640_calibration_quality_ms() { + let temp_dir = tempdir().expect("Couldn't make temp dir"); + + let DataAsStrings { + metafits, + mut vis, + mwafs: _, + srclist, + } = get_reduced_1090008640_ms(); + let args = DiCalArgs { + data_args: InputVisArgs { + files: Some(vec![metafits, vis.swap_remove(0)]), + // To be consistent with other data quality tests, add these flags. + fine_chan_flags: Some(vec![0, 1, 2, 16, 30, 31]), + ..Default::default() + }, + srclist_args: SkyModelWithVetoArgs { + source_list: Some(srclist), + ..Default::default() + }, + beam_args: BeamArgs { + no_beam: true, + ..Default::default() + }, + calibration_args: DiCalCliArgs { + solutions: Some(vec![temp_dir.path().join("hyp_sols.fits")]), + ..Default::default() + }, + ..Default::default() + }; + + let params = args.parse().unwrap(); + let CalVis { + vis_data, + vis_model, + .. + } = params + .get_cal_vis() + .expect("Couldn't read data and generate a model"); + test_1090008640_quality(params, vis_data.view(), vis_model.view()); +} + +#[test] +fn test_1090008640_calibration_quality_uvfits() { + let temp_dir = tempdir().expect("Couldn't make temp dir"); + + let DataAsStrings { + metafits, + mut vis, + mwafs: _, + srclist, + } = get_reduced_1090008640_uvfits(); + let args = DiCalArgs { + data_args: InputVisArgs { + files: Some(vec![metafits, vis.swap_remove(0)]), + // To be consistent with other data quality tests, add these flags. + fine_chan_flags: Some(vec![0, 1, 2, 16, 30, 31]), + ..Default::default() + }, + srclist_args: SkyModelWithVetoArgs { + source_list: Some(srclist), + ..Default::default() + }, + beam_args: BeamArgs { + no_beam: true, + ..Default::default() + }, + calibration_args: DiCalCliArgs { + solutions: Some(vec![temp_dir.path().join("hyp_sols.fits")]), + ..Default::default() + }, + ..Default::default() + }; + + let params = args.parse().unwrap(); + let CalVis { + vis_data, + vis_model, + .. + } = params + .get_cal_vis() + .expect("Couldn't read data and generate a model"); + test_1090008640_quality(params, vis_data.view(), vis_model.view()); } diff --git a/src/cli/dipole_gains.rs b/src/cli/dipole_gains.rs index 3e605da2..35ee9a13 100644 --- a/src/cli/dipole_gains.rs +++ b/src/cli/dipole_gains.rs @@ -11,6 +11,8 @@ use clap::Parser; use log::info; use mwalib::{MetafitsContext, MwalibError}; +use super::common::display_warnings; + /// Print information on the dipole gains listed by a metafits file. #[derive(Parser, Debug)] pub struct DipoleGainsArgs { @@ -35,6 +37,8 @@ impl DipoleGainsArgs { } } + display_warnings(); + if all_unity.len() == meta.num_ants { info!("All dipoles on all tiles have a gain of 1.0!"); } else { diff --git a/src/cli/error.rs b/src/cli/error.rs new file mode 100644 index 00000000..ee78759a --- /dev/null +++ b/src/cli/error.rs @@ -0,0 +1,470 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Error type for all hyperdrive-related errors. This should be the *only* +//! error enum that is publicly visible. + +use thiserror::Error; + +use super::{ + common::InputVisArgsError, + di_calibrate::DiCalArgsError, + solutions::{SolutionsApplyArgsError, SolutionsPlotError}, + srclist::SrclistByBeamError, + vis_convert::VisConvertArgsError, + vis_simulate::VisSimulateArgsError, + vis_subtract::VisSubtractArgsError, +}; +use crate::{ + beam::BeamError, + io::{ + read::VisReadError, + write::{FileWriteError, VisWriteError}, + GlobError, + }, + model::ModelError, + params::{DiCalibrateError, VisConvertError, VisSimulateError, VisSubtractError}, + solutions::{SolutionsReadError, SolutionsWriteError}, + srclist::{ReadSourceListError, SrclistError, WriteSourceListError}, +}; + +const URL: &str = "https://MWATelescope.github.io/mwa_hyperdrive"; + +/// The *only* publicly visible error from hyperdrive. Each error message should +/// include the URL, unless it's "generic". +#[derive(Error, Debug)] +pub enum HyperdriveError { + /// An error related to di-calibrate. + #[error("{0}\n\nSee for more info: {URL}/user/di_cal/intro.html")] + DiCalibrate(String), + + /// An error related to solutions-apply. + #[error("{0}\n\nSee for more info: {URL}/user/solutions_apply/intro.html")] + SolutionsApply(String), + + /// An error related to solutions-plot. + #[error("{0}\n\nSee for more info: {URL}/user/plotting.html")] + SolutionsPlot(String), + + /// An error related to vis-simulate. + #[error("{0}\n\nSee for more info: {URL}/user/vis_simulate/intro.html")] + VisSimulate(String), + + /// An error related to vis-subtract. + #[error("{0}\n\nSee for more info: {URL}/user/vis_subtract/intro.html")] + VisSubtract(String), + + /// Generic error surrounding source lists. + #[error("{0}\n\nSee for more info: {URL}/defs/source_lists.html")] + Srclist(String), + + /// Generic error surrounding calibration solutions. + #[error("{0}\n\nSee for more info: {URL}/defs/cal_sols.html")] + Solutions(String), + + /// Error specific to hyperdrive calibration solutions. + #[error("{0}\n\nSee for more info: {URL}/defs/cal_sols_hyp.html")] + SolutionsHyp(String), + + /// Error specific to AO calibration solutions. + #[error("{0}\n\nSee for more info: {URL}/defs/cal_sols_ao.html")] + SolutionsAO(String), + + /// Error specific to RTS calibration solutions. + #[error("{0}\n\nSee for more info: {URL}/defs/cal_sols_rts.html")] + SolutionsRts(String), + + /// An error related to reading visibilities. + #[error("{0}\n\nSee for more info: {URL}/defs/vis_formats_read.html")] + VisRead(String), + + /// An error related to reading visibilities. + #[error("{0}\n\nSee for more info: {URL}/defs/vis_formats_write.html")] + VisWrite(String), + + /// An error related to averaging. + #[error("{0}\n\nSee for more info: {URL}/defs/vis_formats_write.html#visibility-averaging")] + Averaging(String), + + /// An error related to raw MWA data corrections. + #[error("{0}\n\nSee for more info: {URL}/defs/mwa/corrections.html")] + RawDataCorrections(String), + + /// An error related to metafits files. + #[error("{0}\n\nSee for more info: {URL}/defs/mwa/metafits.html")] + Metafits(String), + + /// An error related to dipole delays. + #[error("{0}\n\nYou may be able to fix this by supplying a metafits file or manually specifying the MWA dipole delays.\n\nSee for more info: {URL}/defs/mwa/delays.html")] + Delays(String), + + /// An error related to mwaf files. + #[error("{0}\n\nSee for more info: {URL}/defs/mwa/mwaf.html")] + Mwaf(String), + + /// An error related to mwalib. + #[error("{0}\n\nSee for more info: {URL}/defs/mwa/mwalib.html")] + Mwalib(String), + + /// An error related to beam code. + #[error("{0}\n\nSee for more info: {URL}/defs/beam.html")] + Beam(String), + + /// An error related to argument files. + #[error("{0}\n\nSee for more info: {URL}/defs/arg_file.html")] + ArgFile(String), + + /// A cfitsio error. Because these are usually quite spartan, some + /// suggestions are provided here. + #[error("cfitsio error: {0}\n\nIf you don't know what this means, try turning up verbosity (-v or -vv) and maybe disabling progress bars.")] + Cfitsio(String), + + /// A generic error that can't be clarified further with documentation, e.g. + /// IO errors. + #[error("{0}")] + Generic(String), +} + +// When changing the error propagation below, ensure `Self::from(e)` uses the +// correct `e`! + +// Binary sub-command errors. + +impl From for HyperdriveError { + fn from(e: DiCalArgsError) -> Self { + match e { + DiCalArgsError::NoOutput + | DiCalArgsError::AllBaselinesFlaggedFromUvwCutoffs + | DiCalArgsError::ParseUvwMin(_) + | DiCalArgsError::ParseUvwMax(_) => Self::DiCalibrate(e.to_string()), + DiCalArgsError::CalibrationOutputFile { .. } => Self::Solutions(e.to_string()), + DiCalArgsError::ParseCalTimeAverageFactor(_) + | DiCalArgsError::CalTimeFactorNotInteger + | DiCalArgsError::CalTimeResNotMultiple { .. } + | DiCalArgsError::CalTimeFactorZero => Self::Averaging(e.to_string()), + DiCalArgsError::IO(e) => Self::from(e), + } + } +} + +impl From for HyperdriveError { + fn from(e: DiCalibrateError) -> Self { + let s = e.to_string(); + match e { + DiCalibrateError::InsufficientMemory { .. } => Self::DiCalibrate(s), + DiCalibrateError::SolutionsRead(_) | DiCalibrateError::SolutionsWrite(_) => { + Self::Solutions(s) + } + DiCalibrateError::Fitsio(_) => Self::Cfitsio(s), + DiCalibrateError::VisRead(e) => Self::from(e), + DiCalibrateError::VisWrite(_) => Self::VisWrite(s), + DiCalibrateError::Model(_) | DiCalibrateError::IO(_) => Self::Generic(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: SolutionsApplyArgsError) -> Self { + let s = e.to_string(); + match e { + SolutionsApplyArgsError::NoSolutions => Self::SolutionsApply(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: SolutionsPlotError) -> Self { + let s = e.to_string(); + match e { + #[cfg(not(feature = "plotting"))] + SolutionsPlotError::NoPlottingFeature => Self::SolutionsPlot(s), + SolutionsPlotError::SolutionsRead(_) => Self::Solutions(s), + SolutionsPlotError::Mwalib(_) => Self::Mwalib(s), + SolutionsPlotError::IO(_) => Self::Generic(s), + #[cfg(feature = "plotting")] + SolutionsPlotError::MetafitsNoAntennaNames => Self::Metafits(s), + #[cfg(feature = "plotting")] + SolutionsPlotError::Draw(_) + | SolutionsPlotError::NoInputs + | SolutionsPlotError::InvalidSolsFormat(_) => Self::Generic(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: VisConvertArgsError) -> Self { + let s = e.to_string(); + match e { + VisConvertArgsError::NoOutputs => Self::VisWrite(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: VisConvertError) -> Self { + match e { + VisConvertError::VisRead(e) => Self::from(e), + VisConvertError::VisWrite(e) => Self::from(e), + VisConvertError::IO(e) => Self::from(e), + } + } +} + +impl From for HyperdriveError { + fn from(e: VisSimulateArgsError) -> Self { + let s = e.to_string(); + match e { + VisSimulateArgsError::NoMetafits + | VisSimulateArgsError::MetafitsDoesntExist(_) + | VisSimulateArgsError::RaInvalid + | VisSimulateArgsError::DecInvalid + | VisSimulateArgsError::OnlyOneRAOrDec + | VisSimulateArgsError::FineChansZero + | VisSimulateArgsError::FineChansWidthTooSmall + | VisSimulateArgsError::ZeroTimeSteps + | VisSimulateArgsError::BadArrayPosition { .. } => Self::VisSimulate(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: VisSimulateError) -> Self { + match e { + VisSimulateError::VisWrite(e) => Self::from(e), + VisSimulateError::Model(e) => Self::from(e), + VisSimulateError::IO(e) => Self::from(e), + } + } +} + +impl From for HyperdriveError { + fn from(e: VisSubtractArgsError) -> Self { + let s = e.to_string(); + match e { + VisSubtractArgsError::MissingSource { .. } + | VisSubtractArgsError::NoSources + | VisSubtractArgsError::AllSourcesFiltered => Self::VisSubtract(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: VisSubtractError) -> Self { + match e { + VisSubtractError::VisRead(e) => Self::from(e), + VisSubtractError::VisWrite(e) => Self::from(e), + VisSubtractError::Model(e) => Self::from(e), + VisSubtractError::IO(e) => Self::from(e), + #[cfg(feature = "cuda")] + VisSubtractError::Cuda(e) => Self::from(e), + } + } +} + +impl From for HyperdriveError { + fn from(e: SrclistByBeamError) -> Self { + match e { + SrclistByBeamError::NoPhaseCentre => todo!(), + SrclistByBeamError::NoLst => todo!(), + SrclistByBeamError::NoFreqs => todo!(), + SrclistByBeamError::ReadSourceList(e) => Self::from(e), + SrclistByBeamError::WriteSourceList(e) => Self::from(e), + SrclistByBeamError::Beam(e) => Self::from(e), + SrclistByBeamError::Mwalib(e) => Self::from(e), + SrclistByBeamError::IO(e) => Self::from(e), + } + } +} + +// Library code errors. + +impl From for HyperdriveError { + fn from(e: InputVisArgsError) -> Self { + let s = e.to_string(); + match e { + InputVisArgsError::Raw( + crate::io::read::RawReadError::MwafFlagsMissingForTimestep { .. } + | crate::io::read::RawReadError::MwafMerge(_), + ) => Self::Mwaf(s), + InputVisArgsError::PfbParse(_) => Self::RawDataCorrections(s), + InputVisArgsError::DoesNotExist(_) + | InputVisArgsError::CouldNotRead(_) + | InputVisArgsError::PpdMetafitsUnsupported(_) + | InputVisArgsError::NotRecognised(_) + | InputVisArgsError::NoInputData + | InputVisArgsError::MultipleMetafits(_) + | InputVisArgsError::MultipleMeasurementSets(_) + | InputVisArgsError::MultipleUvfits(_) + | InputVisArgsError::MultipleSolutions(_) + | InputVisArgsError::InvalidDataInput(_) + | InputVisArgsError::BadArrayPosition { .. } + | InputVisArgsError::NoTimesteps + | InputVisArgsError::DuplicateTimesteps + | InputVisArgsError::UnavailableTimestep { .. } + | InputVisArgsError::NoTiles + | InputVisArgsError::BadTileIndexForFlagging { .. } + | InputVisArgsError::BadTileNameForFlagging(_) + | InputVisArgsError::NoChannels + | InputVisArgsError::FineChanFlagTooBig { .. } + | InputVisArgsError::FineChanFlagPerCoarseChanTooBig { .. } + | InputVisArgsError::Raw(_) + | InputVisArgsError::Ms(_) + | InputVisArgsError::Uvfits(_) => Self::VisRead(s), + InputVisArgsError::TileCountMismatch { .. } | InputVisArgsError::Solutions(_) => { + Self::Solutions(s) + } + InputVisArgsError::ParseTimeAverageFactor(_) + | InputVisArgsError::TimeFactorNotInteger + | InputVisArgsError::TimeResNotMultiple { .. } + | InputVisArgsError::ParseFreqAverageFactor(_) + | InputVisArgsError::FreqFactorNotInteger + | InputVisArgsError::FreqResNotMultiple { .. } => Self::Averaging(s), + InputVisArgsError::Glob(_) | InputVisArgsError::IO(_, _) => Self::Generic(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: VisReadError) -> Self { + let s = e.to_string(); + match e { + VisReadError::Raw(_) | VisReadError::MS(_) | VisReadError::Uvfits(_) => { + Self::VisRead(s) + } + VisReadError::BadArraySize { .. } => Self::Generic(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: VisWriteError) -> Self { + Self::VisWrite(e.to_string()) + } +} + +impl From for HyperdriveError { + fn from(e: FileWriteError) -> Self { + Self::VisWrite(e.to_string()) + } +} + +impl From for HyperdriveError { + fn from(e: ReadSourceListError) -> Self { + let s = e.to_string(); + match e { + ReadSourceListError::IO(_) => Self::Generic(s), + _ => Self::Srclist(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: WriteSourceListError) -> Self { + let s = e.to_string(); + match e { + WriteSourceListError::UnsupportedComponentType { .. } + | WriteSourceListError::UnsupportedFluxDensityType { .. } + | WriteSourceListError::InvalidHyperdriveFormat(_) + | WriteSourceListError::Sexagesimal(_) => Self::Srclist(s), + WriteSourceListError::IO(e) => Self::from(e), + WriteSourceListError::Yaml(_) | WriteSourceListError::Json(_) => Self::Generic(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: SrclistError) -> Self { + let s = e.to_string(); + match e { + SrclistError::ReadSourceList(e) => Self::from(e), + SrclistError::Beam(e) => Self::from(e), + SrclistError::WriteSourceList(_) => Self::Srclist(s), + SrclistError::MissingMetafits => Self::Metafits(s), + SrclistError::Mwalib(_) => Self::Mwalib(s), + SrclistError::IO(e) => Self::from(e), + } + } +} + +impl From for HyperdriveError { + fn from(e: SolutionsReadError) -> Self { + let s = e.to_string(); + match e { + SolutionsReadError::UnsupportedExt { .. } => Self::Solutions(s), + SolutionsReadError::BadShape { .. } | SolutionsReadError::ParsePfbFlavour(_) => { + Self::SolutionsHyp(s) + } + SolutionsReadError::AndreBinaryStr { .. } + | SolutionsReadError::AndreBinaryVal { .. } => Self::SolutionsAO(s), + SolutionsReadError::RtsMetafitsRequired | SolutionsReadError::Rts(_) => { + Self::SolutionsRts(s) + } + SolutionsReadError::Fits(_) | SolutionsReadError::Fitsio(_) => Self::Cfitsio(s), + SolutionsReadError::IO(e) => Self::from(e), + } + } +} + +impl From for HyperdriveError { + fn from(e: SolutionsWriteError) -> Self { + let s = e.to_string(); + match e { + SolutionsWriteError::UnsupportedExt { .. } => Self::Solutions(s), + SolutionsWriteError::Fits(_) | SolutionsWriteError::Fitsio(_) => Self::Cfitsio(s), + SolutionsWriteError::IO(e) => Self::from(e), + } + } +} + +impl From for HyperdriveError { + fn from(e: BeamError) -> Self { + let s = e.to_string(); + match e { + BeamError::NoDelays + | BeamError::BadDelays + | BeamError::DelayGainsDimensionMismatch { .. } => Self::Delays(s), + BeamError::BadTileIndex { .. } + | BeamError::Hyperbeam(_) + | BeamError::HyperbeamInit(_) => Self::Beam(s), + #[cfg(feature = "cuda")] + BeamError::Cuda(_) => Self::Beam(s), + } + } +} + +impl From for HyperdriveError { + fn from(e: ModelError) -> Self { + match e { + ModelError::Beam(e) => Self::from(e), + + #[cfg(feature = "cuda")] + ModelError::Cuda(e) => Self::from(e), + } + } +} + +impl From for HyperdriveError { + fn from(e: GlobError) -> Self { + Self::Generic(e.to_string()) + } +} + +impl From for HyperdriveError { + fn from(e: std::io::Error) -> Self { + Self::Generic(e.to_string()) + } +} + +impl From for HyperdriveError { + fn from(e: mwalib::MwalibError) -> Self { + Self::Mwalib(e.to_string()) + } +} + +#[cfg(feature = "cuda")] +impl From for HyperdriveError { + fn from(e: crate::cuda::CudaError) -> Self { + Self::Generic(e.to_string()) + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index a33e0673..46a0c2dd 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -2,8 +2,281 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -pub(crate) mod di_calibrate; -pub(crate) mod dipole_gains; -pub(crate) mod solutions; -pub(crate) mod srclist; -pub(crate) mod vis_utils; +//! Command-line interface code. More specific options for `hyperdrive` +//! subcommands are contained in modules. +//! +//! All booleans must have `#[serde(default)]` annotated, and anything that +//! isn't a boolean must be optional. This allows all arguments to be optional +//! *and* usable in an arguments file. +//! +//! Only 3 things should be public in this module: `Hyperdrive`, +//! `Hyperdrive::run`, and `HyperdriveError`. + +#[macro_use] +mod common; +mod di_calibrate; +mod dipole_gains; +mod error; +mod solutions; +mod srclist; +mod vis_convert; +mod vis_simulate; +mod vis_subtract; + +pub(crate) use common::Warn; +pub use error::HyperdriveError; + +use std::path::PathBuf; + +use clap::{AppSettings, Args, Parser, Subcommand}; +use log::info; + +use crate::PROGRESS_BARS; + +// Add build-time information from the "built" crate. +include!(concat!(env!("OUT_DIR"), "/built.rs")); + +#[derive(Debug, Parser)] +#[clap( + version, + author, + about = r#"Calibration software for the Murchison Widefield Array (MWA) radio telescope +Documentation: https://mwatelescope.github.io/mwa_hyperdrive +Source: https://github.com/MWATelescope/mwa_hyperdrive"# +)] +#[clap(global_setting(AppSettings::DeriveDisplayOrder))] +#[clap(disable_help_subcommand = true)] +#[clap(infer_subcommands = true)] +#[clap(propagate_version = true)] +#[clap(infer_long_args = true)] +pub struct Hyperdrive { + #[clap(flatten)] + global_opts: GlobalArgs, + + #[clap(subcommand)] + command: Command, +} + +#[derive(Debug, Args)] +struct GlobalArgs { + /// Don't draw progress bars. + #[clap(long)] + #[clap(global = true)] + no_progress_bars: bool, + + /// The verbosity of the program. Increase by specifying multiple times + /// (e.g. -vv). The default is to print only high-level information. + #[clap(short, long, parse(from_occurrences))] + #[clap(global = true)] + verbosity: u8, + + /// Only verify that arguments were correctly ingested and print out + /// high-level information. + #[clap(long)] + #[clap(global = true)] + dry_run: bool, + + /// Save the input arguments into a new TOML file that can be used to + /// reproduce this run. + #[clap(long)] + #[clap(global = true)] + save_toml: Option, +} + +#[derive(Debug, Subcommand)] +#[clap(arg_required_else_help = true)] +enum Command { + #[clap(alias = "calibrate")] + #[clap( + about = r#"Perform direction-independent calibration on the input MWA data. +https://mwatelescope.github.io/mwa_hyperdrive/user/di_cal/intro.html"# + )] + DiCalibrate(di_calibrate::DiCalArgs), + + #[clap(alias = "convert-vis")] + #[clap(about = r#"Convert visibilities from one type to another. +https://mwatelescope.github.io/mwa_hyperdrive/user/vis_convert/intro.html"#)] + VisConvert(vis_convert::VisConvertArgs), + + #[clap(alias = "simulate-vis")] + #[clap(about = r#"Simulate visibilities of a sky-model source list. +https://mwatelescope.github.io/mwa_hyperdrive/user/vis_simulate/intro.html"#)] + VisSimulate(vis_simulate::VisSimulateArgs), + + #[clap(alias = "subtract-vis")] + #[clap(about = "Subtract sky-model sources from supplied visibilities. +https://mwatelescope.github.io/mwa_hyperdrive/user/vis_subtract/intro.html")] + VisSubtract(vis_subtract::VisSubtractArgs), + + #[clap(alias = "apply-solutions")] + #[clap(about = r#"Apply calibration solutions to input data. +https://mwatelescope.github.io/mwa_hyperdrive/user/solutions_apply/intro.html"#)] + SolutionsApply(solutions::SolutionsApplyArgs), + + #[clap(alias = "plot-solutions")] + #[clap( + about = r#"Plot calibration solutions. Only available if compiled with the "plotting" feature. +https://mwatelescope.github.io/mwa_hyperdrive/user/plotting.html"# + )] + SolutionsPlot(solutions::SolutionsPlotArgs), + + #[clap(alias = "convert-solutions")] + #[clap(about = "Convert between calibration solution file formats.")] + SolutionsConvert(solutions::SolutionsConvertArgs), + + SrclistByBeam(srclist::SrclistByBeamArgs), + + SrclistConvert(srclist::SrclistConvertArgs), + + SrclistVerify(srclist::SrclistVerifyArgs), + + SrclistShift(srclist::SrclistShiftArgs), + + DipoleGains(dipole_gains::DipoleGainsArgs), +} + +impl Hyperdrive { + pub fn run(self) -> Result<(), HyperdriveError> { + // Set up logging. + let GlobalArgs { + verbosity, + dry_run, + no_progress_bars, + save_toml, + } = self.global_opts; + setup_logging(verbosity).expect("Failed to initialise logging."); + // Enable progress bars if the user didn't say "no progress bars". + if !no_progress_bars { + PROGRESS_BARS.store(true); + } + + // Print the version of hyperdrive and its build-time information. + let sub_command = match &self.command { + Command::DiCalibrate(_) => "di-calibrate", + Command::VisConvert(_) => "vis-convert", + Command::VisSimulate(_) => "vis-simulate", + Command::VisSubtract(_) => "vis-subtract", + Command::SolutionsApply(_) => "solutions-apply", + Command::SolutionsConvert(_) => "solutions-convert", + Command::SolutionsPlot(_) => "solutions-plot", + Command::SrclistByBeam(_) => "srclist-by-beam", + Command::SrclistConvert(_) => "srclist-convert", + Command::SrclistShift(_) => "srclist-shift", + Command::SrclistVerify(_) => "srclist-verify", + Command::DipoleGains(_) => "dipole-gains", + }; + info!("hyperdrive {} {}", sub_command, env!("CARGO_PKG_VERSION")); + display_build_info(); + + macro_rules! merge_save_run { + ($args:expr) => {{ + let args = $args.merge()?; + if let Some(toml) = save_toml { + use std::{ + fs::File, + io::{BufWriter, Write}, + }; + + let mut f = BufWriter::new(File::create(toml)?); + let toml_str = toml::to_string(&args).expect("toml serialisation error"); + f.write_all(toml_str.as_bytes())?; + } + args.run(dry_run)?; + }}; + } + + match self.command { + Command::DiCalibrate(args) => { + merge_save_run!(args) + } + + Command::VisConvert(args) => { + merge_save_run!(args) + } + + Command::VisSimulate(args) => { + merge_save_run!(args) + } + + Command::VisSubtract(args) => { + merge_save_run!(args) + } + + Command::SolutionsApply(args) => { + merge_save_run!(args) + } + + Command::SolutionsConvert(args) => { + args.run()?; + } + + Command::SolutionsPlot(args) => { + args.run()?; + } + + // Source list utilities. + Command::SrclistByBeam(args) => args.run()?, + Command::SrclistConvert(args) => args.run()?, + Command::SrclistShift(args) => args.run()?, + Command::SrclistVerify(args) => args.run()?, + + // Misc. utilities. + Command::DipoleGains(args) => args.run()?, + } + + info!("hyperdrive {} complete.", sub_command); + Ok(()) + } +} + +/// Activate a logger. All log messages are put onto `stdout`. `env_logger` +/// automatically only uses colours and fancy symbols if we're on a tty (e.g. a +/// terminal); piped output will be formatted sensibly. Source code lines are +/// displayed in log messages when verbosity >= 3. +fn setup_logging(verbosity: u8) -> Result<(), log::SetLoggerError> { + let mut builder = env_logger::Builder::from_default_env(); + builder.target(env_logger::Target::Stdout); + builder.format_target(false); + match verbosity { + 0 => builder.filter_level(log::LevelFilter::Info), + 1 => builder.filter_level(log::LevelFilter::Debug), + 2 => builder.filter_level(log::LevelFilter::Trace), + _ => { + builder.filter_level(log::LevelFilter::Trace); + builder.format(|buf, record| { + use std::io::Write; + + let timestamp = buf.timestamp(); + let level = record.level(); + let target = record.target(); + let line = record.line().unwrap_or(0); + let message = record.args(); + + writeln!(buf, "[{timestamp} {level} {target}:{line}] {message}") + }) + } + }; + builder.init(); + + Ok(()) +} + +/// Write many info-level log lines of how this executable was compiled. +fn display_build_info() { + let dirty = match GIT_DIRTY { + Some(true) => " (dirty)", + _ => "", + }; + match GIT_COMMIT_HASH_SHORT { + Some(hash) => { + info!("Compiled on git commit hash: {hash}{dirty}"); + } + None => info!("Compiled on git commit hash: "), + } + if let Some(hr) = GIT_HEAD_REF { + info!(" git head ref: {}", hr); + } + info!(" {}", BUILT_TIME_UTC); + info!(" with compiler {}", RUSTC_VERSION); + info!(""); +} diff --git a/src/cli/solutions/apply/mod.rs b/src/cli/solutions/apply/mod.rs index d5502beb..1c4cb6a0 100644 --- a/src/cli/solutions/apply/mod.rs +++ b/src/cli/solutions/apply/mod.rs @@ -5,1075 +5,206 @@ //! Given input data and a calibration solutions file, apply the solutions and //! write out the calibrated visibilities. -mod error; #[cfg(test)] mod tests; -pub(crate) use error::SolutionsApplyError; - -use std::{ - collections::HashSet, - path::{Path, PathBuf}, - str::FromStr, - thread, -}; +use std::path::PathBuf; use clap::Parser; -use crossbeam_channel::{bounded, Receiver, Sender}; -use crossbeam_utils::atomic::AtomicCell; -use hifitime::Duration; -use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; -use itertools::Itertools; -use log::{debug, info, log_enabled, trace, warn, Level::Debug}; -use marlu::{Jones, LatLngHeight}; -use ndarray::{prelude::*, ArcArray2}; -use scopeguard::defer_on_unwind; -use vec1::{vec1, Vec1}; +use log::{debug, info, trace}; +use serde::{Deserialize, Serialize}; use crate::{ - averaging::{ - parse_freq_average_factor, parse_time_average_factor, timesteps_to_timeblocks, - AverageFactorError, - }, - context::ObsContext, - filenames::InputDataTypes, - help_texts::{ARRAY_POSITION_HELP, MS_DATA_COL_NAME_HELP}, - io::{ - read::{ - pfb_gains::PfbFlavour, MsReader, RawDataReader, UvfitsReader, VisInputType, VisRead, - VisReadError, - }, - write::{can_write_to_file, write_vis, VisOutputType, VisTimestep, VIS_OUTPUT_EXTENSIONS}, - }, - math::TileBaselineFlags, - messages, - solutions::CalibrationSolutions, + cli::common::{display_warnings, InputVisArgs, OutputVisArgs, ARG_FILE_HELP}, + io::write::VIS_OUTPUT_EXTENSIONS, + params::SolutionsApplyParams, + solutions::CAL_SOLUTION_EXTENSIONS, HyperdriveError, }; pub(crate) const DEFAULT_OUTPUT_VIS_FILENAME: &str = "hyperdrive_calibrated.uvfits"; lazy_static::lazy_static! { + static ref SOLS_INPUT_HELP: String = + format!("Path to the calibration solutions file to be applied. Supported formats: {}", *CAL_SOLUTION_EXTENSIONS); + static ref OUTPUTS_HELP: String = format!("Paths to the output calibrated visibility files. Supported formats: {}. Default: {}", *VIS_OUTPUT_EXTENSIONS, DEFAULT_OUTPUT_VIS_FILENAME); - - static ref PFB_FLAVOUR_HELP: String = - format!("{}. Only useful if the input solutions don't specify that this correction should be applied.", *crate::help_texts::PFB_FLAVOUR_HELP); } -#[derive(Parser, Debug, Default)] -pub struct SolutionsApplyArgs { - /// Paths to the input data files to apply solutions to. These can include a - /// metafits file, a measurement set and/or uvfits files. - #[clap(short, long, multiple_values(true), help_heading = "INPUT FILES")] - data: Vec, - - /// Path to the calibration solutions file. - #[clap(short, long, help_heading = "INPUT FILES")] - solutions: PathBuf, +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(crate) struct SolutionsApplyArgs { + #[clap(name = "ARGUMENTS_FILE", help = ARG_FILE_HELP.as_str(), parse(from_os_str))] + args_file: Option, - /// The timesteps to use from the input data. The default is to use all - /// unflagged timesteps. - #[clap(long, multiple_values(true), help_heading = "INPUT FILES")] - timesteps: Option>, + #[clap(flatten)] + #[serde(rename = "data")] + #[serde(default)] + data_args: InputVisArgs, - /// Use all timesteps in the data, including flagged ones. The default is to - /// use all unflagged timesteps. - #[clap(long, conflicts_with("timesteps"), help_heading = "INPUT FILES")] - use_all_timesteps: bool, + /// Path to the calibration solutions file to be applied. + #[clap(short, long, help = SOLS_INPUT_HELP.as_str(), help_heading = "INPUT DATA")] + solutions: Option, #[clap( - long, help = ARRAY_POSITION_HELP.as_str(), help_heading = "INPUT FILES", - number_of_values = 3, - allow_hyphen_values = true, - value_names = &["LONG_DEG", "LAT_DEG", "HEIGHT_M"] - )] - array_position: Option>, - - #[clap(long, help = MS_DATA_COL_NAME_HELP, help_heading = "INPUT FILES")] - ms_data_column_name: Option, - - /// Use a DUT1 value of 0 seconds rather than what is in the input data. - #[clap(long, help_heading = "INPUT FILES")] - ignore_dut1: bool, - - /// Additional tiles to be flagged. These values correspond to either the - /// values in the "Antenna" column of HDU 2 in the metafits file (e.g. 0 3 - /// 127), or the "TileName" (e.g. Tile011). - #[clap(long, multiple_values(true), help_heading = "FLAGGING")] - tile_flags: Option>, - - /// If specified, pretend that all tiles are unflagged in the input data. - #[clap(long, help_heading = "FLAGGING")] - ignore_input_data_tile_flags: bool, - - /// If specified, pretend that all tiles are unflagged in the input - /// solutions. - #[clap(long, help_heading = "FLAGGING")] - ignore_input_solutions_tile_flags: bool, - - #[clap( - short, long, multiple_values(true), help = OUTPUTS_HELP.as_str(), + short = 'o', + long, + multiple_values(true), + help = OUTPUTS_HELP.as_str(), help_heading = "OUTPUT FILES" )] outputs: Option>, - /// When writing out calibrated visibilities, average this many timesteps - /// together. Also supports a target time resolution (e.g. 8s). The value - /// must be a multiple of the input data's time resolution. The default is - /// to preserve the input data's time resolution. e.g. If the input data is - /// in 0.5s resolution and this variable is 4, then we average 2s worth of - /// calibrated data together before writing the data out. If the variable is - /// instead 4s, then 8 calibrated timesteps are averaged together before - /// writing the data out. + /// When writing out visibilities, average this many timesteps together. + /// Also supports a target time resolution (e.g. 8s). The value must be a + /// multiple of the input data's time resolution. The default is no + /// averaging, i.e. a value of 1. Examples: If the input data is in 0.5s + /// resolution and this variable is 4, then we average 2s worth of data + /// together before writing the data out. If the variable is instead 4s, + /// then 8 timesteps are averaged together before writing the data out. #[clap(long, help_heading = "OUTPUT FILES")] - time_average: Option, - - /// When writing out calibrated visibilities, average this many fine freq. - /// channels together. Also supports a target freq. resolution (e.g. 80kHz). - /// The value must be a multiple of the input data's freq. resolution. The - /// default is to preserve the input data's freq. resolution. e.g. If the - /// input data is in 40kHz resolution and this variable is 4, then we - /// average 160kHz worth of calibrated data together before writing the data - /// out. If the variable is instead 80kHz, then 2 calibrated fine freq. - /// channels are averaged together before writing the data out. + output_vis_time_average: Option, + + /// When writing out visibilities, average this many fine freq. channels + /// together. Also supports a target freq. resolution (e.g. 80kHz). The + /// value must be a multiple of the input data's freq. resolution. The + /// default is no averaging, i.e. a value of 1. Examples: If the input data + /// is in 40kHz resolution and this variable is 4, then we average 160kHz + /// worth of data together before writing the data out. If the variable is + /// instead 80kHz, then 2 fine freq. channels are averaged together before + /// writing the data out. #[clap(long, help_heading = "OUTPUT FILES")] - freq_average: Option, - - /// Don't include autocorrelations in the output visibilities. + output_vis_freq_average: Option, + + /// Rather than writing out the entire input bandwidth, write out only the + /// smallest contiguous band. e.g. Typical 40 kHz MWA data has 768 channels, + /// but the first 2 and last 2 channels are usually flagged. Turning this + /// option on means that 764 channels would be written out instead of 768. + /// Note that other flagged channels in the band are unaffected, because the + /// data written out must be contiguous. #[clap(long, help_heading = "OUTPUT FILES")] - no_autos: bool, - - #[clap(long, help = PFB_FLAVOUR_HELP.as_str(), help_heading = "RAW MWA DATA CORRECTIONS")] - pfb_flavour: Option, - - /// When reading in raw MWA data, don't apply digital gains. Only useful if - /// the input solutions don't specify that this correction should be - /// applied. - #[clap(long, help_heading = "RAW MWA DATA CORRECTIONS")] - no_digital_gains: bool, - - /// When reading in raw MWA data, don't apply cable length corrections. Only - /// useful if the input solutions don't specify that this correction should - /// be applied. - #[clap(long, help_heading = "RAW MWA DATA CORRECTIONS")] - no_cable_length_correction: bool, - - /// When reading in raw MWA data, don't apply geometric corrections. Only - /// useful if the input solutions don't specify that this correction should - /// be applied. - #[clap(long, help_heading = "RAW MWA DATA CORRECTIONS")] - no_geometric_correction: bool, - - /// Don't draw progress bars. - #[clap(long, help_heading = "USER INTERFACE")] - no_progress_bars: bool, + #[serde(default)] + output_smallest_contiguous_band: bool, } impl SolutionsApplyArgs { - pub fn run(self, dry_run: bool) -> Result<(), HyperdriveError> { - apply_solutions(self, dry_run)?; - Ok(()) - } -} - -fn apply_solutions(args: SolutionsApplyArgs, dry_run: bool) -> Result<(), SolutionsApplyError> { - debug!("{:#?}", args); - - // Expose all the struct fields to ensure they're all used. - let SolutionsApplyArgs { - data, - solutions, - timesteps, - use_all_timesteps, - array_position, - ms_data_column_name, - ignore_dut1, - tile_flags, - ignore_input_data_tile_flags, - ignore_input_solutions_tile_flags, - outputs, - time_average, - freq_average, - no_autos, - pfb_flavour, - no_digital_gains, - no_cable_length_correction, - no_geometric_correction, - no_progress_bars, - } = args; - - // Get the input data types. - let input_data_types = InputDataTypes::new(&data)?; - - // We don't necessarily need a metafits file, but if there's multiple of - // them, we complain. - let metafits = { - if let Some(ms) = input_data_types.metafits.as_ref() { - if ms.len() > 1 { - return Err(SolutionsApplyError::MultipleMetafits(ms.clone())); - } - Some(ms.first().as_ref()) - } else { - None - } - }; - - // Read the solutions before the input data; if something is wrong with - // them, then we can bail much faster. - let sols = CalibrationSolutions::read_solutions_from_ext_inner(&solutions, metafits)?; - - messages::CalSolDetails { - filename: &solutions, - sols: &sols, - } - .print(); - - // Use corrections specified by the solutions if they exist. Otherwise, we - // start with defaults. - let mut raw_data_corrections = sols.raw_data_corrections.unwrap_or_default(); - if let Some(s) = pfb_flavour { - let pfb_flavour = PfbFlavour::parse(&s)?; - raw_data_corrections.pfb_flavour = pfb_flavour; - }; - if no_digital_gains { - raw_data_corrections.digital_gains = false; - } - if no_cable_length_correction { - raw_data_corrections.cable_length = false; - } - if no_geometric_correction { - raw_data_corrections.geometric = false; - } - debug!("Raw data corrections with user input: {raw_data_corrections:?}"); - - // If the user supplied the array position, unpack it here. - let array_position = match array_position { - Some(pos) => { - if pos.len() != 3 { - return Err(SolutionsApplyError::BadArrayPosition { pos }); - } - Some(LatLngHeight { - longitude_rad: pos[0].to_radians(), - latitude_rad: pos[1].to_radians(), - height_metres: pos[2], + /// Both command-line and file arguments overlap in terms of what is + /// available; this function consolidates everything that was specified into + /// a single struct. Where applicable, it will prefer CLI parameters over + /// those in the file. + /// + /// The argument to this function is the path to the arguments file. + /// + /// This function should only ever merge arguments, and not try to make + /// sense of them. + pub(crate) fn merge(self) -> Result { + debug!("Merging command-line arguments with the argument file"); + + let cli_args = self; + + if let Some(arg_file) = cli_args.args_file { + // Read in the file arguments. Ensure all of the file args are + // accounted for by pattern matching. + let SolutionsApplyArgs { + args_file: _, + data_args, + solutions, + outputs, + output_vis_time_average, + output_vis_freq_average, + output_smallest_contiguous_band, + } = unpack_arg_file!(arg_file); + + // Merge all the arguments, preferring the CLI args when available. + Ok(SolutionsApplyArgs { + args_file: None, + data_args: cli_args.data_args.merge(data_args), + solutions: cli_args.solutions.or(solutions), + outputs: cli_args.outputs.or(outputs), + output_vis_time_average: cli_args + .output_vis_time_average + .or(output_vis_time_average), + output_vis_freq_average: cli_args + .output_vis_freq_average + .or(output_vis_freq_average), + output_smallest_contiguous_band: cli_args.output_smallest_contiguous_band + || output_smallest_contiguous_band, }) - } - None => None, - }; - - // Prepare an input data reader. - let input_data: Box = match ( - input_data_types.metafits, - input_data_types.gpuboxes, - input_data_types.mwafs, - input_data_types.ms, - input_data_types.uvfits, - ) { - // Valid input for reading raw data. - (Some(meta), Some(gpuboxes), mwafs, None, None) => { - // Ensure that there's only one metafits. - let meta = if meta.len() > 1 { - return Err(SolutionsApplyError::MultipleMetafits(meta)); - } else { - meta.first() - }; - - debug!("gpubox files: {:?}", &gpuboxes); - debug!("mwaf files: {:?}", &mwafs); - - let input_data = Box::new(RawDataReader::new( - meta, - &gpuboxes, - mwafs.as_deref(), - raw_data_corrections, - array_position, - )?); - - messages::InputFileDetails::Raw { - obsid: input_data.get_obs_context().obsid.unwrap(), - gpubox_count: gpuboxes.len(), - metafits_file_name: meta.display().to_string(), - mwaf: input_data.get_flags(), - raw_data_corrections, - } - .print("Applying solutions to"); // Print some high-level information. - - input_data - } - - // Valid input for reading a measurement set. - (meta, None, None, Some(ms), None) => { - // Only one MS is supported at the moment. - let ms: PathBuf = if ms.len() > 1 { - return Err(SolutionsApplyError::MultipleMeasurementSets(ms)); - } else { - ms.first().clone() - }; - - // Ensure that there's only one metafits. - let meta: Option<&Path> = match meta.as_ref() { - None => None, - Some(m) => { - if m.len() > 1 { - return Err(SolutionsApplyError::MultipleMetafits(m.clone())); - } else { - Some(m.first().as_path()) - } - } - }; - - let input_data = MsReader::new(ms.clone(), ms_data_column_name, meta, array_position) - .map_err(VisReadError::from)?; - - messages::InputFileDetails::MeasurementSet { - obsid: input_data.get_obs_context().obsid, - file_name: ms.display().to_string(), - metafits_file_name: meta.map(|m| m.display().to_string()), - } - .print("Applying solutions to"); - - Box::new(input_data) - } - - // Valid input for reading uvfits files. - (meta, None, None, None, Some(uvfits)) => { - // Only one uvfits is supported at the moment. - let uvfits: PathBuf = if uvfits.len() > 1 { - return Err(SolutionsApplyError::MultipleUvfits(uvfits)); - } else { - uvfits.first().clone() - }; - - // Ensure that there's only one metafits. - let meta: Option<&Path> = match meta.as_ref() { - None => None, - Some(m) => { - if m.len() > 1 { - return Err(SolutionsApplyError::MultipleMetafits(m.clone())); - } else { - Some(m.first()) - } - } - }; - - let input_data = UvfitsReader::new(uvfits.clone(), meta, array_position) - .map_err(VisReadError::from)?; - - messages::InputFileDetails::UvfitsFile { - obsid: input_data.get_obs_context().obsid, - file_name: uvfits.display().to_string(), - metafits_file_name: meta.map(|m| m.display().to_string()), - } - .print("Applying solutions to"); - - Box::new(input_data) - } - - // The following matches are for invalid combinations of input - // files. Make an error message for the user. - (Some(_), _, None, None, None) => { - let msg = "Received only a metafits file; a uvfits file, a measurement set or gpubox files are required."; - return Err(SolutionsApplyError::InvalidDataInput(msg)); - } - (Some(_), _, Some(_), None, None) => { - let msg = "Received only a metafits file and mwaf files; gpubox files are required."; - return Err(SolutionsApplyError::InvalidDataInput(msg)); - } - (None, Some(_), _, None, None) => { - let msg = "Received gpuboxes without a metafits file; this is not supported."; - return Err(SolutionsApplyError::InvalidDataInput(msg)); - } - (None, None, Some(_), None, None) => { - let msg = - "Received mwaf files without gpuboxes and a metafits file; this is not supported."; - return Err(SolutionsApplyError::InvalidDataInput(msg)); - } - (_, Some(_), _, Some(_), None) => { - let msg = "Received gpuboxes and measurement set files; this is not supported."; - return Err(SolutionsApplyError::InvalidDataInput(msg)); - } - (_, Some(_), _, None, Some(_)) => { - let msg = "Received gpuboxes and uvfits files; this is not supported."; - return Err(SolutionsApplyError::InvalidDataInput(msg)); - } - (_, _, _, Some(_), Some(_)) => { - let msg = "Received uvfits and measurement set files; this is not supported."; - return Err(SolutionsApplyError::InvalidDataInput(msg)); - } - (_, _, Some(_), Some(_), _) => { - let msg = "Received mwafs and measurement set files; this is not supported."; - return Err(SolutionsApplyError::InvalidDataInput(msg)); - } - (_, _, Some(_), _, Some(_)) => { - let msg = "Received mwafs and uvfits files; this is not supported."; - return Err(SolutionsApplyError::InvalidDataInput(msg)); - } - (None, None, None, None, None) => return Err(SolutionsApplyError::NoInputData), - }; - - // Warn the user if we're applying solutions to raw data without corrections. - if matches!(input_data.get_input_data_type(), VisInputType::Raw) - && sols.raw_data_corrections.is_none() - { - warn!("The calibration solutions do not list raw data corrections."); - warn!("Defaults and any user inputs are being used."); - }; - - let obs_context = input_data.get_obs_context(); - let no_autos = if !obs_context.autocorrelations_present { - info!("No auto-correlations in the input data; none will be written out"); - true - } else if no_autos { - info!("Ignoring auto-correlations in the input data; none will be written out"); - true - } else { - info!("Will write out calibrated cross- and auto-correlations"); - false - }; - let total_num_tiles = obs_context.get_total_num_tiles(); - - // We can't do anything if the number of tiles in the data is different to - // that of the solutions. - if total_num_tiles != sols.di_jones.len_of(Axis(1)) { - return Err(SolutionsApplyError::TileCountMismatch { - data: total_num_tiles, - solutions: sols.di_jones.len_of(Axis(1)), - }); - } - - // Assign the tile flags. The flagged tiles in the solutions are always - // used. - let tile_flags = { - // Need to convert indices into strings to use the `get_tile_flags` - // method below. - let mut sol_flags: Vec = if ignore_input_solutions_tile_flags { - debug!("Including any tiles with only NaN for solutions"); - vec![] } else { - debug!( - "There are {} tiles with only NaN for solutions; considering them as flagged tiles", - sols.flagged_tiles.len() - ); - sols.flagged_tiles.iter().map(|i| format!("{i}")).collect() - }; - if let Some(user_tile_flags) = tile_flags { - debug!("Using additional user tile flags: {user_tile_flags:?}"); - sol_flags.extend(user_tile_flags); + Ok(cli_args) } - if ignore_input_data_tile_flags { - debug!("Not using flagged tiles in the input data"); - } else { - debug!( - "Using input data tile flags: {:?}", - obs_context.flagged_tiles - ); - } - obs_context.get_tile_flags(ignore_input_data_tile_flags, Some(&sol_flags))? - }; - let num_unflagged_tiles = total_num_tiles - tile_flags.len(); - if log_enabled!(Debug) { - debug!("Tile indices, names and statuses:"); - obs_context - .tile_names - .iter() - .enumerate() - .map(|(i, name)| { - let flagged = tile_flags.contains(&i); - (i, name, if flagged { " flagged" } else { "unflagged" }) - }) - .for_each(|(i, name, status)| { - debug!(" {:3}: {:10}: {}", i, name, status); - }) - } - if num_unflagged_tiles == 0 { - return Err(SolutionsApplyError::NoTiles); } - let array_position = obs_context.array_position; - messages::ArrayDetails { - array_position: Some(array_position), - array_latitude_j2000: None, - total_num_tiles, - num_unflagged_tiles, - flagged_tiles: &tile_flags - .iter() - .cloned() - .sorted() - .map(|i| (obs_context.tile_names[i].as_str(), i)) - .collect::>(), - } - .print(); - let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, tile_flags); - let timesteps = match (use_all_timesteps, timesteps) { - (true, _) => Vec1::try_from(obs_context.all_timesteps.as_slice()), - (false, None) => Vec1::try_from(obs_context.unflagged_timesteps.as_slice()), - (false, Some(mut ts)) => { - // Make sure there are no duplicates. - let timesteps_hashset: HashSet<&usize> = ts.iter().collect(); - if timesteps_hashset.len() != ts.len() { - return Err(SolutionsApplyError::DuplicateTimesteps); - } + fn parse(self) -> Result { + debug!("{:#?}", self); - // Ensure that all specified timesteps are actually available. - for &t in &ts { - if obs_context.timestamps.get(t).is_none() { - return Err(SolutionsApplyError::UnavailableTimestep { - got: t, - last: obs_context.timestamps.len() - 1, - }); - } - } + let Self { + args_file: _, + mut data_args, + solutions, + outputs, + output_vis_time_average, + output_vis_freq_average, + output_smallest_contiguous_band, + } = self; - ts.sort_unstable(); - Vec1::try_from_vec(ts) - } - } - .map_err(|_| SolutionsApplyError::NoTimesteps)?; + match (solutions, data_args.files.as_mut()) { + // Add the user-specified solutions to the file list. + (Some(s), Some(f)) => f.push(s), - let dut1 = if ignore_dut1 { None } else { obs_context.dut1 }; + // No solutions specified to solutions-apply; if no solutions were + // given to data_args, then we'll need to complain. + (None, _) => (), - messages::ObservationDetails { - dipole_delays: None, - beam_file: None, - num_tiles_with_dead_dipoles: None, - phase_centre: obs_context.phase_centre, - pointing_centre: obs_context.pointing_centre, - dut1, - lmst: None, - lmst_j2000: None, - available_timesteps: Some(&obs_context.all_timesteps), - unflagged_timesteps: Some(&obs_context.unflagged_timesteps), - using_timesteps: Some(×teps), - first_timestamp: Some(obs_context.timestamps[*timesteps.first()]), - last_timestamp: if timesteps.len() > 1 { - Some(obs_context.timestamps[*timesteps.last()]) - } else { - None - }, - time_res: obs_context.time_res, - total_num_channels: obs_context.fine_chan_freqs.len(), - num_unflagged_channels: None, - flagged_chans_per_coarse_chan: None, - first_freq_hz: Some(*obs_context.fine_chan_freqs.first() as f64), - last_freq_hz: Some(*obs_context.fine_chan_freqs.last() as f64), - first_unflagged_freq_hz: None, - last_unflagged_freq_hz: None, - freq_res_hz: obs_context.freq_res, - } - .print(); - - // Handle output visibility arguments. - let (time_average_factor, freq_average_factor) = { - // Parse and verify user input (specified resolutions must - // evenly divide the input data's resolutions). - let time_factor = parse_time_average_factor( - obs_context.time_res, - time_average.as_deref(), - 1, - ) - .map_err(|e| match e { - AverageFactorError::Zero => SolutionsApplyError::OutputVisTimeAverageFactorZero, - AverageFactorError::NotInteger => SolutionsApplyError::OutputVisTimeFactorNotInteger, - AverageFactorError::NotIntegerMultiple { out, inp } => { - SolutionsApplyError::OutputVisTimeResNotMultiple { out, inp } - } - AverageFactorError::Parse(e) => SolutionsApplyError::ParseOutputVisTimeAverageFactor(e), - })?; - let freq_factor = parse_freq_average_factor( - obs_context.freq_res, - freq_average.as_deref(), - 1, - ) - .map_err(|e| match e { - AverageFactorError::Zero => SolutionsApplyError::OutputVisFreqAverageFactorZero, - AverageFactorError::NotInteger => SolutionsApplyError::OutputVisFreqFactorNotInteger, - AverageFactorError::NotIntegerMultiple { out, inp } => { - SolutionsApplyError::OutputVisFreqResNotMultiple { out, inp } - } - AverageFactorError::Parse(e) => SolutionsApplyError::ParseOutputVisFreqAverageFactor(e), - })?; - - (time_factor, freq_factor) - }; - - let outputs = match outputs { - None => vec1![( - PathBuf::from("hyp_calibrated.uvfits"), - VisOutputType::Uvfits - )], - Some(v) => { - let mut outputs = Vec::with_capacity(v.len()); - for file in v { - // Is the output file type supported? - let ext = file.extension().and_then(|os_str| os_str.to_str()); - match ext.and_then(|s| VisOutputType::from_str(s).ok()) { - Some(vis_type) => { - trace!("{} is a visibility output", file.display()); - can_write_to_file(&file)?; - outputs.push((file, vis_type)); - } - None => return Err(SolutionsApplyError::InvalidOutputFormat(file)), - } - } - Vec1::try_from_vec(outputs).map_err(|_| SolutionsApplyError::NoOutput)? + // Solutions were given, but no data_args. Well, we need + // visibilities, so parsing data_args will fail. + (Some(_), None) => (), } - }; - - messages::OutputFileDetails { - output_solutions: &[], - vis_type: "calibrated", - output_vis: Some(&outputs), - input_vis_time_res: obs_context.time_res, - input_vis_freq_res: obs_context.freq_res, - output_vis_time_average_factor: time_average_factor, - output_vis_freq_average_factor: freq_average_factor, - } - .print(); - - if dry_run { - info!("Dry run -- exiting now."); - return Ok(()); - } - apply_solutions_inner( - &*input_data, - &sols, - ×teps, - array_position, - dut1.unwrap_or_else(|| Duration::from_seconds(0.0)), - no_autos, - &tile_baseline_flags, - // TODO: Provide CLI options - &HashSet::new(), - false, - &outputs, - time_average_factor, - freq_average_factor, - no_progress_bars, - ) -} - -#[allow(clippy::too_many_arguments)] -pub(super) fn apply_solutions_inner( - input_data: &dyn VisRead, - sols: &CalibrationSolutions, - timesteps: &Vec1, - array_position: LatLngHeight, - dut1: Duration, - no_autos: bool, - tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, - ignore_input_data_fine_channel_flags: bool, - outputs: &Vec1<(PathBuf, VisOutputType)>, - output_vis_time_average_factor: usize, - output_vis_freq_average_factor: usize, - no_progress_bars: bool, -) -> Result<(), SolutionsApplyError> { - let obs_context = input_data.get_obs_context(); - let fine_chan_flags = { - let mut flagged_fine_chans = flagged_fine_chans.clone(); - if !ignore_input_data_fine_channel_flags { - flagged_fine_chans.extend(obs_context.flagged_fine_chans.iter().copied()); + let input_vis_params = data_args.parse("Applying solutions")?; + if input_vis_params.solutions.is_none() { + return Err(SolutionsApplyArgsError::NoSolutions.into()); } - flagged_fine_chans - }; - - let timeblocks = timesteps_to_timeblocks( - &obs_context.timestamps, - output_vis_time_average_factor, - timesteps, - ); - - // Channel for applying solutions. - let (tx_data, rx_data) = bounded(3); - // Channel for writing calibrated visibilities. - let (tx_write, rx_write) = bounded(3); - - // Progress bars. - let multi_progress = MultiProgress::with_draw_target(if no_progress_bars { - ProgressDrawTarget::hidden() - } else { - ProgressDrawTarget::stdout() - }); - let read_progress = multi_progress.add( - ProgressBar::new(timesteps.len() as _) - .with_style( - ProgressStyle::default_bar() - .template("{msg:18}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() - .progress_chars("=> "), - ) - .with_position(0) - .with_message("Reading data"), - ); - let apply_progress = multi_progress.add( - ProgressBar::new(timesteps.len() as _) - .with_style( - ProgressStyle::default_bar() - .template("{msg:18}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() - .progress_chars("=> "), - ) - .with_position(0) - .with_message("Applying solutions"), - ); - let write_progress = multi_progress.add( - ProgressBar::new(timeblocks.len() as _) - .with_style( - ProgressStyle::default_bar() - .template("{msg:18}: [{wide_bar:.blue}] {pos:2}/{len:2} timeblocks ({elapsed_precise}<{eta_precise})").unwrap() - .progress_chars("=> "), - ) - .with_position(0) - .with_message("Writing data"), - ); - - // Use a variable to track whether any threads have an issue. - let error = AtomicCell::new(false); - - info!("Reading input data, applying, and writing"); - let scoped_threads_result = thread::scope(|s| { - // Input visibility-data reading thread. - let data_handle = s.spawn(|| { - // If a panic happens, update our atomic error. - defer_on_unwind! { error.store(true); } - read_progress.tick(); - - let result = read_vis( - obs_context, - tile_baseline_flags, - input_data, - timesteps, - no_autos, - tx_data, - &error, - read_progress, - ); - // If the result of reading data was an error, allow the other - // threads to see this so they can abandon their work early. - if result.is_err() { - error.store(true); - } - result - }); - // Solutions applying thread. - let apply_handle = s.spawn(|| { - defer_on_unwind! { error.store(true); } - apply_progress.tick(); - - let result = apply_solutions_thread( - obs_context, - sols, - tile_baseline_flags, - &fine_chan_flags, - rx_data, - tx_write, - &error, - apply_progress, - ); - if result.is_err() { - error.store(true); - } - result - }); - - // Calibrated vis writing thread. - let write_handle = s.spawn(|| { - defer_on_unwind! { error.store(true); } - write_progress.tick(); - - // If we're not using autos, "disable" the `unflagged_tiles_iter` by - // making it not iterate over anything. - let total_num_tiles = if no_autos { - 0 - } else { - obs_context.get_total_num_tiles() - }; - let unflagged_tiles_iter = (0..total_num_tiles) - .filter(|i_tile| !tile_baseline_flags.flagged_tiles.contains(i_tile)) - .map(|i_tile| (i_tile, i_tile)); - // Form (sorted) unflagged baselines from our cross- and - // auto-correlation baselines. - let unflagged_cross_and_auto_baseline_tile_pairs = tile_baseline_flags - .tile_to_unflagged_cross_baseline_map - .keys() - .copied() - .chain(unflagged_tiles_iter) - .sorted() - .collect::>(); - let fine_chan_freqs = obs_context.fine_chan_freqs.mapped_ref(|&f| f as f64); + let output_vis_params = OutputVisArgs { + outputs, + output_vis_time_average, + output_vis_freq_average, + } + .parse( + input_vis_params.time_res, + input_vis_params.spw.freq_res, + &input_vis_params.timeblocks.mapped_ref(|tb| tb.median), + output_smallest_contiguous_band, + DEFAULT_OUTPUT_VIS_FILENAME, + Some("calibrated"), + )?; - let result = write_vis( - outputs, - array_position, - obs_context.phase_centre, - obs_context.pointing_centre, - &obs_context.tile_xyzs, - &obs_context.tile_names, - obs_context.obsid, - &obs_context.timestamps, - timesteps, - &timeblocks, - obs_context.guess_time_res(), - dut1, - obs_context.guess_freq_res(), - &fine_chan_freqs, - &unflagged_cross_and_auto_baseline_tile_pairs, - &HashSet::new(), - output_vis_time_average_factor, - output_vis_freq_average_factor, - input_data.get_marlu_mwa_info().as_ref(), - rx_write, - &error, - Some(write_progress), - ); - if result.is_err() { - error.store(true); - } - result - }); + display_warnings(); - // Join all thread handles. This propagates any errors and lets us know - // if any threads panicked, if panics aren't aborting as per the - // Cargo.toml. (It would be nice to capture the panic information, if - // it's possible, but I don't know how, so panics are currently - // aborting.) - let result = data_handle.join().unwrap(); - let result = result.and_then(|_| apply_handle.join().unwrap()); - result.and_then(|_| { - write_handle - .join() - .unwrap() - .map_err(SolutionsApplyError::from) + Ok(SolutionsApplyParams { + input_vis_params, + output_vis_params, }) - }); - - // Propagate errors and print out the write message. - let s = scoped_threads_result?; - info!("{s}"); - - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -fn read_vis( - obs_context: &ObsContext, - tile_baseline_flags: &TileBaselineFlags, - input_data: &dyn VisRead, - timesteps: &Vec1, - no_autos: bool, - tx: Sender, - error: &AtomicCell, - progress_bar: ProgressBar, -) -> Result<(), SolutionsApplyError> { - let num_unflagged_tiles = tile_baseline_flags.unflagged_auto_index_to_tile_map.len(); - let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; - - let cross_vis_shape = ( - obs_context.fine_chan_freqs.len(), - num_unflagged_cross_baselines, - ); - let auto_vis_shape = (obs_context.fine_chan_freqs.len(), num_unflagged_tiles); - - // Send the data as timesteps. - for ×tep in timesteps { - let timestamp = obs_context.timestamps[timestep]; - debug!( - "Reading timestep {timestep} (GPS {})", - timestamp.to_gpst_seconds() - ); - - let mut cross_data_fb: ArcArray2> = ArcArray2::zeros(cross_vis_shape); - let mut cross_weights_fb: ArcArray2 = ArcArray2::zeros(cross_vis_shape); - let mut autos = if no_autos { - None - } else { - Some(( - ArcArray::zeros(auto_vis_shape), - ArcArray::zeros(auto_vis_shape), - )) - }; + } - if let Some((auto_data_fb, auto_weights_fb)) = autos.as_mut() { - input_data.read_crosses_and_autos( - cross_data_fb.view_mut(), - cross_weights_fb.view_mut(), - auto_data_fb.view_mut(), - auto_weights_fb.view_mut(), - timestep, - tile_baseline_flags, - // We want to read in all channels, even if they're flagged. - // Channels will get flagged later based on the calibration - // solutions, the input data flags and user flags. - &HashSet::new(), - )?; - } else { - input_data.read_crosses( - cross_data_fb.view_mut(), - cross_weights_fb.view_mut(), - timestep, - tile_baseline_flags, - &HashSet::new(), - )?; - } + pub(crate) fn run(self, dry_run: bool) -> Result<(), HyperdriveError> { + debug!("Converting arguments into parameters"); + trace!("{:#?}", self); + let params = self.parse()?; - // Should we continue? - if error.load() { + if dry_run { + info!("Dry run -- exiting now."); return Ok(()); } - match tx.send(VisTimestep { - cross_data_fb, - cross_weights_fb, - autos, - timestamp, - }) { - Ok(()) => (), - // If we can't send the message, it's because the channel - // has been closed on the other side. That should only - // happen because the writer has exited due to error; in - // that case, just exit this thread. - Err(_) => return Ok(()), - } - - progress_bar.inc(1); + params.run()?; + Ok(()) } - - debug!("Finished reading"); - progress_bar.abandon_with_message("Finished reading visibilities"); - Ok(()) } -#[allow(clippy::too_many_arguments)] -fn apply_solutions_thread( - obs_context: &ObsContext, - solutions: &CalibrationSolutions, - tile_baseline_flags: &TileBaselineFlags, - fine_chan_flags: &HashSet, - rx: Receiver, - tx: Sender, - error: &AtomicCell, - progress_bar: ProgressBar, -) -> Result<(), SolutionsApplyError> { - for VisTimestep { - mut cross_data_fb, - mut cross_weights_fb, - mut autos, - timestamp, - } in rx.iter() - { - // Should we continue? - if error.load() { - return Ok(()); - } - - let span = *obs_context.timestamps.last() - *obs_context.timestamps.first(); - let timestamp_fraction = ((timestamp - *obs_context.timestamps.first()).to_seconds() - / span.to_seconds()) - // Stop stupid values. - .clamp(0.0, 0.99); - - // Find solutions corresponding to this timestamp. - let sols = solutions.get_timeblock(timestamp, timestamp_fraction); - - for (i_baseline, (mut cross_data_f, mut cross_weights_f)) in cross_data_fb - .axis_iter_mut(Axis(1)) - .zip_eq(cross_weights_fb.axis_iter_mut(Axis(1))) - .enumerate() - { - let (tile1, tile2) = tile_baseline_flags.unflagged_cross_baseline_to_tile_map - .get(&i_baseline) - .copied() - .unwrap_or_else(|| { - panic!("Couldn't find baseline index {i_baseline} in unflagged_cross_baseline_to_tile_map") - }); - // TODO: Allow solutions to have a different number of channels than - // the data. - - // Get the solutions for both tiles and apply them. - let sols_tile1 = sols.slice(s![tile1, ..]); - let sols_tile2 = sols.slice(s![tile2, ..]); - cross_data_f - .iter_mut() - .zip_eq(cross_weights_f.iter_mut()) - .zip_eq(sols_tile1.iter()) - .zip_eq(sols_tile2.iter()) - .enumerate() - .for_each(|(i_chan, (((data, weight), sol1), sol2))| { - // One of the tiles doesn't have a solution; flag. - if sol1.any_nan() || sol2.any_nan() { - *weight = -weight.abs(); - *data = Jones::default(); - } else { - if fine_chan_flags.contains(&i_chan) { - // The channel is flagged, but we still have a solution for it. - *weight = -weight.abs(); - } - // Promote the data before demoting it again. - let d: Jones = Jones::from(*data); - *data = Jones::from((*sol1 * d) * sol2.h()); - } - }); - } - - if let Some((auto_data_fb, auto_weights_fb)) = autos.as_mut() { - for (i_tile, (mut auto_data_f, mut auto_weights_f)) in auto_data_fb - .axis_iter_mut(Axis(1)) - .zip_eq(auto_weights_fb.axis_iter_mut(Axis(1))) - .enumerate() - { - let i_tile = tile_baseline_flags - .unflagged_auto_index_to_tile_map - .get(&i_tile) - .copied() - .unwrap_or_else(|| { - panic!( - "Couldn't find auto index {i_tile} in unflagged_auto_index_to_tile_map" - ) - }); - - // Get the solutions for the tile and apply it twice. - let sols = sols.slice(s![i_tile, ..]); - auto_data_f - .iter_mut() - .zip_eq(auto_weights_f.iter_mut()) - .zip_eq(sols.iter()) - .enumerate() - .for_each(|(i_chan, ((data, weight), sol))| { - // No solution; flag. - if sol.any_nan() { - *weight = -weight.abs(); - *data = Jones::default(); - } else { - if fine_chan_flags.contains(&i_chan) { - // The channel is flagged, but we still have a solution for it. - *weight = -weight.abs(); - } - // Promote the data before demoting it again. - let d: Jones = Jones::from(*data); - *data = Jones::from((*sol * d) * sol.h()); - } - }); - } - } - - // Send the calibrated visibilities to the writer. - match tx.send(VisTimestep { - cross_data_fb, - cross_weights_fb, - autos, - timestamp, - }) { - Ok(()) => (), - // If we can't send the message, it's because the channel - // has been closed on the other side. That should only - // happen because the writer has exited due to error; in - // that case, just exit this thread. - Err(_) => return Ok(()), - } - progress_bar.inc(1); - } - debug!("Finished applying"); - progress_bar.abandon_with_message("Finished applying solutions"); - Ok(()) +#[derive(thiserror::Error, Debug)] +pub(crate) enum SolutionsApplyArgsError { + #[error("No calibration solutions were supplied")] + NoSolutions, } diff --git a/src/cli/solutions/apply/tests.rs b/src/cli/solutions/apply/tests.rs index fda14d2c..f5db520f 100644 --- a/src/cli/solutions/apply/tests.rs +++ b/src/cli/solutions/apply/tests.rs @@ -7,82 +7,73 @@ use std::{collections::HashSet, path::Path}; use approx::{assert_abs_diff_eq, assert_relative_eq}; +use crossbeam_utils::atomic::AtomicCell; +use itertools::{izip, Itertools}; use marlu::Jones; use ndarray::prelude::*; use serial_test::serial; use tempfile::TempDir; -use vec1::vec1; +use vec1::Vec1; use super::*; use crate::{ + cli::{common::InputVisArgs, vis_convert::VisConvertArgs}, io::read::{ fits::{fits_get_required_key, fits_open, fits_open_hdu}, - pfb_gains::PfbFlavour, - RawDataCorrections, + MsReader, UvfitsReader, VisRead, }, - tests::reduced_obsids::{ - get_reduced_1090008640, get_reduced_1090008640_ms, get_reduced_1090008640_uvfits, + math::TileBaselineFlags, + tests::{ + get_reduced_1090008640_ms, get_reduced_1090008640_raw, get_reduced_1090008640_raw_pbs, + get_reduced_1090008640_uvfits, DataAsPathBufs, DataAsStrings, }, + CalibrationSolutions, }; -fn test_solutions_apply_trivial(input_data: &dyn VisRead, metafits: &Path) { +fn test_solutions_apply_trivial(mut args: SolutionsApplyArgs) { + let tmp_dir = TempDir::new().unwrap(); + let output = tmp_dir.path().join("test.uvfits"); + let error = AtomicCell::new(false); + let metafits = PathBuf::from(args.data_args.files.as_ref().unwrap().last().unwrap()); + // Make some solutions that are all identity; the output visibilities should // be the same as the input. + let sols_file = get_1090008640_identity_solutions_file(tmp_dir.path()); + args.solutions = Some(sols_file.display().to_string()); + args.outputs = Some(vec![output.clone()]); + args.output_vis_time_average = None; + args.output_vis_freq_average = None; + + let mut params = args.parse().unwrap(); + // Get the reference visibilities. - let obs_context = input_data.get_obs_context(); - let flagged_tiles = obs_context.get_tile_flags(false, None).unwrap(); - assert!(flagged_tiles.is_empty()); + let obs_context = params.input_vis_params.vis_reader.get_obs_context(); let total_num_tiles = obs_context.get_total_num_tiles(); let total_num_baselines = (total_num_tiles * (total_num_tiles - 1)) / 2; let total_num_channels = obs_context.fine_chan_freqs.len(); - let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, flagged_tiles); - let mut flagged_fine_chans = HashSet::new(); + let tile_baseline_flags = ¶ms.input_vis_params.tile_baseline_flags; + let flagged_tiles = &tile_baseline_flags.flagged_tiles; + assert!(flagged_tiles.is_empty()); + let flagged_fine_chans = ¶ms.input_vis_params.spw.flagged_chan_indices; + let mut ref_crosses = Array2::zeros((total_num_channels, total_num_baselines)); let mut ref_cross_weights = Array2::zeros((total_num_channels, total_num_baselines)); let mut ref_autos = Array2::zeros((total_num_channels, total_num_tiles)); let mut ref_auto_weights = Array2::zeros((total_num_channels, total_num_tiles)); - input_data - .read_crosses_and_autos( + params + .input_vis_params + .read_timeblock( + params.input_vis_params.timeblocks.first(), ref_crosses.view_mut(), ref_cross_weights.view_mut(), - ref_autos.view_mut(), - ref_auto_weights.view_mut(), - obs_context.unflagged_timesteps[0], - &tile_baseline_flags, - &flagged_fine_chans, + Some((ref_autos.view_mut(), ref_auto_weights.view_mut())), + &error, ) .unwrap(); - - let mut sols = CalibrationSolutions { - di_jones: Array3::from_elem((1, total_num_tiles, total_num_channels), Jones::identity()), - flagged_tiles: vec![], - flagged_chanblocks: vec![], - ..Default::default() - }; - let timesteps = Vec1::try_from_vec(obs_context.unflagged_timesteps.clone()).unwrap(); - let tmp_dir = TempDir::new().unwrap(); - let output = tmp_dir.path().join("test.uvfits"); - let outputs = vec1![(output.clone(), VisOutputType::Uvfits)]; - - apply_solutions_inner( - input_data, - &sols, - ×teps, - LatLngHeight::mwa(), - Duration::default(), - false, - &tile_baseline_flags, - &flagged_fine_chans, - true, - &outputs, - 1, - 1, - true, - ) - .unwrap(); + params.run().unwrap(); // Read the output visibilities. - let output_data = UvfitsReader::new(output.clone(), Some(metafits), None).unwrap(); + let output_data = UvfitsReader::new(output.clone(), Some(&metafits), None).unwrap(); let mut crosses = Array2::zeros((total_num_channels, total_num_baselines)); let mut cross_weights = Array2::zeros((total_num_channels, total_num_baselines)); let mut autos = Array2::zeros((total_num_channels, total_num_tiles)); @@ -94,8 +85,8 @@ fn test_solutions_apply_trivial(input_data: &dyn VisRead, metafits: &Path) { autos.view_mut(), auto_weights.view_mut(), 0, - &tile_baseline_flags, - &flagged_fine_chans, + tile_baseline_flags, + flagged_fine_chans, ) .unwrap(); @@ -106,26 +97,17 @@ fn test_solutions_apply_trivial(input_data: &dyn VisRead, metafits: &Path) { // Now make the solutions all "2"; the output visibilities should be 4x the // input. - sols.di_jones.mapv_inplace(|j| j * 2.0); - apply_solutions_inner( - input_data, - &sols, - ×teps, - LatLngHeight::mwa(), - Duration::default(), - false, - &tile_baseline_flags, - &flagged_fine_chans, - true, - &outputs, - 1, - 1, - true, - ) - .unwrap(); + params + .input_vis_params + .solutions + .as_mut() + .unwrap() + .di_jones + .mapv_inplace(|j| j * 2.0); + params.run().unwrap(); // Read the output visibilities. - let output_data = UvfitsReader::new(output.clone(), Some(metafits), None).unwrap(); + let output_data = UvfitsReader::new(output.clone(), Some(&metafits), None).unwrap(); crosses.fill(Jones::default()); cross_weights.fill(0.0); autos.fill(Jones::default()); @@ -137,8 +119,8 @@ fn test_solutions_apply_trivial(input_data: &dyn VisRead, metafits: &Path) { autos.view_mut(), auto_weights.view_mut(), 0, - &tile_baseline_flags, - &flagged_fine_chans, + tile_baseline_flags, + flagged_fine_chans, ) .unwrap(); @@ -148,32 +130,22 @@ fn test_solutions_apply_trivial(input_data: &dyn VisRead, metafits: &Path) { assert_abs_diff_eq!(auto_weights, ref_auto_weights); // Now make the solutions equal to the tile index. - sols.di_jones + params + .input_vis_params + .solutions + .as_mut() + .unwrap() + .di_jones .slice_mut(s![0, .., ..]) .outer_iter_mut() .enumerate() .for_each(|(i_tile, mut sols)| { sols.fill(Jones::identity() * (i_tile + 1) as f64); }); - apply_solutions_inner( - input_data, - &sols, - ×teps, - LatLngHeight::mwa(), - Duration::default(), - false, - &tile_baseline_flags, - &flagged_fine_chans, - true, - &outputs, - 1, - 1, - true, - ) - .unwrap(); + params.run().unwrap(); // Read the output visibilities. - let output_data = UvfitsReader::new(output.clone(), Some(metafits), None).unwrap(); + let output_data = UvfitsReader::new(output.clone(), Some(&metafits), None).unwrap(); crosses.fill(Jones::default()); cross_weights.fill(0.0); autos.fill(Jones::default()); @@ -185,8 +157,8 @@ fn test_solutions_apply_trivial(input_data: &dyn VisRead, metafits: &Path) { autos.view_mut(), auto_weights.view_mut(), 0, - &tile_baseline_flags, - &flagged_fine_chans, + tile_baseline_flags, + flagged_fine_chans, ) .unwrap(); @@ -225,70 +197,56 @@ fn test_solutions_apply_trivial(input_data: &dyn VisRead, metafits: &Path) { assert_abs_diff_eq!(auto_weights, ref_auto_weights); // Use tile indices for solutions again, but now flag some tiles. - let mut flagged_tiles = tile_baseline_flags.flagged_tiles; - flagged_tiles.insert(10); - flagged_tiles.insert(78); - let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, flagged_tiles); - for f in &tile_baseline_flags.flagged_tiles { - sols.flagged_tiles.push(*f); - } + let flags = [10, 78]; + params.input_vis_params.tile_baseline_flags = + TileBaselineFlags::new(total_num_tiles, HashSet::from(flags)); + let tile_baseline_flags = ¶ms.input_vis_params.tile_baseline_flags; // Re-generate the reference data. let num_unflagged_tiles = total_num_tiles - tile_baseline_flags.flagged_tiles.len(); let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; - let mut ref_crosses = Array2::zeros((total_num_channels, num_unflagged_cross_baselines)); - let mut ref_cross_weights = Array2::zeros((total_num_channels, num_unflagged_cross_baselines)); - let mut ref_autos = Array2::zeros((total_num_channels, num_unflagged_tiles)); - let mut ref_auto_weights = Array2::zeros((total_num_channels, num_unflagged_tiles)); - input_data - .read_crosses_and_autos( - ref_crosses.view_mut(), - ref_cross_weights.view_mut(), - ref_autos.view_mut(), - ref_auto_weights.view_mut(), - obs_context.unflagged_timesteps[0], - &tile_baseline_flags, - &flagged_fine_chans, - ) - .unwrap(); - - apply_solutions_inner( - input_data, - &sols, - ×teps, - LatLngHeight::mwa(), - Duration::default(), - false, - &tile_baseline_flags, - &flagged_fine_chans, - true, - &outputs, - 1, - 1, - true, - ) - .unwrap(); + let mut ref_crosses_fb = Array2::zeros((total_num_channels, num_unflagged_cross_baselines)); + let mut ref_cross_weights_fb = + Array2::zeros((total_num_channels, num_unflagged_cross_baselines)); + let mut ref_autos_fb = Array2::zeros((total_num_channels, num_unflagged_tiles)); + let mut ref_auto_weights_fb = Array2::zeros((total_num_channels, num_unflagged_tiles)); + params.run().unwrap(); // Read the output visibilities. - let output_data = UvfitsReader::new(output.clone(), Some(metafits), None).unwrap(); - let mut crosses = Array2::zeros((total_num_channels, num_unflagged_cross_baselines)); - let mut cross_weights = Array2::zeros((total_num_channels, num_unflagged_cross_baselines)); - let mut autos = Array2::zeros((total_num_channels, num_unflagged_tiles)); - let mut auto_weights = Array2::zeros((total_num_channels, num_unflagged_tiles)); + let output_data = UvfitsReader::new(output.clone(), Some(&metafits), None).unwrap(); + let mut crosses_fb = Array2::zeros((total_num_channels, num_unflagged_cross_baselines)); + let mut cross_weights_fb = Array2::zeros((total_num_channels, num_unflagged_cross_baselines)); + let mut autos_fb = Array2::zeros((total_num_channels, num_unflagged_tiles)); + let mut auto_weights_fb = Array2::zeros((total_num_channels, num_unflagged_tiles)); output_data .read_crosses_and_autos( - crosses.view_mut(), - cross_weights.view_mut(), - autos.view_mut(), - auto_weights.view_mut(), + crosses_fb.view_mut(), + cross_weights_fb.view_mut(), + autos_fb.view_mut(), + auto_weights_fb.view_mut(), 0, - &tile_baseline_flags, - &flagged_fine_chans, + tile_baseline_flags, + flagged_fine_chans, ) .unwrap(); - for (i_baseline, (baseline, ref_baseline)) in crosses + // Read in the newly-flagged data without any solutions being applied. + params + .input_vis_params + .vis_reader + .read_crosses_and_autos( + ref_crosses_fb.view_mut(), + ref_cross_weights_fb.view_mut(), + ref_autos_fb.view_mut(), + ref_auto_weights_fb.view_mut(), + obs_context.unflagged_timesteps[0], + tile_baseline_flags, + flagged_fine_chans, + ) + .unwrap(); + + for (i_baseline, (baseline, ref_baseline)) in crosses_fb .axis_iter(Axis(1)) - .zip_eq(ref_crosses.axis_iter(Axis(1))) + .zip_eq(ref_crosses_fb.axis_iter(Axis(1))) .enumerate() { let (tile1, tile2) = tile_baseline_flags.unflagged_cross_baseline_to_tile_map[&i_baseline]; @@ -300,9 +258,9 @@ fn test_solutions_apply_trivial(input_data: &dyn VisRead, metafits: &Path) { max_relative = 1e-7 ); } - for (i_tile, (tile, ref_tile)) in autos + for (i_tile, (tile, ref_tile)) in autos_fb .axis_iter(Axis(1)) - .zip_eq(ref_autos.axis_iter(Axis(1))) + .zip_eq(ref_autos_fb.axis_iter(Axis(1))) .enumerate() { let tile_factor = tile_baseline_flags.unflagged_auto_index_to_tile_map[&i_tile]; @@ -310,77 +268,90 @@ fn test_solutions_apply_trivial(input_data: &dyn VisRead, metafits: &Path) { ref_tile.mapv(Jones::::from) * (tile_factor + 1) as f64 * (tile_factor + 1) as f64; assert_relative_eq!(tile.mapv(Jones::::from), ref_tile, max_relative = 1e-7); } - assert_abs_diff_eq!(cross_weights, ref_cross_weights); - assert_abs_diff_eq!(auto_weights, ref_auto_weights); + assert_abs_diff_eq!(cross_weights_fb, ref_cross_weights_fb); + assert_abs_diff_eq!(auto_weights_fb, ref_auto_weights_fb); // Finally, flag some channels. - flagged_fine_chans.insert(3); - flagged_fine_chans.insert(18); - input_data + let flags = [3, 18]; + params + .input_vis_params + .spw + .flagged_chan_indices + .extend(flags); + params + .input_vis_params + .spw + .flagged_chanblock_indices + .extend(flags); + let flagged_fine_chans = ¶ms.input_vis_params.spw.flagged_chan_indices; + // Remove the newly-flagged chanblocks. (This is awkward because SPWs + // weren't designed to be modified.) + params.input_vis_params.spw.chanblocks = (0..) + .zip(params.input_vis_params.spw.chanblocks.into_iter()) + .filter(|(i, _)| !flagged_fine_chans.contains(i)) + .map(|(_, c)| c) + .collect(); + params.run().unwrap(); + + // Read the output visibilities. + let output_data = UvfitsReader::new(output, Some(&metafits), None).unwrap(); + crosses_fb.fill(Jones::default()); + cross_weights_fb.fill(0.0); + autos_fb.fill(Jones::default()); + auto_weights_fb.fill(0.0); + output_data .read_crosses_and_autos( - ref_crosses.view_mut(), - ref_cross_weights.view_mut(), - ref_autos.view_mut(), - ref_auto_weights.view_mut(), - obs_context.unflagged_timesteps[0], - &tile_baseline_flags, + crosses_fb.view_mut(), + cross_weights_fb.view_mut(), + autos_fb.view_mut(), + auto_weights_fb.view_mut(), + 0, + tile_baseline_flags, // We want to read all channels, even the flagged ones. &HashSet::new(), ) .unwrap(); - apply_solutions_inner( - input_data, - &sols, - ×teps, - LatLngHeight::mwa(), - Duration::default(), - false, - &tile_baseline_flags, - &flagged_fine_chans, - true, - &outputs, - 1, - 1, - true, - ) - .unwrap(); - - // Read the output visibilities. - let output_data = UvfitsReader::new(output, Some(metafits), None).unwrap(); - crosses.fill(Jones::default()); - cross_weights.fill(0.0); - autos.fill(Jones::default()); - auto_weights.fill(0.0); - output_data + // Read in the raw data without any solutions or flags being applied. + params + .input_vis_params + .vis_reader .read_crosses_and_autos( - crosses.view_mut(), - cross_weights.view_mut(), - autos.view_mut(), - auto_weights.view_mut(), - 0, - &tile_baseline_flags, + ref_crosses_fb.view_mut(), + ref_cross_weights_fb.view_mut(), + ref_autos_fb.view_mut(), + ref_auto_weights_fb.view_mut(), + obs_context.unflagged_timesteps[0], + tile_baseline_flags, &HashSet::new(), ) .unwrap(); - for (i_baseline, (baseline, ref_baseline)) in crosses + // Manually flag the flagged channels. + for &f in flagged_fine_chans { + let f = usize::from(f); + ref_crosses_fb.slice_mut(s![f, ..]).fill(Jones::default()); + ref_cross_weights_fb.slice_mut(s![f, ..]).fill(-0.0); + ref_autos_fb.slice_mut(s![f, ..]).fill(Jones::default()); + ref_auto_weights_fb.slice_mut(s![f, ..]).fill(-0.0); + } + for (i_baseline, (baseline_f, ref_baseline_f)) in crosses_fb .axis_iter(Axis(1)) - .zip_eq(ref_crosses.axis_iter(Axis(1))) + .zip_eq(ref_crosses_fb.axis_iter(Axis(1))) .enumerate() { let (tile1, tile2) = tile_baseline_flags.unflagged_cross_baseline_to_tile_map[&i_baseline]; - let ref_baseline = - ref_baseline.mapv(Jones::::from) * (tile1 + 1) as f64 * (tile2 + 1) as f64; + let ref_baseline_f = + ref_baseline_f.mapv(Jones::::from) * (tile1 + 1) as f64 * (tile2 + 1) as f64; assert_relative_eq!( - baseline.mapv(Jones::::from), - ref_baseline, + baseline_f.mapv(Jones::::from), + ref_baseline_f, max_relative = 1e-7 ); } - for (i_tile, (tile, ref_tile)) in autos + for (i_tile, (tile, ref_tile)) in autos_fb .axis_iter(Axis(1)) - .zip_eq(ref_autos.axis_iter(Axis(1))) + .zip_eq(ref_autos_fb.axis_iter(Axis(1))) .enumerate() { let tile_factor = tile_baseline_flags.unflagged_auto_index_to_tile_map[&i_tile]; @@ -388,40 +359,34 @@ fn test_solutions_apply_trivial(input_data: &dyn VisRead, metafits: &Path) { ref_tile.mapv(Jones::::from) * (tile_factor + 1) as f64 * (tile_factor + 1) as f64; assert_relative_eq!(tile.mapv(Jones::::from), ref_tile, max_relative = 1e-7); } - // Manually negate the weights corresponding to our flagged channels. - for c in flagged_fine_chans { - ref_cross_weights - .slice_mut(s![c, ..]) - .map_inplace(|w| *w = -w.abs()); - ref_auto_weights - .slice_mut(s![c, ..]) - .map_inplace(|w| *w = -w.abs()); - } - assert_abs_diff_eq!(cross_weights, ref_cross_weights); - assert_abs_diff_eq!(auto_weights, ref_auto_weights); + assert_abs_diff_eq!(cross_weights_fb, ref_cross_weights_fb); + assert_abs_diff_eq!(auto_weights_fb, ref_auto_weights_fb); } #[test] fn test_solutions_apply_trivial_raw() { - let cal_args = get_reduced_1090008640(false, false); - let mut data = cal_args.data.unwrap().into_iter(); - let metafits = PathBuf::from(data.next().unwrap()); - let gpubox = PathBuf::from(data.next().unwrap()); - let input_data = RawDataReader::new( - &metafits, - &[gpubox], - None, - RawDataCorrections { - pfb_flavour: PfbFlavour::None, - digital_gains: false, - cable_length: false, - geometric: false, + let DataAsStrings { + metafits, + vis: mut files, + mwafs: _, + srclist: _, + } = get_reduced_1090008640_raw(); + files.push(metafits); + + let args = SolutionsApplyArgs { + data_args: InputVisArgs { + files: Some(files), + pfb_flavour: Some("none".to_string()), + no_digital_gains: false, + no_cable_length_correction: false, + no_geometric_correction: false, + ignore_input_data_fine_channel_flags: true, + ..Default::default() }, - None, - ) - .unwrap(); + ..Default::default() + }; - test_solutions_apply_trivial(&input_data, &metafits) + test_solutions_apply_trivial(args) } // If all data-reading routines are working correctly, these extra tests are @@ -430,29 +395,54 @@ fn test_solutions_apply_trivial_raw() { #[test] #[serial] fn test_solutions_apply_trivial_ms() { - let cal_args = get_reduced_1090008640_ms(); - let mut data = cal_args.data.unwrap().into_iter(); - let metafits = PathBuf::from(data.next().unwrap()); - let ms = PathBuf::from(data.next().unwrap()); - let input_data = MsReader::new(ms, None, Some(&metafits), None).unwrap(); + let DataAsStrings { + metafits, + vis: mut files, + mwafs: _, + srclist: _, + } = get_reduced_1090008640_ms(); + files.push(metafits); + + let args = SolutionsApplyArgs { + data_args: InputVisArgs { + files: Some(files), + ignore_input_data_fine_channel_flags: true, + ..Default::default() + }, + ..Default::default() + }; - test_solutions_apply_trivial(&input_data, &metafits) + test_solutions_apply_trivial(args) } #[test] fn test_solutions_apply_trivial_uvfits() { - let cal_args = get_reduced_1090008640_uvfits(); - let mut data = cal_args.data.unwrap().into_iter(); - let metafits = PathBuf::from(data.next().unwrap()); - let uvfits = PathBuf::from(data.next().unwrap()); - let input_data = UvfitsReader::new(uvfits, Some(&metafits), None).unwrap(); + let DataAsStrings { + metafits, + vis: mut files, + mwafs: _, + srclist: _, + } = get_reduced_1090008640_uvfits(); + files.push(metafits); + + let args = SolutionsApplyArgs { + data_args: InputVisArgs { + files: Some(files), + ignore_input_data_fine_channel_flags: true, + ..Default::default() + }, + ..Default::default() + }; - test_solutions_apply_trivial(&input_data, &metafits) + test_solutions_apply_trivial(args) } pub(crate) fn get_1090008640_identity_solutions_file(tmp_dir: &Path) -> PathBuf { let sols = CalibrationSolutions { di_jones: Array3::from_elem((1, 128, 32), Jones::identity()), + chanblock_freqs: Some( + Vec1::try_from(Array1::linspace(196495000.0, 197735000.0, 32).into_raw_vec()).unwrap(), + ), ..Default::default() }; let file = tmp_dir.join("sols.fits"); @@ -463,20 +453,16 @@ pub(crate) fn get_1090008640_identity_solutions_file(tmp_dir: &Path) -> PathBuf #[test] fn test_1090008640_solutions_apply_writes_vis_uvfits() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let args = get_reduced_1090008640(false, false); - let data = args.data.unwrap(); - let metafits = &data[0]; - let gpubox = &data[1]; + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); let solutions = get_1090008640_identity_solutions_file(tmp_dir.path()); let out_vis_path = tmp_dir.path().join("vis.uvfits"); #[rustfmt::skip] let args = SolutionsApplyArgs::parse_from([ "solutions-apply", - "--data", metafits, gpubox, + "--data", &format!("{}", metafits.display()), &format!("{}", vis[0].display()), "--solutions", &format!("{}", solutions.display()), "--outputs", &format!("{}", out_vis_path.display()), - "--no-progress-bars", ]); // Run solutions-apply and check that it succeeds @@ -503,21 +489,17 @@ fn test_1090008640_solutions_apply_writes_vis_uvfits() { #[test] fn test_1090008640_solutions_apply_writes_vis_uvfits_no_autos() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let args = get_reduced_1090008640(false, false); - let data = args.data.unwrap(); - let metafits = &data[0]; - let gpubox = &data[1]; + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); let solutions = get_1090008640_identity_solutions_file(tmp_dir.path()); let out_vis_path = tmp_dir.path().join("vis.uvfits"); #[rustfmt::skip] let args = SolutionsApplyArgs::parse_from([ "solutions-apply", - "--data", metafits, gpubox, + "--data", &format!("{}", metafits.display()), &format!("{}", vis[0].display()), "--solutions", &format!("{}", solutions.display()), "--outputs", &format!("{}", out_vis_path.display()), "--no-autos", - "--no-progress-bars", ]); // Run solutions-apply and check that it succeeds @@ -544,10 +526,9 @@ fn test_1090008640_solutions_apply_writes_vis_uvfits_no_autos() { #[test] fn test_1090008640_solutions_apply_writes_vis_uvfits_avg_freq() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let args = get_reduced_1090008640(false, false); - let data = args.data.unwrap(); - let metafits = &data[0]; - let gpubox = &data[1]; + let DataAsStrings { + metafits, mut vis, .. + } = get_reduced_1090008640_raw(); let solutions = get_1090008640_identity_solutions_file(tmp_dir.path()); let out_vis_path = tmp_dir.path().join("vis.uvfits"); @@ -556,11 +537,10 @@ fn test_1090008640_solutions_apply_writes_vis_uvfits_avg_freq() { #[rustfmt::skip] let args = SolutionsApplyArgs::parse_from([ "solutions-apply", - "--data", metafits, gpubox, + "--data", &metafits, &vis.swap_remove(0), "--solutions", &format!("{}", solutions.display()), "--outputs", &format!("{}", out_vis_path.display()), "--freq-average", &format!("{freq_avg_factor}"), - "--no-progress-bars", ]); // Run solutions-apply and check that it succeeds @@ -571,7 +551,7 @@ fn test_1090008640_solutions_apply_writes_vis_uvfits_avg_freq() { assert!(out_vis_path.exists(), "out vis file not written"); let exp_timesteps = 1; let exp_baselines = 8256; - let exp_channels = 32 / freq_avg_factor; + let exp_channels = 16; let mut out_vis = fits_open(&out_vis_path).unwrap(); let hdu0 = fits_open_hdu(&mut out_vis, 0).unwrap(); @@ -581,6 +561,7 @@ fn test_1090008640_solutions_apply_writes_vis_uvfits_avg_freq() { exp_timesteps * exp_baselines ); let num_fine_freq_chans: String = fits_get_required_key(&mut out_vis, &hdu0, "NAXIS4").unwrap(); + std::fs::copy(out_vis_path, PathBuf::from("/tmp/hyp_test.uvfits")).unwrap(); assert_eq!(num_fine_freq_chans.parse::().unwrap(), exp_channels); } @@ -588,10 +569,7 @@ fn test_1090008640_solutions_apply_writes_vis_uvfits_avg_freq() { #[serial] fn test_1090008640_solutions_apply_writes_vis_uvfits_and_ms() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let args = get_reduced_1090008640(false, false); - let data = args.data.unwrap(); - let metafits = &data[0]; - let gpubox = &data[1]; + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); let solutions = get_1090008640_identity_solutions_file(tmp_dir.path()); let out_uvfits_path = tmp_dir.path().join("vis.uvfits"); let out_ms_path = tmp_dir.path().join("vis.ms"); @@ -599,12 +577,11 @@ fn test_1090008640_solutions_apply_writes_vis_uvfits_and_ms() { #[rustfmt::skip] let args = SolutionsApplyArgs::parse_from([ "solutions-apply", - "--data", metafits, gpubox, + "--data", &format!("{}", metafits.display()), &format!("{}", vis[0].display()), "--solutions", &format!("{}", solutions.display()), "--outputs", &format!("{}", out_uvfits_path.display()), &format!("{}", out_ms_path.display()), - "--no-progress-bars", ]); // Run solutions-apply and check that it succeeds @@ -616,15 +593,14 @@ fn test_1090008640_solutions_apply_writes_vis_uvfits_and_ms() { let exp_timesteps = 1; let exp_channels = 32; - let uvfits_data = - UvfitsReader::new(out_uvfits_path, Some(&PathBuf::from(metafits)), None).unwrap(); + let uvfits_data = UvfitsReader::new(out_uvfits_path, Some(&metafits), None).unwrap(); let uvfits_ctx = uvfits_data.get_obs_context(); // check ms file has been created, is readable assert!(out_ms_path.exists(), "out vis file not written"); - let ms_data = MsReader::new(out_ms_path, None, Some(&PathBuf::from(metafits)), None).unwrap(); + let ms_data = MsReader::new(out_ms_path, None, Some(&metafits), None).unwrap(); let ms_ctx = ms_data.get_obs_context(); @@ -656,20 +632,20 @@ fn test_1090008640_solutions_apply_correct_vis() { let flagged_tiles = HashSet::from([1, 3, 5]); - let args = get_reduced_1090008640_uvfits(); - let data = args.data.unwrap(); - let metafits = &data[0]; - let uvfits = &data[1]; + let DataAsStrings { + metafits, mut vis, .. + } = get_reduced_1090008640_uvfits(); + let metafits_pb = PathBuf::from(&metafits); + let uvfits = vis.swap_remove(0); let vis_out = tmp_dir.path().join("vis.uvfits"); let vis_out_string = vis_out.display().to_string(); #[rustfmt::skip] let mut args = vec![ "solutions-apply", - "--data", metafits, uvfits, + "--data", &metafits, &uvfits, "--solutions", &sols_file_string, "--outputs", &vis_out_string, - "--no-progress-bars" ]; let flag_strings = flagged_tiles .iter() @@ -689,8 +665,8 @@ fn test_1090008640_solutions_apply_correct_vis() { assert!(vis_out.exists(), "out vis file not written"); let uncal_reader = - UvfitsReader::new(PathBuf::from(uvfits), Some(&PathBuf::from(metafits)), None).unwrap(); - let cal_reader = UvfitsReader::new(vis_out, Some(&PathBuf::from(metafits)), None).unwrap(); + UvfitsReader::new(PathBuf::from(uvfits.clone()), Some(&metafits_pb), None).unwrap(); + let cal_reader = UvfitsReader::new(vis_out, Some(&metafits_pb), None).unwrap(); let obs_context = cal_reader.get_obs_context(); let total_num_tiles = obs_context.get_total_num_tiles(); @@ -814,13 +790,9 @@ fn test_1090008640_solutions_apply_correct_vis() { #[rustfmt::skip] let mut args = vec![ "solutions-apply", - "--data", metafits, uvfits, + "--data", &metafits, &uvfits, "--solutions", &sols_file_string, "--outputs", &vis_out_string, - "--no-progress-bars", - // Deliberately ignore solution tile flags, otherwise the code will - // write out a different number of baselines. - "--ignore-input-solutions-tile-flags" ]; if !flag_strings.is_empty() { args.push("--tile-flags"); @@ -835,7 +807,7 @@ fn test_1090008640_solutions_apply_correct_vis() { assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); assert!(vis_out.exists(), "out vis file not written"); - let cal2_reader = UvfitsReader::new(vis_out, Some(&PathBuf::from(metafits)), None).unwrap(); + let cal2_reader = UvfitsReader::new(vis_out, Some(&metafits_pb), None).unwrap(); let mut cal2_cross_vis_data = Array2::zeros(( obs_context.fine_chan_freqs.len(), num_unflagged_cross_baselines, @@ -912,3 +884,299 @@ fn test_1090008640_solutions_apply_correct_vis() { } } } + +#[test] +fn test_solutions_apply_works_with_implicit_or_explicit_sols() { + let DataAsStrings { + metafits, + vis: mut files, + mwafs: _, + srclist: _, + } = get_reduced_1090008640_raw(); + files.push(metafits.clone()); + + let mut args = SolutionsApplyArgs { + data_args: InputVisArgs { + files: Some(files), + pfb_flavour: Some("none".to_string()), + no_digital_gains: false, + no_cable_length_correction: false, + no_geometric_correction: false, + ignore_input_data_fine_channel_flags: true, + ..Default::default() + }, + ..Default::default() + }; + + let tmp_dir = TempDir::new().unwrap(); + let explicit_output = tmp_dir.path().join("explicit.uvfits"); + let sols_file = get_1090008640_identity_solutions_file(tmp_dir.path()); + args.solutions = Some(sols_file.display().to_string()); + args.outputs = Some(vec![explicit_output.clone()]); + args.clone().parse().unwrap().run().unwrap(); + + let implicit_output = tmp_dir.path().join("implicit.uvfits"); + args.data_args + .files + .as_mut() + .unwrap() + .push(sols_file.display().to_string()); + args.solutions = None; + args.outputs = Some(vec![implicit_output.clone()]); + args.parse().unwrap().run().unwrap(); + + // The visibilities should be exactly the same. + let error = AtomicCell::new(false); + let explicit_vis_params = InputVisArgs { + files: Some(vec![ + metafits.clone(), + explicit_output.display().to_string(), + ]), + ..Default::default() + } + .parse("") + .unwrap(); + let obs_context = explicit_vis_params.get_obs_context(); + let total_num_tiles = obs_context.get_total_num_tiles(); + let total_num_baselines = (total_num_tiles * (total_num_tiles - 1)) / 2; + let total_num_channels = obs_context.fine_chan_freqs.len(); + let mut explicit_crosses = Array3::zeros(( + explicit_vis_params.timeblocks.len(), + total_num_channels, + total_num_baselines, + )); + let mut explicit_cross_weights = Array3::zeros(( + explicit_vis_params.timeblocks.len(), + total_num_channels, + total_num_baselines, + )); + let mut explicit_autos = Array3::zeros(( + explicit_vis_params.timeblocks.len(), + total_num_channels, + total_num_tiles, + )); + let mut explicit_auto_weights = Array3::zeros(( + explicit_vis_params.timeblocks.len(), + total_num_channels, + total_num_tiles, + )); + for (timeblock, crosses, cross_weights, autos, auto_weights) in izip!( + explicit_vis_params.timeblocks.iter(), + explicit_crosses.outer_iter_mut(), + explicit_cross_weights.outer_iter_mut(), + explicit_autos.outer_iter_mut(), + explicit_auto_weights.outer_iter_mut() + ) { + explicit_vis_params + .read_timeblock( + timeblock, + crosses, + cross_weights, + Some((autos, auto_weights)), + &error, + ) + .unwrap(); + } + + let implicit_vis_params = InputVisArgs { + files: Some(vec![metafits, implicit_output.display().to_string()]), + ..Default::default() + } + .parse("") + .unwrap(); + let obs_context = implicit_vis_params.get_obs_context(); + assert_eq!(total_num_tiles, obs_context.get_total_num_tiles()); + assert_eq!( + total_num_baselines, + (total_num_tiles * (total_num_tiles - 1)) / 2 + ); + assert_eq!(total_num_channels, obs_context.fine_chan_freqs.len()); + let mut implicit_crosses = Array3::zeros(( + implicit_vis_params.timeblocks.len(), + total_num_channels, + total_num_baselines, + )); + let mut implicit_cross_weights = Array3::zeros(( + implicit_vis_params.timeblocks.len(), + total_num_channels, + total_num_baselines, + )); + let mut implicit_autos = Array3::zeros(( + implicit_vis_params.timeblocks.len(), + total_num_channels, + total_num_tiles, + )); + let mut implicit_auto_weights = Array3::zeros(( + implicit_vis_params.timeblocks.len(), + total_num_channels, + total_num_tiles, + )); + for (timeblock, crosses, cross_weights, autos, auto_weights) in izip!( + implicit_vis_params.timeblocks.iter(), + implicit_crosses.outer_iter_mut(), + implicit_cross_weights.outer_iter_mut(), + implicit_autos.outer_iter_mut(), + implicit_auto_weights.outer_iter_mut() + ) { + implicit_vis_params + .read_timeblock( + timeblock, + crosses, + cross_weights, + Some((autos, auto_weights)), + &error, + ) + .unwrap(); + } + + assert_abs_diff_eq!(explicit_crosses, implicit_crosses); + assert_abs_diff_eq!(explicit_cross_weights, implicit_cross_weights); + assert_abs_diff_eq!(explicit_autos, implicit_autos); + assert_abs_diff_eq!(explicit_auto_weights, implicit_auto_weights); +} + +#[test] +fn test_solutions_apply_needs_sols_but_is_otherwise_vis_convert() { + let DataAsStrings { + metafits, + vis: mut files, + mwafs: _, + srclist: _, + } = get_reduced_1090008640_raw(); + files.push(metafits.clone()); + + let mut args = SolutionsApplyArgs { + data_args: InputVisArgs { + files: Some(files), + pfb_flavour: Some("none".to_string()), + no_digital_gains: false, + no_cable_length_correction: false, + no_geometric_correction: false, + ignore_input_data_fine_channel_flags: true, + ..Default::default() + }, + ..Default::default() + }; + + let tmp_dir = TempDir::new().unwrap(); + let apply_output = tmp_dir.path().join("apply.uvfits"); + let sols_file = get_1090008640_identity_solutions_file(tmp_dir.path()); + args.solutions = Some(sols_file.display().to_string()); + args.outputs = Some(vec![apply_output.clone()]); + args.clone().parse().unwrap().run().unwrap(); + + let convert_output = tmp_dir.path().join("convert.uvfits"); + let args = VisConvertArgs { + data_args: args.data_args, + outputs: Some(vec![convert_output.clone()]), + ..Default::default() + }; + args.parse().unwrap().run().unwrap(); + + // Seeing as the solutions are identities, the visibilities should be + // exactly the same. + let error = AtomicCell::new(false); + let apply_vis_params = InputVisArgs { + files: Some(vec![metafits.clone(), apply_output.display().to_string()]), + ..Default::default() + } + .parse("") + .unwrap(); + let obs_context = apply_vis_params.get_obs_context(); + let total_num_tiles = obs_context.get_total_num_tiles(); + let total_num_baselines = (total_num_tiles * (total_num_tiles - 1)) / 2; + let total_num_channels = obs_context.fine_chan_freqs.len(); + let mut apply_crosses = Array3::zeros(( + apply_vis_params.timeblocks.len(), + total_num_channels, + total_num_baselines, + )); + let mut apply_cross_weights = Array3::zeros(( + apply_vis_params.timeblocks.len(), + total_num_channels, + total_num_baselines, + )); + let mut apply_autos = Array3::zeros(( + apply_vis_params.timeblocks.len(), + total_num_channels, + total_num_tiles, + )); + let mut apply_auto_weights = Array3::zeros(( + apply_vis_params.timeblocks.len(), + total_num_channels, + total_num_tiles, + )); + for (timeblock, crosses, cross_weights, autos, auto_weights) in izip!( + apply_vis_params.timeblocks.iter(), + apply_crosses.outer_iter_mut(), + apply_cross_weights.outer_iter_mut(), + apply_autos.outer_iter_mut(), + apply_auto_weights.outer_iter_mut() + ) { + apply_vis_params + .read_timeblock( + timeblock, + crosses, + cross_weights, + Some((autos, auto_weights)), + &error, + ) + .unwrap(); + } + + let convert_vis_params = InputVisArgs { + files: Some(vec![metafits, convert_output.display().to_string()]), + ..Default::default() + } + .parse("") + .unwrap(); + let obs_context = convert_vis_params.get_obs_context(); + assert_eq!(total_num_tiles, obs_context.get_total_num_tiles()); + assert_eq!( + total_num_baselines, + (total_num_tiles * (total_num_tiles - 1)) / 2 + ); + assert_eq!(total_num_channels, obs_context.fine_chan_freqs.len()); + let mut convert_crosses = Array3::zeros(( + convert_vis_params.timeblocks.len(), + total_num_channels, + total_num_baselines, + )); + let mut convert_cross_weights = Array3::zeros(( + convert_vis_params.timeblocks.len(), + total_num_channels, + total_num_baselines, + )); + let mut convert_autos = Array3::zeros(( + convert_vis_params.timeblocks.len(), + total_num_channels, + total_num_tiles, + )); + let mut convert_auto_weights = Array3::zeros(( + convert_vis_params.timeblocks.len(), + total_num_channels, + total_num_tiles, + )); + for (timeblock, crosses, cross_weights, autos, auto_weights) in izip!( + convert_vis_params.timeblocks.iter(), + convert_crosses.outer_iter_mut(), + convert_cross_weights.outer_iter_mut(), + convert_autos.outer_iter_mut(), + convert_auto_weights.outer_iter_mut() + ) { + convert_vis_params + .read_timeblock( + timeblock, + crosses, + cross_weights, + Some((autos, auto_weights)), + &error, + ) + .unwrap(); + } + + assert_abs_diff_eq!(apply_crosses, convert_crosses); + assert_abs_diff_eq!(apply_cross_weights, convert_cross_weights); + assert_abs_diff_eq!(apply_autos, convert_autos); + assert_abs_diff_eq!(apply_auto_weights, convert_auto_weights); +} diff --git a/src/cli/solutions/convert.rs b/src/cli/solutions/convert.rs index cc2359f4..8ebe2d16 100644 --- a/src/cli/solutions/convert.rs +++ b/src/cli/solutions/convert.rs @@ -9,10 +9,10 @@ use std::path::PathBuf; use clap::Parser; use log::info; -use crate::{solutions::CalibrationSolutions, HyperdriveError}; +use crate::{cli::common::display_warnings, solutions::CalibrationSolutions, HyperdriveError}; #[derive(Parser, Debug, Default)] -pub struct SolutionsConvertArgs { +pub(crate) struct SolutionsConvertArgs { /// The path to the input file. If this is a directory instead, then we /// attempt to read RTS calibration files in the directory. #[clap(name = "INPUT_SOLUTIONS_FILE", parse(from_os_str))] @@ -34,6 +34,8 @@ impl SolutionsConvertArgs { CalibrationSolutions::read_solutions_from_ext(&self.input, self.metafits.as_ref())?; sols.write_solutions_from_ext(&self.output)?; + display_warnings(); + info!( "Converted {} to {}", self.input.display(), diff --git a/src/cli/solutions/mod.rs b/src/cli/solutions/mod.rs index d7428693..b3386898 100644 --- a/src/cli/solutions/mod.rs +++ b/src/cli/solutions/mod.rs @@ -2,6 +2,10 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -pub(crate) mod apply; -pub(crate) mod convert; -pub(crate) mod plot; +mod apply; +mod convert; +mod plot; + +pub(super) use apply::{SolutionsApplyArgs, SolutionsApplyArgsError}; +pub(super) use convert::SolutionsConvertArgs; +pub(super) use plot::{SolutionsPlotArgs, SolutionsPlotError}; diff --git a/src/cli/solutions/plot/mod.rs b/src/cli/solutions/plot/mod.rs index 12c418fd..31518cea 100644 --- a/src/cli/solutions/plot/mod.rs +++ b/src/cli/solutions/plot/mod.rs @@ -15,7 +15,7 @@ use clap::Parser; use crate::HyperdriveError; #[derive(Parser, Debug, Default)] -pub struct SolutionsPlotArgs { +pub(crate) struct SolutionsPlotArgs { #[clap(name = "SOLUTIONS_FILES", parse(from_os_str))] files: Vec, @@ -65,7 +65,7 @@ pub struct SolutionsPlotArgs { impl SolutionsPlotArgs { #[cfg(not(feature = "plotting"))] - pub fn run(self) -> Result<(), HyperdriveError> { + pub(crate) fn run(self) -> Result<(), HyperdriveError> { // Plotting is an optional feature. This is because it doesn't look // possible to statically compile the C dependencies needed for // plotting. If the "plotting" feature isn't available, warn the user @@ -74,7 +74,7 @@ impl SolutionsPlotArgs { } #[cfg(feature = "plotting")] - pub fn run(self) -> Result<(), HyperdriveError> { + pub(crate) fn run(self) -> Result<(), HyperdriveError> { plotting::plot_all_sol_files(self)?; Ok(()) } @@ -207,6 +207,9 @@ mod plotting { ); let tile_names = sols.tile_names.as_ref().or(mwalib_tile_names.as_ref()); if tile_names.is_none() && !warned_no_tile_names { + // N.B. Not using `crate::cli::Warn` here because multiple + // calibration solutions may be plotted, and we want the user to + // see the warnings for each file. warn!("No metafits supplied; the obsid and tile names won't be on the plots"); warned_no_tile_names = true; } diff --git a/src/cli/srclist/by_beam/mod.rs b/src/cli/srclist/by_beam/mod.rs index e23922f0..3d887356 100644 --- a/src/cli/srclist/by_beam/mod.rs +++ b/src/cli/srclist/by_beam/mod.rs @@ -8,28 +8,27 @@ mod tests; use std::{ - fs::File, - io::{BufWriter, Write}, + borrow::Cow, path::{Path, PathBuf}, str::FromStr, }; use clap::Parser; use itertools::Itertools; -use log::{debug, info, trace, warn}; -use marlu::RADec; +use log::{debug, info, trace}; +use marlu::{LatLngHeight, RADec}; use crate::{ - beam::{create_fee_beam_object, Delays}, - constants::{DEFAULT_CUTOFF_DISTANCE, DEFAULT_VETO_THRESHOLD}, - help_texts::{ - SOURCE_DIST_CUTOFF_HELP, SOURCE_LIST_INPUT_TYPE_HELP, SOURCE_LIST_OUTPUT_TYPE_HELP, - VETO_THRESHOLD_HELP, + beam::Delays, + cli::common::{ + display_warnings, BeamArgs, Warn, ARRAY_POSITION_HELP, SOURCE_DIST_CUTOFF_HELP, + SOURCE_LIST_INPUT_TYPE_HELP, SOURCE_LIST_OUTPUT_TYPE_HELP, VETO_THRESHOLD_HELP, }, + constants::{DEFAULT_CUTOFF_DISTANCE, DEFAULT_VETO_THRESHOLD}, metafits::get_dipole_delays, srclist::{ - ao, hyperdrive, read::read_source_list_file, rts, veto_sources, woden, HyperdriveFileType, - SourceList, SourceListType, SrclistError, WriteSourceListError, + read::read_source_list_file, veto_sources, write_source_list, ReadSourceListError, + SourceList, SourceListType, WriteSourceListError, }, HyperdriveError, }; @@ -45,7 +44,7 @@ pub struct SrclistByBeamArgs { #[clap( name = "INPUT_SOURCE_LIST", parse(from_os_str), - help_heading = "INPUT/OUTPUT FILES" + help_heading = "INPUT FILES" )] input_source_list: PathBuf, @@ -54,29 +53,60 @@ pub struct SrclistByBeamArgs { #[clap( name = "OUTPUT_SOURCE_LIST", parse(from_os_str), - help_heading = "INPUT/OUTPUT FILES" + help_heading = "OUTPUT FILES" )] output_source_list: Option, - #[clap(short = 'i', long, parse(from_str), help = SOURCE_LIST_INPUT_TYPE_HELP.as_str(), help_heading = "INPUT/OUTPUT FILES")] + #[clap(short = 'i', long, parse(from_str), help = SOURCE_LIST_INPUT_TYPE_HELP.as_str(), help_heading = "INPUT FILES")] input_type: Option, - #[clap(short = 'o', long, parse(from_str), help = SOURCE_LIST_OUTPUT_TYPE_HELP.as_str(), help_heading = "INPUT/OUTPUT FILES")] + #[clap(short = 'o', long, parse(from_str), help = SOURCE_LIST_OUTPUT_TYPE_HELP.as_str(), help_heading = "OUTPUT FILES")] output_type: Option, - /// Path to the metafits file. + /// Path to the metafits file, which contains the metadata needed to veto + /// sources. + #[clap(short = 'm', long, parse(from_str), help_heading = "METADATA")] + metafits: Option, + + #[clap( + long, help = ARRAY_POSITION_HELP.as_str(), help_heading = "METADATA", + number_of_values = 3, + allow_hyphen_values = true, + value_names = &["LONG_DEG", "LAT_DEG", "HEIGHT_M"] + )] + array_position: Option>, + + /// The LST in radians. Overrides the value in the metafits. + #[clap( + long = "lst", + help_heading = "METADATA", + allow_hyphen_values = true, + required_unless_present = "metafits" + )] + lst_rad: Option, + + /// The RA and Dec. phase centre of the observation in degrees. Overrides + /// the value in metafits. #[clap( - short = 'm', long, - parse(from_str), - help_heading = "INPUT/OUTPUT FILES" + help_heading = "METADATA", + number_of_values = 2, + allow_hyphen_values = true, + value_names = &["RA", "DEC"], + required_unless_present = "metafits" )] - metafits: PathBuf, + phase_centre: Option>, - /// Path to the MWA FEE beam file. If this is not specified, then the - /// MWA_BEAM_FILE environment variable should contain the path. - #[clap(short, long, help_heading = "INPUT/OUTPUT FILES")] - beam_file: Option, + /// A representative sample of frequencies in the observation [Hz]; it's + /// typical to use the centre frequencies of each MWA coarse channel. + /// Overrides the coarse channels in the metafits. + #[clap( + long = "freqs", + help_heading = "METADATA", + multiple_values(true), + required_unless_present = "metafits" + )] + freqs_hz: Option>, /// Reduce the input source list to the brightest N sources and write them /// to the output source list. If the input source list has less than N @@ -113,270 +143,293 @@ pub struct SrclistByBeamArgs { /// the base source; this is very important for RTS DI calibration. #[clap(long, help_heading = "RTS-ONLY ARGUMENTS")] rts_base_source: Option, + + #[clap(flatten)] + beam_args: BeamArgs, } impl SrclistByBeamArgs { - /// Run [by_beam] with these arguments. - pub fn run(&self) -> Result<(), HyperdriveError> { + /// Run [`by_beam`] with these arguments. + pub fn run(self) -> Result<(), HyperdriveError> { by_beam( &self.input_source_list, - self.output_source_list.as_ref(), - self.input_type.as_ref(), - self.output_type.as_ref(), + self.output_source_list.as_deref(), + self.input_type.as_deref(), + self.output_type.as_deref(), self.number, - &self.metafits, + self.metafits.as_deref(), + self.array_position.as_ref().map(|a| LatLngHeight { + longitude_rad: a[0].to_radians(), + latitude_rad: a[1].to_radians(), + height_metres: a[2], + }), + self.lst_rad, + self.phase_centre + .as_ref() + .map(|p| RADec::from_degrees(p[0], p[1])), + self.freqs_hz.as_deref(), self.source_dist_cutoff, self.veto_threshold, - self.beam_file.as_ref(), self.filter_points, self.filter_gaussians, self.filter_shapelets, self.collapse_into_single_source, - self.rts_base_source.as_ref(), + self.rts_base_source.as_deref(), + self.beam_args, )?; Ok(()) } } +struct Metadata<'a> { + phase_centre: RADec, + array_position: LatLngHeight, + lst_rad: f64, + freqs_hz: Cow<'a, [f64]>, + dipole_delays: Option, +} + #[allow(clippy::too_many_arguments)] -fn by_beam, S: AsRef>( - input_path: P, - output_path: Option

, - input_type: Option, - output_type: Option, - num_sources: usize, - metafits: P, +fn by_beam( + input_path: &Path, + output_path: Option<&Path>, + input_type: Option<&str>, + output_type: Option<&str>, + mut num_sources: usize, + metafits: Option<&Path>, + array_position: Option, + lst_rad: Option, + phase_centre: Option, + freqs_hz: Option<&[f64]>, source_dist_cutoff: Option, veto_threshold: Option, - beam_file: Option

, filter_points: bool, filter_gaussians: bool, filter_shapelets: bool, collapse_into_single_source: bool, - rts_base_source: Option, -) -> Result<(), SrclistError> { - fn inner( - input_path: &Path, - output_path: Option<&Path>, - input_type: Option<&str>, - output_type: Option<&str>, - mut num_sources: usize, - metafits: &Path, - source_dist_cutoff: Option, - veto_threshold: Option, - beam_file: Option<&Path>, - filter_points: bool, - filter_gaussians: bool, - filter_shapelets: bool, - collapse_into_single_source: bool, - rts_base_source: Option<&str>, - ) -> Result<(), SrclistError> { - // Read the input source list. - let input_type = input_type.and_then(|t| SourceListType::from_str(t).ok()); - let (sl, sl_type) = crate::misc::expensive_op( - || read_source_list_file(input_path, input_type), - "Still reading source list file", - )?; - if input_type.is_none() { - info!( - "Successfully read {} as a {}-style source list", - input_path.display(), - sl_type - ); - } - let counts = sl.get_counts(); + rts_base_source: Option<&str>, + beam_args: BeamArgs, +) -> Result<(), SrclistByBeamError> { + // Read the input source list. + let input_type = input_type.and_then(|t| SourceListType::from_str(t).ok()); + let (sl, sl_type) = crate::misc::expensive_op( + || read_source_list_file(input_path, input_type), + "Still reading source list file", + )?; + if input_type.is_none() { info!( - "{} points, {} gaussians, {} shapelets", - counts.num_points, counts.num_gaussians, counts.num_shapelets + "Successfully read {} as a {}-style source list", + input_path.display(), + sl_type ); + } + let counts = sl.get_counts(); + info!( + "{} points, {} gaussians, {} shapelets", + counts.num_points, counts.num_gaussians, counts.num_shapelets + ); + + // Handle the output path and type. + let output_path = match output_path { + Some(p) => p.to_path_buf(), + None => { + let input_path_base = input_path + .file_stem() + .and_then(|os_str| os_str.to_str()) + .expect("Input file didn't have a filename stem"); + let input_path_ext = input_path + .extension() + .and_then(|os_str| os_str.to_str()) + .expect("Input file didn't have an extension"); + let output_pb = + PathBuf::from(format!("{input_path_base}_{num_sources}.{input_path_ext}")); + debug!("Writing reduced source list to {}", output_pb.display()); + output_pb + } + }; - // Handle the output path and type. - let output_path = match output_path { - Some(p) => p.to_path_buf(), - None => { - let input_path_base = input_path - .file_stem() - .and_then(|os_str| os_str.to_str()) - .expect("Input file didn't have a filename stem"); - let input_path_ext = input_path - .extension() - .and_then(|os_str| os_str.to_str()) - .expect("Input file didn't have an extension"); - let output_pb = - PathBuf::from(format!("{input_path_base}_{num_sources}.{input_path_ext}")); - debug!("Writing reduced source list to {}", output_pb.display()); - output_pb - } - }; - let output_ext = output_path.extension().and_then(|e| e.to_str()); - let hyp_file_type = output_ext.and_then(|e| HyperdriveFileType::from_str(e).ok()); - let output_type = match (output_type, &hyp_file_type) { - (Some(t), _) => { - // Try to parse the specified output type. - match SourceListType::from_str(t) { - Ok(t) => t, - Err(_) => return Err(WriteSourceListError::InvalidFormat.into()), - } - } - - (None, Some(_)) => SourceListType::Hyperdrive, - - // Use the input source list type as the output type. - (None, None) => sl_type, - }; - + let metadata = if let Some(metafits) = metafits { // Open the metafits. trace!("Attempting to open the metafits file"); - let meta = mwalib::MetafitsContext::new(metafits, None)?; - let ra_phase_centre = meta - .ra_phase_center_degrees - .unwrap_or(meta.ra_tile_pointing_degrees); - let dec_phase_centre = meta - .dec_phase_center_degrees - .unwrap_or(meta.dec_tile_pointing_degrees); - let phase_centre = RADec::from_degrees(ra_phase_centre, dec_phase_centre); - debug!("Using {} as the phase centre", phase_centre); - let lst = meta.lst_rad; - debug!("Using {}° as the LST", lst.to_degrees()); - let coarse_chan_freqs: Vec = meta - .metafits_coarse_chans - .iter() - .map(|cc| cc.chan_centre_hz as _) - .collect(); - debug!( - "Using coarse channel frequencies [MHz]: {}", - coarse_chan_freqs - .iter() - .map(|cc_freq_hz| format!("{:.2}", *cc_freq_hz / 1e6)) - .join(", ") - ); + let metafits = mwalib::MetafitsContext::new(metafits, None)?; - // Set up the beam. We use the ideal delays for all tiles because we - // don't want to use any dead dipoles. - let mut dipole_delays = Delays::Full(get_dipole_delays(&meta)); + let mut dipole_delays = Delays::Full(get_dipole_delays(&metafits)); dipole_delays.set_to_ideal_delays(); - let beam = create_fee_beam_object(beam_file, 1, dipole_delays, None)?; - - // Apply any filters. - let mut sl = if filter_points || filter_gaussians || filter_shapelets { - let sl = sl.filter(filter_points, filter_gaussians, filter_shapelets); - let counts = sl.get_counts(); - debug!( - "After filtering, there are {} points, {} gaussians, {} shapelets", - counts.num_points, counts.num_gaussians, counts.num_shapelets - ); - sl - } else { - sl - }; - // Veto sources. - veto_sources( - &mut sl, - phase_centre, - lst, - marlu::constants::MWA_LAT_RAD, - &coarse_chan_freqs, - &*beam, - None, - source_dist_cutoff.unwrap_or(DEFAULT_CUTOFF_DISTANCE), - veto_threshold.unwrap_or(DEFAULT_VETO_THRESHOLD), - )?; - // Were any sources left after vetoing? - if sl.is_empty() { - return Err(SrclistError::NoSourcesAfterVeto); + let mut metadata = Metadata { + phase_centre: RADec::from_degrees( + metafits + .ra_phase_center_degrees + .unwrap_or(metafits.ra_tile_pointing_degrees), + metafits + .dec_phase_center_degrees + .unwrap_or(metafits.dec_tile_pointing_degrees), + ), + array_position: LatLngHeight::mwa(), + lst_rad: metafits.lst_rad, + freqs_hz: metafits + .metafits_coarse_chans + .iter() + .map(|cc| cc.chan_centre_hz as _) + .collect(), + dipole_delays: Some(dipole_delays), }; - // If requested, collapse the source list. - sl = if collapse_into_single_source { - let base = rts_base_source - .unwrap_or(sl.get_index(0).unwrap().0) - .to_owned(); - let mut collapsed = SourceList::new(); - let base = sl.remove_entry(&base).unwrap(); - let mut num_collapsed_components = base.1.components.len() - 1; - collapsed.insert(base.0, base.1); - let base_src = collapsed.get_index_mut(0).unwrap().1; - let mut base_comps = vec![].into_boxed_slice(); - std::mem::swap(&mut base_src.components, &mut base_comps); - let mut base_comps = base_comps.into_vec(); - sl.into_iter() - .take(num_sources) - .flat_map(|(_, src)| src.components.to_vec()) - .for_each(|comp| { - num_collapsed_components += 1; - base_comps.push(comp); - }); - std::mem::swap(&mut base_src.components, &mut base_comps.into_boxed_slice()); - info!( - "Collapsed {num_sources} into 1 base source with {num_collapsed_components} components" - ); - num_sources = 1; - collapsed - } else { - if rts_base_source.is_some() { - warn!("RTS base source was supplied, but we're not collapsing the source list into a single source."); - } - sl - }; + // Override metafits values with anything that was manually specified. + if let Some(phase_centre) = phase_centre { + metadata.phase_centre = phase_centre; + } + if let Some(array_position) = array_position { + metadata.array_position = array_position; + } + if let Some(lst_rad) = lst_rad { + metadata.lst_rad = lst_rad; + } + if let Some(freqs_hz) = freqs_hz { + metadata.freqs_hz = freqs_hz.into(); + } - // Write the output source list. - // TODO: De-duplicate this code. - trace!("Attempting to write output source list"); - let mut f = BufWriter::new(File::create(&output_path)?); - - match (output_type, hyp_file_type) { - (SourceListType::Hyperdrive, None) => { - return Err(WriteSourceListError::InvalidHyperdriveFormat( - output_ext.unwrap_or("").to_string(), - ) - .into()) - } - (SourceListType::Rts, _) => { - rts::write_source_list(&mut f, &sl, Some(num_sources))?; - info!("Wrote rts-style source list to {}", output_path.display()); - } - (SourceListType::AO, _) => { - ao::write_source_list(&mut f, &sl, Some(num_sources))?; - info!("Wrote ao-style source list to {}", output_path.display()); - } - (SourceListType::Woden, _) => { - woden::write_source_list(&mut f, &sl, Some(num_sources))?; - info!("Wrote woden-style source list to {}", output_path.display()); - } - (_, Some(HyperdriveFileType::Yaml)) => { - hyperdrive::source_list_to_yaml(&mut f, &sl, Some(num_sources))?; - info!( - "Wrote hyperdrive-style source list to {}", - output_path.display() - ); - } - (_, Some(HyperdriveFileType::Json)) => { - hyperdrive::source_list_to_json(&mut f, &sl, Some(num_sources))?; - info!( - "Wrote hyperdrive-style source list to {}", - output_path.display() - ); - } + metadata + } else { + Metadata { + phase_centre: match phase_centre { + Some(p) => p, + None => return Err(SrclistByBeamError::NoPhaseCentre), + }, + array_position: match array_position { + Some(a) => a, + None => LatLngHeight::mwa(), + }, + lst_rad: match lst_rad { + Some(l) => l, + None => return Err(SrclistByBeamError::NoLst), + }, + freqs_hz: match freqs_hz { + Some(f) => f.into(), + None => return Err(SrclistByBeamError::NoFreqs), + }, + // If the user didn't specify delays on the command line, and delays + // are needed to set up the beam object, an error will be generated + // below. + dipole_delays: None, + } + }; + + debug!("Using {} as the phase centre", metadata.phase_centre); + debug!("Using {}° as the LST", metadata.lst_rad.to_degrees()); + debug!( + "Using coarse channel frequencies [MHz]: {}", + metadata + .freqs_hz + .iter() + .map(|freq_hz| format!("{:.2}", *freq_hz / 1e6)) + .join(", ") + ); + + // Set up the beam. We use the ideal delays for all tiles because we + // don't want to use any dead dipoles. + info!(""); + let beam = beam_args.parse(1, metadata.dipole_delays, None, None)?; + + // Apply any filters. + let mut sl = if filter_points || filter_gaussians || filter_shapelets { + let sl = sl.filter(filter_points, filter_gaussians, filter_shapelets); + let counts = sl.get_counts(); + debug!( + "After filtering, there are {} points, {} gaussians, {} shapelets", + counts.num_points, counts.num_gaussians, counts.num_shapelets + ); + sl + } else { + sl + }; + + // Veto sources. + veto_sources( + &mut sl, + metadata.phase_centre, + metadata.lst_rad, + metadata.array_position.latitude_rad, + &metadata.freqs_hz, + &*beam, + None, + source_dist_cutoff.unwrap_or(DEFAULT_CUTOFF_DISTANCE), + veto_threshold.unwrap_or(DEFAULT_VETO_THRESHOLD), + )?; + // Were any sources left after vetoing? + if sl.is_empty() { + return Err(ReadSourceListError::NoSourcesAfterVeto.into()); + }; + + // If requested, collapse the source list. + sl = if collapse_into_single_source { + let base = rts_base_source + .unwrap_or(sl.get_index(0).unwrap().0) + .to_owned(); + let mut collapsed = SourceList::new(); + let base = sl.remove_entry(&base).unwrap(); + let mut num_collapsed_components = base.1.components.len() - 1; + collapsed.insert(base.0, base.1); + let base_src = collapsed.get_index_mut(0).unwrap().1; + let mut base_comps = vec![].into_boxed_slice(); + std::mem::swap(&mut base_src.components, &mut base_comps); + let mut base_comps = base_comps.to_vec(); + sl.into_iter() + .take(num_sources) + .flat_map(|(_, src)| src.components.to_vec()) + .for_each(|comp| { + num_collapsed_components += 1; + base_comps.push(comp); + }); + std::mem::swap(&mut base_src.components, &mut base_comps.into_boxed_slice()); + info!( + "Collapsed {num_sources} into 1 base source with {num_collapsed_components} components" + ); + num_sources = 1; + collapsed + } else { + if rts_base_source.is_some() { + "RTS base source was supplied, but we're not collapsing the source list into a single source.".warn(); } - f.flush()?; + sl + }; - Ok(()) - } - inner( - input_path.as_ref(), - output_path.as_ref().map(|f| f.as_ref()), - input_type.as_ref().map(|f| f.as_ref()), - output_type.as_ref().map(|f| f.as_ref()), - num_sources, - metafits.as_ref(), - source_dist_cutoff, - veto_threshold, - beam_file.as_ref().map(|f| f.as_ref()), - filter_points, - filter_gaussians, - filter_shapelets, - collapse_into_single_source, - rts_base_source.as_ref().map(|f| f.as_ref()), - ) + write_source_list( + &sl, + &output_path, + sl_type, + output_type.and_then(|s| SourceListType::from_str(s).ok()), + Some(num_sources), + )?; + + display_warnings(); + + Ok(()) +} + +#[derive(thiserror::Error, Debug)] +pub(crate) enum SrclistByBeamError { + #[error("No metafits was supplied and no phase centre was specified; cannot continue")] + NoPhaseCentre, + + #[error("No metafits was supplied and no LST was specified; cannot continue")] + NoLst, + + #[error("No metafits was supplied and no frequencies were specified; cannot continue")] + NoFreqs, + + #[error(transparent)] + ReadSourceList(#[from] ReadSourceListError), + + #[error(transparent)] + WriteSourceList(#[from] WriteSourceListError), + + #[error(transparent)] + Beam(#[from] crate::beam::BeamError), + + #[error(transparent)] + Mwalib(#[from] mwalib::MwalibError), + + #[error(transparent)] + IO(#[from] std::io::Error), } diff --git a/src/cli/srclist/by_beam/tests.rs b/src/cli/srclist/by_beam/tests.rs index 91d148ee..7424fed6 100644 --- a/src/cli/srclist/by_beam/tests.rs +++ b/src/cli/srclist/by_beam/tests.rs @@ -7,7 +7,10 @@ use std::{fs::File, io::BufReader, path::PathBuf}; use approx::assert_abs_diff_eq; use super::SrclistByBeamArgs; -use crate::srclist::{hyperdrive::source_list_from_json, read::read_source_list_file}; +use crate::{ + cli::common::BeamArgs, + srclist::{hyperdrive::source_list_from_json, read::read_source_list_file}, +}; #[test] fn test_srclist_by_beam() { @@ -23,8 +26,11 @@ fn test_srclist_by_beam() { output_source_list: Some(temp.path().to_path_buf()), input_type: None, output_type: None, - metafits: PathBuf::from("test_files/1090008640/1090008640.metafits"), - beam_file: None, + metafits: Some(PathBuf::from("test_files/1090008640/1090008640.metafits")), + array_position: None, + lst_rad: None, + phase_centre: None, + freqs_hz: None, number: n, source_dist_cutoff: None, veto_threshold: None, @@ -33,6 +39,12 @@ fn test_srclist_by_beam() { filter_shapelets: false, collapse_into_single_source: false, rts_base_source: None, + beam_args: BeamArgs { + beam_file: None, + unity_dipole_gains: false, + delays: None, + no_beam: false, + }, } .run() .unwrap(); diff --git a/src/cli/srclist/convert.rs b/src/cli/srclist/convert.rs index 64d22db1..35418798 100644 --- a/src/cli/srclist/convert.rs +++ b/src/cli/srclist/convert.rs @@ -5,23 +5,23 @@ //! Code to convert between sky-model source list files. use std::{ - fs::File, - io::{BufWriter, Write}, path::{Path, PathBuf}, str::FromStr, }; use clap::Parser; use itertools::Itertools; -use log::{debug, info, trace, warn}; +use log::{debug, info, trace}; use marlu::RADec; use rayon::prelude::*; use crate::{ - help_texts::{SOURCE_LIST_INPUT_TYPE_HELP, SOURCE_LIST_OUTPUT_TYPE_HELP}, + cli::common::{ + display_warnings, Warn, SOURCE_LIST_INPUT_TYPE_HELP, SOURCE_LIST_OUTPUT_TYPE_HELP, + }, srclist::{ - ao, hyperdrive, read::read_source_list_file, rts, woden, HyperdriveFileType, SourceList, - SourceListType, SrclistError, WriteSourceListError, + read::read_source_list_file, write_source_list, HyperdriveFileType, SourceList, + SourceListType, SrclistError, }, HyperdriveError, }; @@ -123,7 +123,7 @@ fn convert, S: AsRef>( let output_ext = output_path.extension().and_then(|e| e.to_str()); let output_file_type = output_ext.and_then(|e| HyperdriveFileType::from_str(e).ok()); if output_type.is_none() && output_file_type.is_some() { - warn!("Assuming that the output file type is 'hyperdrive'"); + "Assuming that the output file type is 'hyperdrive'".warn(); } // Read the input source list. @@ -220,45 +220,9 @@ fn convert, S: AsRef>( // Write the output source list. trace!("Attempting to write output source list"); - let mut f = BufWriter::new(File::create(output_path)?); - - match (output_type, output_file_type) { - (_, Some(HyperdriveFileType::Yaml)) => { - hyperdrive::source_list_to_yaml(&mut f, &sl, None)?; - info!( - "Wrote hyperdrive-style source list to {}", - output_path.display() - ); - } - (_, Some(HyperdriveFileType::Json)) => { - hyperdrive::source_list_to_json(&mut f, &sl, None)?; - info!( - "Wrote hyperdrive-style source list to {}", - output_path.display() - ); - } - (Some(SourceListType::Hyperdrive), None) => { - return Err(WriteSourceListError::InvalidHyperdriveFormat( - output_ext.unwrap_or("").to_string(), - ) - .into()) - } - (Some(SourceListType::Rts), _) => { - rts::write_source_list(&mut f, &sl, None)?; - info!("Wrote rts-style source list to {}", output_path.display()); - } - (Some(SourceListType::AO), _) => { - ao::write_source_list(&mut f, &sl, None)?; - info!("Wrote ao-style source list to {}", output_path.display()); - } - (Some(SourceListType::Woden), _) => { - woden::write_source_list(&mut f, &sl, None)?; - info!("Wrote woden-style source list to {}", output_path.display()); - } - (None, None) => return Err(WriteSourceListError::NotEnoughInfo.into()), - } + write_source_list(&sl, output_path, sl_type, output_type, None)?; - f.flush()?; + display_warnings(); Ok(()) } diff --git a/src/cli/srclist/mod.rs b/src/cli/srclist/mod.rs index b0a3d1fc..cf3b8495 100644 --- a/src/cli/srclist/mod.rs +++ b/src/cli/srclist/mod.rs @@ -4,7 +4,12 @@ //! Utilities surrounding source lists. -pub(crate) mod by_beam; -pub(crate) mod convert; -pub(crate) mod shift; -pub(crate) mod verify; +mod by_beam; +mod convert; +mod shift; +mod verify; + +pub(super) use by_beam::{SrclistByBeamArgs, SrclistByBeamError}; +pub(super) use convert::SrclistConvertArgs; +pub(super) use shift::SrclistShiftArgs; +pub(super) use verify::SrclistVerifyArgs; diff --git a/src/cli/srclist/shift.rs b/src/cli/srclist/shift.rs index 52216e93..0ba2c1ce 100644 --- a/src/cli/srclist/shift.rs +++ b/src/cli/srclist/shift.rs @@ -13,16 +13,18 @@ use std::{ use clap::Parser; use indexmap::IndexMap; use itertools::Itertools; -use log::{debug, info, trace, warn}; +use log::{debug, info, trace}; use marlu::RADec; use rayon::prelude::*; use serde::Deserialize; use crate::{ - help_texts::{SOURCE_LIST_INPUT_TYPE_HELP, SOURCE_LIST_OUTPUT_TYPE_HELP}, + cli::common::{ + display_warnings, Warn, SOURCE_LIST_INPUT_TYPE_HELP, SOURCE_LIST_OUTPUT_TYPE_HELP, + }, srclist::{ - ao, hyperdrive, read::read_source_list_file, rts, woden, HyperdriveFileType, Source, - SourceList, SourceListType, SrclistError, WriteSourceListError, + read::read_source_list_file, rts, write_source_list, Source, SourceList, SourceListType, + SrclistError, WriteSourceListError, }, HyperdriveError, }; @@ -74,288 +76,224 @@ impl SrclistShiftArgs { shift( &self.source_list, &self.source_shifts, - self.output_source_list.as_ref(), - self.input_type.as_ref(), - self.output_type.as_ref(), + self.output_source_list.as_deref(), + self.input_type.as_deref(), + self.output_type.as_deref(), self.collapse_into_single_source, self.include_unshifted_sources, - self.metafits.as_ref(), + self.metafits.as_deref(), )?; Ok(()) } } #[allow(clippy::too_many_arguments)] -fn shift, S: AsRef>( - source_list_file: P, - source_shifts_file: P, - output_source_list_file: Option

, - source_list_input_type: Option, - source_list_output_type: Option, +fn shift( + source_list_file: &Path, + source_shifts_file: &Path, + output_source_list_file: Option<&Path>, + source_list_input_type: Option<&str>, + source_list_output_type: Option<&str>, collapse_into_single_source: bool, include_unshifted_sources: bool, - metafits_file: Option

, + metafits_file: Option<&Path>, ) -> Result<(), SrclistError> { - fn inner( - source_list_file: &Path, - source_shifts_file: &Path, - output_source_list_file: Option<&Path>, - source_list_input_type: Option<&str>, - source_list_output_type: Option<&str>, - collapse_into_single_source: bool, - include_unshifted_sources: bool, - metafits_file: Option<&Path>, - ) -> Result<(), SrclistError> { - let output_path: PathBuf = match output_source_list_file { - Some(p) => p.to_path_buf(), - None => { - let input_path_base = source_list_file - .file_stem() - .and_then(|os_str| os_str.to_str()) - .expect("Input file didn't have a filename stem"); - let input_path_ext = source_list_file - .extension() - .and_then(|os_str| os_str.to_str()) - .expect("Input file didn't have an extension"); - let output_pb = - PathBuf::from(format!("{input_path_base}_shifted.{input_path_ext}")); - trace!("Writing shifted source list to {}", output_pb.display()); - output_pb - } - }; - let input_type = source_list_input_type.and_then(|t| SourceListType::from_str(t).ok()); - - let f = BufReader::new(File::open(source_shifts_file)?); - let source_shifts: BTreeMap = - serde_json::from_reader(f).map_err(WriteSourceListError::from)?; - let (sl, sl_type) = crate::misc::expensive_op( - || read_source_list_file(source_list_file, input_type), - "Still reading source list file", - )?; - info!( - "Successfully read {} as a {}-style source list", - source_list_file.display(), - sl_type - ); - let counts = sl.get_counts(); - info!( - "Input has {} points, {} gaussians, {} shapelets", - counts.num_points, counts.num_gaussians, counts.num_shapelets - ); + let output_path: PathBuf = match output_source_list_file { + Some(p) => p.to_path_buf(), + None => { + let input_path_base = source_list_file + .file_stem() + .and_then(|os_str| os_str.to_str()) + .expect("Input file didn't have a filename stem"); + let input_path_ext = source_list_file + .extension() + .and_then(|os_str| os_str.to_str()) + .expect("Input file didn't have an extension"); + let output_pb = PathBuf::from(format!("{input_path_base}_shifted.{input_path_ext}")); + trace!("Writing shifted source list to {}", output_pb.display()); + output_pb + } + }; + let input_type = source_list_input_type.and_then(|t| SourceListType::from_str(t).ok()); - let metafits: Option = - match (collapse_into_single_source, metafits_file) { - (false, _) => None, - (true, None) => return Err(SrclistError::MissingMetafits), - (true, Some(m)) => { - trace!("Attempting to open the metafits file"); - let m = mwalib::MetafitsContext::new(m, None)?; - Some(m) - } - }; + let f = BufReader::new(File::open(source_shifts_file)?); + let source_shifts: BTreeMap = + serde_json::from_reader(f).map_err(WriteSourceListError::from)?; + let (sl, sl_type) = crate::misc::expensive_op( + || read_source_list_file(source_list_file, input_type), + "Still reading source list file", + )?; + info!( + "Successfully read {} as a {}-style source list", + source_list_file.display(), + sl_type + ); + let counts = sl.get_counts(); + info!( + "Input has {} points, {} gaussians, {} shapelets", + counts.num_points, counts.num_gaussians, counts.num_shapelets + ); - // If this an RTS source list, then the order of the sources in the source - // list is important and must be preserved. When hyperdrive reads these - // source lists, the ordering is thrown away because it was not designed - // with this in mind (and, it should never consider it, as it's a detail - // that we never want to care about). Here, we know that we've read an RTS - // source list so we can (in a dirty fashion) get the order of the sources. - let source_name_order: Option> = match sl_type { - SourceListType::Rts => { - warn!("Preserving the order of the RTS sources"); - let mut names = vec![]; - let f = BufReader::new(File::open(source_list_file)?); - for line in f.lines() { - let line = line?; - if line.starts_with("SOURCE") { - // unwrap is safe because we successfully read the RTS - // source list earlier. - let source_name = line.split_whitespace().nth(1).unwrap(); - if include_unshifted_sources || source_shifts.contains_key(source_name) { - names.push(source_name.to_string()); - } - } - } - Some(names) + let metafits: Option = + match (collapse_into_single_source, metafits_file) { + (false, _) => None, + (true, None) => return Err(SrclistError::MissingMetafits), + (true, Some(m)) => { + trace!("Attempting to open the metafits file"); + let m = mwalib::MetafitsContext::new(m, None)?; + Some(m) } - _ => None, }; - // Filter any sources that aren't in the shifts file, and shift the - // sources. All components of a source get shifted the same amount. - let mut sl = { - let no_shift = RaDec { ra: 0.0, dec: 0.0 }; - let tmp_sl: IndexMap = sl - .into_iter() - .filter(|(name, _)| include_unshifted_sources || source_shifts.contains_key(name)) - .map(|(name, mut src)| { - let shift = if source_shifts.contains_key(&name) { - &source_shifts[&name] - } else { - &no_shift - }; - src.components.iter_mut().for_each(|comp| { - comp.radec.ra += shift.ra.to_radians(); - comp.radec.dec += shift.dec.to_radians(); - }); - (name, src) - }) - .collect(); - SourceList::from(tmp_sl) - }; - - // If requested, collapse the source list. - sl = if let Some(meta) = metafits { - let ra_phase_centre = meta - .ra_phase_center_degrees - .unwrap_or(meta.ra_tile_pointing_degrees); - let dec_phase_centre = meta - .dec_phase_center_degrees - .unwrap_or(meta.dec_tile_pointing_degrees); - let phase_centre = RADec::from_degrees(ra_phase_centre, dec_phase_centre); - debug!("Using {} as the phase centre", phase_centre); - let lst = meta.lst_rad; - debug!("Using {}° as the LST", lst.to_degrees()); - let coarse_chan_freqs: Vec = meta - .metafits_coarse_chans - .iter() - .map(|cc| cc.chan_centre_hz as _) - .collect(); - debug!( - "Using coarse channel frequencies [MHz]: {}", - coarse_chan_freqs - .iter() - .map(|cc_freq_hz| format!("{:.2}", *cc_freq_hz / 1e6)) - .join(", ") - ); - - let mut collapsed = SourceList::new(); - // If we're preserving the order of the RTS sources, then use the first - // source as the base. - if let Some(ordered) = &source_name_order { - let base_name = ordered.first().unwrap().clone(); - let base = sl.remove_entry(&base_name).unwrap(); - collapsed.insert(base_name, base.1); - let base_src = collapsed.get_mut(&base.0).unwrap(); - let mut base_comps = vec![].into_boxed_slice(); - std::mem::swap(&mut base_src.components, &mut base_comps); - let mut base_comps = base_comps.into_vec(); - - for name in &ordered[1..] { - for comp in sl[name].components.iter() { - base_comps.push(comp.clone()); + // If this an RTS source list, then the order of the sources in the source + // list is important and must be preserved. When hyperdrive reads these + // source lists, the ordering is thrown away because it was not designed + // with this in mind (and, it should never consider it, as it's a detail + // that we never want to care about). Here, we know that we've read an RTS + // source list so we can (in a dirty fashion) get the order of the sources. + let source_name_order: Option> = match sl_type { + SourceListType::Rts => { + "Preserving the order of the RTS sources".warn(); + let mut names = vec![]; + let f = BufReader::new(File::open(source_list_file)?); + for line in f.lines() { + let line = line?; + if line.starts_with("SOURCE") { + // unwrap is safe because we successfully read the RTS + // source list earlier. + let source_name = line.split_whitespace().nth(1).unwrap(); + if include_unshifted_sources || source_shifts.contains_key(source_name) { + names.push(source_name.to_string()); } } - std::mem::swap(&mut base_src.components, &mut base_comps.into_boxed_slice()); - } else { - // Use the apparently brightest source as the base. - let brightest = sl - .par_iter() - .map(|(name, src)| { - let stokes_i = src - .get_flux_estimates(150e6) - .iter() - .fold(0.0, |acc, fd| acc + fd.i); - (name, stokes_i) - }) - .max_by(|x, y| x.1.partial_cmp(&y.1).unwrap()) - .unwrap(); - let base_name = brightest.0.clone(); - - let base = sl.remove_entry(&base_name).unwrap(); - collapsed.insert(base_name, base.1); - let base_src = collapsed.get_mut(&base.0).unwrap(); - let mut base_comps = vec![].into_boxed_slice(); - std::mem::swap(&mut base_src.components, &mut base_comps); - let mut base_comps = base_comps.into_vec(); - sl.into_iter() - .flat_map(|(_, src)| src.components.to_vec()) - .for_each(|comp| base_comps.push(comp)); - std::mem::swap(&mut base_src.components, &mut base_comps.into_boxed_slice()); } + Some(names) + } + _ => None, + }; - collapsed - } else { - sl - }; - let counts = sl.get_counts(); - info!( - "Shifted {} points, {} gaussians, {} shapelets", - counts.num_points, counts.num_gaussians, counts.num_shapelets + // Filter any sources that aren't in the shifts file, and shift the + // sources. All components of a source get shifted the same amount. + let mut sl = { + let no_shift = RaDec { ra: 0.0, dec: 0.0 }; + let tmp_sl: IndexMap = sl + .into_iter() + .filter(|(name, _)| include_unshifted_sources || source_shifts.contains_key(name)) + .map(|(name, mut src)| { + let shift = if source_shifts.contains_key(&name) { + &source_shifts[&name] + } else { + &no_shift + }; + src.components.iter_mut().for_each(|comp| { + comp.radec.ra += shift.ra.to_radians(); + comp.radec.dec += shift.dec.to_radians(); + }); + (name, src) + }) + .collect(); + SourceList::from(tmp_sl) + }; + + // If requested, collapse the source list. + sl = if let Some(meta) = metafits { + let ra_phase_centre = meta + .ra_phase_center_degrees + .unwrap_or(meta.ra_tile_pointing_degrees); + let dec_phase_centre = meta + .dec_phase_center_degrees + .unwrap_or(meta.dec_tile_pointing_degrees); + let phase_centre = RADec::from_degrees(ra_phase_centre, dec_phase_centre); + debug!("Using {} as the phase centre", phase_centre); + let lst = meta.lst_rad; + debug!("Using {}° as the LST", lst.to_degrees()); + let coarse_chan_freqs: Vec = meta + .metafits_coarse_chans + .iter() + .map(|cc| cc.chan_centre_hz as _) + .collect(); + debug!( + "Using coarse channel frequencies [MHz]: {}", + coarse_chan_freqs + .iter() + .map(|cc_freq_hz| format!("{:.2}", *cc_freq_hz / 1e6)) + .join(", ") ); - // Write the output source list. - trace!("Attempting to write output source list"); - let mut f = std::io::BufWriter::new(File::create(&output_path)?); + let mut collapsed = SourceList::new(); + // If we're preserving the order of the RTS sources, then use the first + // source as the base. + if let Some(ordered) = &source_name_order { + let base_name = ordered.first().unwrap().clone(); + let base = sl.remove_entry(&base_name).unwrap(); + collapsed.insert(base_name, base.1); + let base_src = collapsed.get_mut(&base.0).unwrap(); + let mut base_comps = vec![].into_boxed_slice(); + std::mem::swap(&mut base_src.components, &mut base_comps); + let mut base_comps = base_comps.to_vec(); - let output_ext = output_path.extension().and_then(|e| e.to_str()); - let hyp_file_type = output_ext.and_then(|e| HyperdriveFileType::from_str(e).ok()); - let output_type = match (source_list_output_type, &hyp_file_type) { - (Some(t), _) => { - // Try to parse the specified output type. - match SourceListType::from_str(t) { - Ok(t) => t, - Err(_) => return Err(WriteSourceListError::InvalidFormat.into()), + for name in &ordered[1..] { + for comp in sl[name].components.iter() { + base_comps.push(comp.clone()); } } + std::mem::swap(&mut base_src.components, &mut base_comps.into_boxed_slice()); + } else { + // Use the apparently brightest source as the base. + let brightest = sl + .par_iter() + .map(|(name, src)| { + let stokes_i = src + .get_flux_estimates(150e6) + .iter() + .fold(0.0, |acc, fd| acc + fd.i); + (name, stokes_i) + }) + .max_by(|x, y| x.1.partial_cmp(&y.1).unwrap()) + .unwrap(); + let base_name = brightest.0.clone(); - (None, Some(_)) => SourceListType::Hyperdrive, + let base = sl.remove_entry(&base_name).unwrap(); + collapsed.insert(base_name, base.1); + let base_src = collapsed.get_mut(&base.0).unwrap(); + let mut base_comps = vec![].into_boxed_slice(); + std::mem::swap(&mut base_src.components, &mut base_comps); + let mut base_comps = base_comps.to_vec(); + sl.into_iter() + .flat_map(|(_, src)| src.components.to_vec()) + .for_each(|comp| base_comps.push(comp)); + std::mem::swap(&mut base_src.components, &mut base_comps.into_boxed_slice()); + } - // Use the input source list type as the output type. - (None, None) => sl_type, - }; + collapsed + } else { + sl + }; + let counts = sl.get_counts(); + info!( + "Shifted {} points, {} gaussians, {} shapelets", + counts.num_points, counts.num_gaussians, counts.num_shapelets + ); - match (output_type, hyp_file_type) { - (SourceListType::Hyperdrive, None) => { - return Err(WriteSourceListError::InvalidHyperdriveFormat( - output_ext.unwrap_or("").to_string(), - ) - .into()) - } - (SourceListType::Rts, _) => { - if let Some(source_name_order) = source_name_order { - rts::write_source_list_with_order(&mut f, &sl, source_name_order)?; - } else { - rts::write_source_list(&mut f, &sl, None)?; - } - info!("Wrote rts-style source list to {}", output_path.display()); - } - (SourceListType::AO, _) => { - ao::write_source_list(&mut f, &sl, None)?; - info!("Wrote ao-style source list to {}", output_path.display()); - } - (SourceListType::Woden, _) => { - woden::write_source_list(&mut f, &sl, None)?; - info!("Wrote woden-style source list to {}", output_path.display()); - } - (_, Some(HyperdriveFileType::Yaml)) => { - hyperdrive::source_list_to_yaml(&mut f, &sl, None)?; - info!( - "Wrote hyperdrive-style source list to {}", - output_path.display() - ); - } - (_, Some(HyperdriveFileType::Json)) => { - hyperdrive::source_list_to_json(&mut f, &sl, None)?; - info!( - "Wrote hyperdrive-style source list to {}", - output_path.display() - ); - } + match (sl_type, source_name_order) { + (SourceListType::Rts, Some(source_name_order)) => { + trace!("Attempting to write ordered RTS source list"); + let mut f = std::io::BufWriter::new(std::fs::File::create(&output_path)?); + rts::write_source_list_with_order(&mut f, &sl, source_name_order)?; } - Ok(()) + _ => write_source_list( + &sl, + &output_path, + sl_type, + source_list_output_type.and_then(|s| SourceListType::from_str(s).ok()), + None, + )?, } - inner( - source_list_file.as_ref(), - source_shifts_file.as_ref(), - output_source_list_file.as_ref().map(|f| f.as_ref()), - source_list_input_type.as_ref().map(|f| f.as_ref()), - source_list_output_type.as_ref().map(|f| f.as_ref()), - collapse_into_single_source, - include_unshifted_sources, - metafits_file.as_ref().map(|f| f.as_ref()), - ) + + display_warnings(); + + Ok(()) } #[derive(Deserialize)] diff --git a/src/cli/srclist/verify.rs b/src/cli/srclist/verify.rs index 69a30e01..83dc7d32 100644 --- a/src/cli/srclist/verify.rs +++ b/src/cli/srclist/verify.rs @@ -4,15 +4,17 @@ //! Code to verify sky-model source list files. -use std::fs::File; -use std::path::{Path, PathBuf}; -use std::str::FromStr; +use std::{ + fs::File, + path::{Path, PathBuf}, + str::FromStr, +}; use clap::Parser; use log::info; use crate::{ - help_texts::SOURCE_LIST_INPUT_TYPE_HELP, + cli::common::{display_warnings, SOURCE_LIST_INPUT_TYPE_HELP}, srclist::{ ao, hyperdrive, read::read_source_list_file, rts, woden, ComponentCounts, SourceListType, SrclistError, @@ -126,5 +128,7 @@ fn verify>( info!(""); } + display_warnings(); + Ok(()) } diff --git a/src/cli/vis_convert/mod.rs b/src/cli/vis_convert/mod.rs new file mode 100644 index 00000000..de2734c9 --- /dev/null +++ b/src/cli/vis_convert/mod.rs @@ -0,0 +1,180 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +#[cfg(test)] +mod tests; + +use std::path::PathBuf; + +use clap::Parser; +use log::{debug, info, trace}; +use serde::{Deserialize, Serialize}; + +use super::common::{InputVisArgs, OutputVisArgs, ARG_FILE_HELP}; +use crate::{ + cli::common::display_warnings, io::write::VIS_OUTPUT_EXTENSIONS, params::VisConvertParams, + HyperdriveError, +}; + +lazy_static::lazy_static! { + static ref OUTPUTS_HELP: String = + format!("Paths to the output visibility files. Supported formats: {}", *VIS_OUTPUT_EXTENSIONS); +} + +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(super) struct VisConvertArgs { + #[clap(name = "ARGUMENTS_FILE", help = ARG_FILE_HELP.as_str(), parse(from_os_str))] + pub(super) args_file: Option, + + #[clap(flatten)] + #[serde(rename = "data")] + #[serde(default)] + pub(super) data_args: InputVisArgs, + + #[clap( + short = 'o', + long, + multiple_values(true), + help = OUTPUTS_HELP.as_str(), + help_heading = "OUTPUT FILES" + )] + pub(super) outputs: Option>, + + /// When writing out visibilities, average this many timesteps together. + /// Also supports a target time resolution (e.g. 8s). The value must be a + /// multiple of the input data's time resolution. The default is no + /// averaging, i.e. a value of 1. Examples: If the input data is in 0.5s + /// resolution and this variable is 4, then we average 2s worth of data + /// together before writing the data out. If the variable is instead 4s, + /// then 8 timesteps are averaged together before writing the data out. + #[clap(long, help_heading = "OUTPUT FILES")] + pub(super) output_vis_time_average: Option, + + /// When writing out visibilities, average this many fine freq. channels + /// together. Also supports a target freq. resolution (e.g. 80kHz). The + /// value must be a multiple of the input data's freq. resolution. The + /// default is no averaging, i.e. a value of 1. Examples: If the input data + /// is in 40kHz resolution and this variable is 4, then we average 160kHz + /// worth of data together before writing the data out. If the variable is + /// instead 80kHz, then 2 fine freq. channels are averaged together before + /// writing the data out. + #[clap(long, help_heading = "OUTPUT FILES")] + pub(super) output_vis_freq_average: Option, + + /// Rather than writing out the entire input bandwidth, write out only the + /// smallest contiguous band. e.g. Typical 40 kHz MWA data has 768 channels, + /// but the first 2 and last 2 channels are usually flagged. Turning this + /// option on means that 764 channels would be written out instead of 768. + /// Note that other flagged channels in the band are unaffected, because the + /// data written out must be contiguous. + #[clap(long, help_heading = "OUTPUT FILES")] + #[serde(default)] + pub(super) output_smallest_contiguous_band: bool, +} + +impl VisConvertArgs { + /// Both command-line and file arguments overlap in terms of what is + /// available; this function consolidates everything that was specified into + /// a single struct. Where applicable, it will prefer CLI parameters over + /// those in the file. + /// + /// The argument to this function is the path to the arguments file. + /// + /// This function should only ever merge arguments, and not try to make + /// sense of them. + pub(super) fn merge(self) -> Result { + debug!("Merging command-line arguments with the argument file"); + + let cli_args = self; + + if let Some(arg_file) = cli_args.args_file { + // Read in the file arguments. Ensure all of the file args are + // accounted for by pattern matching. + let VisConvertArgs { + args_file: _, + data_args, + outputs, + output_vis_time_average, + output_vis_freq_average, + output_smallest_contiguous_band, + } = unpack_arg_file!(arg_file); + + // Merge all the arguments, preferring the CLI args when available. + Ok(VisConvertArgs { + args_file: None, + data_args: cli_args.data_args.merge(data_args), + outputs: cli_args.outputs.or(outputs), + output_vis_time_average: cli_args + .output_vis_time_average + .or(output_vis_time_average), + output_vis_freq_average: cli_args + .output_vis_freq_average + .or(output_vis_freq_average), + output_smallest_contiguous_band: cli_args.output_smallest_contiguous_band + || output_smallest_contiguous_band, + }) + } else { + Ok(cli_args) + } + } + + pub(super) fn parse(self) -> Result { + debug!("{:#?}", self); + + let Self { + args_file: _, + data_args, + outputs, + output_vis_time_average, + output_vis_freq_average, + output_smallest_contiguous_band, + } = self; + + if outputs.is_none() { + return Err(VisConvertArgsError::NoOutputs.into()); + } + + let input_vis_params = data_args.parse("Converting")?; + let output_vis_params = OutputVisArgs { + outputs, + output_vis_time_average, + output_vis_freq_average, + } + .parse( + input_vis_params.time_res, + input_vis_params.spw.freq_res, + &input_vis_params.timeblocks.mapped_ref(|tb| tb.median), + output_smallest_contiguous_band, + "", // Won't be used because the outputs are checked above. + None, + )?; + + display_warnings(); + + Ok(VisConvertParams { + input_vis_params, + output_vis_params, + }) + } + + pub(super) fn run(self, dry_run: bool) -> Result<(), HyperdriveError> { + debug!("Converting arguments into parameters"); + trace!("{:#?}", self); + let params = self.parse()?; + + if dry_run { + info!("Dry run -- exiting now."); + return Ok(()); + } + + params.run()?; + Ok(()) + } +} + +#[derive(thiserror::Error, Debug)] +pub(super) enum VisConvertArgsError { + #[error("No output visibility files were specified")] + NoOutputs, +} diff --git a/src/cli/vis_convert/tests.rs b/src/cli/vis_convert/tests.rs new file mode 100644 index 00000000..52f98d15 --- /dev/null +++ b/src/cli/vis_convert/tests.rs @@ -0,0 +1,152 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +use std::{ + num::NonZeroU16, + path::{Path, PathBuf}, +}; + +use clap::Parser; +use tempfile::TempDir; + +use super::VisConvertArgs; +use crate::{ + io::read::VisRead, + tests::{get_reduced_1090008640_raw, DataAsStrings}, + MsReader, UvfitsReader, +}; + +#[test] +fn test_per_coarse_chan_flags_and_smallest_contiguous_band_writing() { + let temp_dir = TempDir::new().expect("couldn't make tmp dir"); + let uvfits_converted = temp_dir.path().join("converted.uvfits"); + let ms_converted = temp_dir.path().join("converted.ms"); + + fn get_data_object( + output: &Path, + per_coarse_chan_flags: Option>, + output_smallest_contiguous_band: bool, + ) -> Box { + let DataAsStrings { + metafits, + vis, + mwafs: _, + srclist: _, + } = get_reduced_1090008640_raw(); + let metafits_pb = PathBuf::from(&metafits); + + let output_string = output.display().to_string(); + #[rustfmt::skip] + let mut args = vec![ + "vis-convert", + "--data", &vis[0], &metafits, + "--outputs", &output_string + ]; + if output_smallest_contiguous_band { + args.push("--output-smallest-contiguous-band"); + } + if let Some(per_coarse_chan_flags) = per_coarse_chan_flags.as_ref() { + args.push("--fine-chan-flags-per-coarse-chan"); + for f in per_coarse_chan_flags { + args.push(f.as_str()); + } + } + let vis_convert_args = VisConvertArgs::parse_from(args); + vis_convert_args.run(false).unwrap(); + + match output.extension().and_then(|os_str| os_str.to_str()) { + Some("uvfits") => { + Box::new(UvfitsReader::new(output.to_path_buf(), Some(&metafits_pb), None).unwrap()) + } + Some("ms") => Box::new( + MsReader::new(output.to_path_buf(), None, Some(&metafits_pb), None).unwrap(), + ), + _ => unreachable!(), + } + } + + for output_smallest_contiguous_band in [false, true] { + for output in [&uvfits_converted, &ms_converted] { + let data = get_data_object(output, None, output_smallest_contiguous_band); + let obs_context = data.get_obs_context(); + if output_smallest_contiguous_band { + assert_eq!(obs_context.fine_chan_freqs.len(), 28); + } else { + assert_eq!(obs_context.fine_chan_freqs.len(), 32); + }; + assert_eq!( + obs_context.num_fine_chans_per_coarse_chan, + Some(NonZeroU16::new(32).unwrap()) + ); + assert_eq!( + obs_context.mwa_coarse_chan_nums.as_deref(), + Some([154].as_slice()) + ); + match output.extension().and_then(|os_str| os_str.to_str()) { + Some("uvfits") => { + // uvfits currently doesn't try to determine this. + assert_eq!(obs_context.flagged_fine_chans_per_coarse_chan, None); + } + Some("ms") => { + assert_eq!( + obs_context.flagged_fine_chans_per_coarse_chan.as_deref(), + if output_smallest_contiguous_band { + // The MS reader can't tell if the edge channels are + // flagged or not; they're not available. + Some([16].as_slice()) + } else { + Some([0, 1, 16, 30, 31].as_slice()) + } + ); + } + _ => unreachable!(), + } + } + + // Now test with additional per-coarse-chan flags. + for output in [&uvfits_converted, &ms_converted] { + let data = get_data_object( + output, + Some(vec!["2".to_string(), "5".to_string()]), + output_smallest_contiguous_band, + ); + let obs_context = data.get_obs_context(); + if output_smallest_contiguous_band { + assert_eq!(obs_context.fine_chan_freqs.len(), 27); + } else { + assert_eq!(obs_context.fine_chan_freqs.len(), 32); + }; + assert_eq!( + obs_context.num_fine_chans_per_coarse_chan, + Some(NonZeroU16::new(32).unwrap()) + ); + assert_eq!( + obs_context.mwa_coarse_chan_nums.as_deref(), + Some([154].as_slice()) + ); + match output.extension().and_then(|os_str| os_str.to_str()) { + Some("uvfits") => { + // uvfits currently doesn't try to determine this. + assert_eq!( + obs_context.flagged_fine_chans_per_coarse_chan.as_deref(), + None + ); + } + Some("ms") => { + assert_eq!( + obs_context.flagged_fine_chans_per_coarse_chan.as_deref(), + if output_smallest_contiguous_band { + // The MS reader can't tell if the edge channels are + // flagged or not; they're not available. + Some([5, 16].as_slice()) + } else { + Some([0, 1, 2, 5, 16, 30, 31].as_slice()) + } + ); + } + _ => unreachable!(), + } + } + } +} diff --git a/src/cli/vis_simulate/mod.rs b/src/cli/vis_simulate/mod.rs new file mode 100644 index 00000000..2e460020 --- /dev/null +++ b/src/cli/vis_simulate/mod.rs @@ -0,0 +1,659 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Generate sky-model visibilities from a sky-model source list. + +// TODO: Utilise tile flags. +// TODO: Allow the user to specify the mwa_version for the metafits file. + +use std::{ + collections::HashSet, + path::{Path, PathBuf}, +}; + +use clap::Parser; +use console::style; +use hifitime::{Duration, Epoch}; +use log::{debug, info, trace}; +use marlu::{precession::precess_time, LatLngHeight, RADec, XyzGeodetic}; +use mwalib::MetafitsContext; +use serde::{Deserialize, Serialize}; +use thiserror::Error; +use vec1::Vec1; + +use super::common::{ + display_warnings, BeamArgs, ModellingArgs, OutputVisArgs, SkyModelWithVetoArgs, ARG_FILE_HELP, + ARRAY_POSITION_HELP, +}; +use crate::{ + beam::Delays, + cli::common::InfoPrinter, + io::write::VIS_OUTPUT_EXTENSIONS, + math::TileBaselineFlags, + metafits::{get_dipole_delays, get_dipole_gains}, + params::VisSimulateParams, + srclist::ComponentCounts, + HyperdriveError, +}; + +const DEFAULT_OUTPUT_VIS_FILENAME: &str = "hyp_model.uvfits"; +const DEFAULT_NUM_FINE_CHANNELS: usize = 384; +const DEFAULT_FREQ_RES_KHZ: f64 = 80.0; +const DEFAULT_NUM_TIMESTEPS: usize = 14; +const DEFAULT_TIME_RES_SECONDS: f64 = 8.0; + +lazy_static::lazy_static! { + static ref NUM_FINE_CHANNELS_HELP: String = + format!("The total number of fine channels in the observation. Default: {DEFAULT_NUM_FINE_CHANNELS}"); + + static ref FREQ_RES_HELP: String = + format!("The fine-channel resolution [kHz]. Default: {DEFAULT_FREQ_RES_KHZ}"); + + static ref NUM_TIMESTEPS_HELP: String = + format!("The number of time steps used from the metafits epoch. Default: {DEFAULT_NUM_TIMESTEPS}"); + + static ref TIME_RES_HELP: String = + format!("The time resolution [seconds]. Default: {DEFAULT_TIME_RES_SECONDS}"); + + static ref OUTPUTS_HELP: String = + format!("Paths to the output visibility files. Supported formats: {}. Default: {}", *VIS_OUTPUT_EXTENSIONS, DEFAULT_OUTPUT_VIS_FILENAME); +} + +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(super) struct VisSimulateCliArgs { + /// Path to the metafits file. + #[clap(short, long, parse(from_str), help_heading = "INPUT FILES")] + pub(super) metafits: Option, + + /// Use this value as the DUT1 [seconds]. + #[clap(long, help_heading = "INPUT DATA")] + #[serde(default)] + dut1: Option, + + /// Use a DUT1 value of 0 seconds rather than what is in the metafits file. + #[clap(long, conflicts_with("dut1"), help_heading = "INPUT FILES")] + ignore_dut1: bool, + + /// The phase centre right ascension [degrees]. If this is not specified, + /// then the metafits phase/pointing centre is used. + #[clap(short, long, help_heading = "OBSERVATION PARAMETERS")] + ra: Option, + + /// The phase centre declination [degrees]. If this is not specified, then + /// the metafits phase/pointing centre is used. + #[clap(short, long, help_heading = "OBSERVATION PARAMETERS")] + dec: Option, + + #[clap( + short = 'c', + long, + help = NUM_FINE_CHANNELS_HELP.as_str(), + help_heading = "OBSERVATION PARAMETERS" + )] + pub(super) num_fine_channels: Option, + + #[clap( + short, + long, + help = FREQ_RES_HELP.as_str(), + help_heading = "OBSERVATION PARAMETERS" + )] + freq_res: Option, + + /// The centroid frequency of the simulation [MHz]. If this is not + /// specified, then the FREQCENT specified in the metafits is used. + #[clap(long, help_heading = "OBSERVATION PARAMETERS")] + middle_freq: Option, + + #[clap( + short = 't', + long, + help = NUM_TIMESTEPS_HELP.as_str(), + help_heading = "OBSERVATION PARAMETERS" + )] + pub(super) num_timesteps: Option, + + #[clap(long, help = TIME_RES_HELP.as_str(), help_heading = "OBSERVATION PARAMETERS")] + pub(super) time_res: Option, + + /// The time offset from the start [seconds]. The default start time is the + /// is the obsid. + #[clap(long, help_heading = "OBSERVATION PARAMETERS")] + time_offset: Option, + + #[clap( + long, help = ARRAY_POSITION_HELP.as_str(), help_heading = "OBSERVATION PARAMETERS", + number_of_values = 3, + allow_hyphen_values = true, + value_names = &["LONG_DEG", "LAT_DEG", "HEIGHT_M"] + )] + array_position: Option>, + + #[clap( + short = 'o', + long, + multiple_values(true), + help = OUTPUTS_HELP.as_str(), + help_heading = "OUTPUT FILES" + )] + pub(super) output_model_files: Option>, + + /// When writing out model visibilities, average this many timesteps + /// together. Also supports a target time resolution (e.g. 8s). The value + /// must be a multiple of the input data's time resolution. The default is + /// no averaging, i.e. a value of 1. Examples: If the input data is in 0.5s + /// resolution and this variable is 4, then we average 2s worth of data + /// together before writing the data out. If the variable is instead 4s, + /// then 8 timesteps are averaged together before writing the data out. + #[clap(long, help_heading = "OUTPUT FILES")] + output_model_time_average: Option, + + /// When writing out model visibilities, average this many fine freq. + /// channels together. Also supports a target freq. resolution (e.g. 80kHz). + /// The value must be a multiple of the input data's freq. resolution. The + /// default is no averaging, i.e. a value of 1. Examples: If the input data + /// is in 40kHz resolution and this variable is 4, then we average 160kHz + /// worth of data together before writing the data out. If the variable is + /// instead 80kHz, then 2 fine freq. channels are averaged together before + /// writing the data out. + #[clap(long, help_heading = "OUTPUT FILES")] + output_model_freq_average: Option, + + /// Remove any "point" components from the input sky model. + #[clap(long, help_heading = "SKY MODEL")] + filter_points: bool, + + /// Remove any "gaussian" components from the input sky model. + #[clap(long, help_heading = "SKY MODEL")] + filter_gaussians: bool, + + /// Remove any "shapelet" components from the input sky model. + #[clap(long, help_heading = "SKY MODEL")] + filter_shapelets: bool, +} + +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(super) struct VisSimulateArgs { + #[clap(name = "ARGUMENTS_FILE", help = ARG_FILE_HELP.as_str(), parse(from_os_str))] + pub(super) args_file: Option, + + #[clap(flatten)] + #[serde(rename = "beam")] + #[serde(default)] + pub(super) beam_args: BeamArgs, + + #[clap(flatten)] + #[serde(rename = "model")] + #[serde(default)] + pub(super) modelling_args: ModellingArgs, + + #[clap(flatten)] + #[serde(rename = "sky-model")] + #[serde(default)] + pub(super) srclist_args: SkyModelWithVetoArgs, + + #[clap(flatten)] + #[serde(rename = "vis-simulate")] + #[serde(default)] + pub(super) simulate_args: VisSimulateCliArgs, +} + +impl VisSimulateArgs { + /// Both command-line and file arguments overlap in terms of what is + /// available; this function consolidates everything that was specified into + /// a single struct. Where applicable, it will prefer CLI parameters over + /// those in the file. + /// + /// The argument to this function is the path to the arguments file. + /// + /// This function should only ever merge arguments, and not try to make + /// sense of them. + pub(super) fn merge(self) -> Result { + debug!("Merging command-line arguments with the argument file"); + + let cli_args = self; + + if let Some(arg_file) = cli_args.args_file { + // Read in the file arguments. Ensure all of the file args are + // accounted for by pattern matching. + let VisSimulateArgs { + args_file: _, + beam_args, + modelling_args, + srclist_args, + simulate_args, + } = unpack_arg_file!(arg_file); + + // Merge all the arguments, preferring the CLI args when available. + Ok(VisSimulateArgs { + args_file: None, + beam_args: cli_args.beam_args.merge(beam_args), + modelling_args: cli_args.modelling_args.merge(modelling_args), + srclist_args: cli_args.srclist_args.merge(srclist_args), + simulate_args: cli_args.simulate_args.merge(simulate_args), + }) + } else { + Ok(cli_args) + } + } + + fn parse(self) -> Result { + debug!("{:#?}", self); + + // Expose all the struct fields to ensure they're all used. + let VisSimulateArgs { + args_file: _, + beam_args, + modelling_args, + srclist_args, + simulate_args: + VisSimulateCliArgs { + metafits, + dut1, + ignore_dut1, + ra, + dec, + num_fine_channels, + freq_res, + middle_freq, + num_timesteps, + time_res, + time_offset, + array_position, + output_model_files, + output_model_time_average, + output_model_freq_average, + filter_points, + filter_gaussians, + filter_shapelets, + }, + } = self; + + // Read the metafits file with mwalib. + let metafits = if let Some(metafits) = metafits { + if !metafits.exists() { + return Err( + VisSimulateArgsError::MetafitsDoesntExist(metafits.into_boxed_path()).into(), + ); + } + MetafitsContext::new(metafits, None)? + } else { + return Err(VisSimulateArgsError::NoMetafits.into()); + }; + + let mut metadata_printer = InfoPrinter::new( + format!("Simulating visibilities for obsid {}", metafits.obs_id).into(), + ); + metadata_printer.push_line(format!("with {}", metafits.metafits_filename).into()); + metadata_printer.display(); + + let mut coord_printer = InfoPrinter::new("Coordinates".into()); + // Get the phase centre. + let phase_centre = match (ra, dec, &metafits) { + (Some(ra), Some(dec), _) => { + // Verify that the input coordinates are sensible. + if !(0.0..=360.0).contains(&ra) { + return Err(VisSimulateArgsError::RaInvalid.into()); + } + if !(-90.0..=90.0).contains(&dec) { + return Err(VisSimulateArgsError::DecInvalid.into()); + } + RADec::from_degrees(ra, dec) + } + (Some(_), None, _) => return Err(VisSimulateArgsError::OnlyOneRAOrDec.into()), + (None, Some(_), _) => return Err(VisSimulateArgsError::OnlyOneRAOrDec.into()), + (None, None, m) => { + // The phase centre in a metafits file may not be present. If not, + // we have to use the pointing centre. + match (m.ra_phase_center_degrees, m.dec_phase_center_degrees) { + (Some(ra), Some(dec)) => RADec::from_degrees(ra, dec), + (None, None) => { + RADec::from_degrees(m.ra_tile_pointing_degrees, m.dec_tile_pointing_degrees) + } + _ => unreachable!(), + } + } + }; + let mut block = vec![]; + block.push( + style(" RA Dec") + .bold() + .to_string() + .into(), + ); + if let Some((ra, dec)) = metafits + .ra_phase_center_degrees + .zip(metafits.dec_phase_center_degrees) + { + block.push(format!("Phase centre: {:>8.4}° {:>8.4}° (J2000)", ra, dec).into()); + } + block.push( + format!( + "Pointing centre: {:>8.4}° {:>8.4}°", + metafits.ra_tile_pointing_degrees, metafits.dec_tile_pointing_degrees + ) + .into(), + ); + coord_printer.push_block(block); + + // If the user supplied the array position, unpack it here. + let array_position = match array_position { + Some(v) => { + if v.len() != 3 { + return Err(VisSimulateArgsError::BadArrayPosition { pos: v }.into()); + } + LatLngHeight { + longitude_rad: v[0].to_radians(), + latitude_rad: v[1].to_radians(), + height_metres: v[2], + } + } + None => LatLngHeight::mwa(), + }; + coord_printer.push_line( + format!( + "Array position: {:>8.4}° {:>8.4}° {:.4}m", + array_position.longitude_rad.to_degrees(), + array_position.latitude_rad.to_degrees(), + array_position.height_metres + ) + .into(), + ); + coord_printer.display(); + + // Get the geodetic XYZ coordinates of each of the MWA tiles. + let tile_xyzs = XyzGeodetic::get_tiles(&metafits, array_position.latitude_rad); + let tile_names: Vec = metafits + .antennas + .iter() + .map(|a| a.tile_name.clone()) + .collect(); + + // Prepare a map between baselines and their constituent tiles. + let flagged_tiles = HashSet::new(); + let tile_baseline_flags = TileBaselineFlags::new(metafits.num_ants, flagged_tiles); + + let mut tile_printer = InfoPrinter::new("Tile info".into()); + tile_printer.push_line(format!("{} tiles", tile_xyzs.len()).into()); + tile_printer.display(); + + let time_res = Duration::from_seconds(time_res.unwrap_or(DEFAULT_TIME_RES_SECONDS)); + let timestamps = { + let num_timesteps = num_timesteps.unwrap_or(DEFAULT_NUM_TIMESTEPS); + let mut timestamps = Vec::with_capacity(num_timesteps); + let start_ns = metafits + .sched_start_gps_time_ms + .checked_mul(1_000_000) + .expect("does not overflow u64"); + let start = Epoch::from_gpst_nanoseconds(start_ns) + + time_res / 2 + + Duration::from_seconds(time_offset.unwrap_or_default()); + for i in 0..num_timesteps { + timestamps.push(start + time_res * i as i64); + } + Vec1::try_from_vec(timestamps).map_err(|_| VisSimulateArgsError::ZeroTimeSteps)? + }; + let dut1 = match (ignore_dut1, dut1) { + (true, _) => { + debug!("Ignoring metafits and user DUT1"); + Duration::default() + } + (false, Some(dut1)) => { + debug!("Using user DUT1"); + Duration::from_seconds(dut1) + } + (false, None) => { + debug!("Using metafits DUT1"); + metafits + .dut1 + .map(Duration::from_seconds) + .unwrap_or_default() + } + }; + let precession_info = precess_time( + array_position.longitude_rad, + array_position.latitude_rad, + phase_centre, + *timestamps.first(), + dut1, + ); + let (lst_rad, latitude_rad) = if !modelling_args.no_precession { + ( + precession_info.lmst_j2000, + precession_info.array_latitude_j2000, + ) + } else { + (precession_info.lmst, array_position.latitude_rad) + }; + + let mut time_printer = InfoPrinter::new("Time info".into()); + time_printer.push_line(format!("Simulating at resolution: {time_res}").into()); + time_printer.push_block(vec![ + format!("First timestamp: {}", timestamps.first()).into(), + format!( + "First timestamp (GPS): {}", + timestamps.first().to_gpst_seconds() + ) + .into(), + format!( + "Last timestamp (GPS): {}", + timestamps.last().to_gpst_seconds() + ) + .into(), + format!( + "First LMST: {:.6}° (J2000)", + precession_info.lmst_j2000.to_degrees() + ) + .into(), + ]); + time_printer.push_line(format!("DUT1: {:.10} s", dut1.to_seconds()).into()); + time_printer.display(); + + // Get the fine channel frequencies. + let freq_res = freq_res.unwrap_or(DEFAULT_FREQ_RES_KHZ); + let num_fine_channels = num_fine_channels.unwrap_or(DEFAULT_NUM_FINE_CHANNELS); + if freq_res < f64::EPSILON { + return Err(VisSimulateArgsError::FineChansWidthTooSmall.into()); + } + let middle_freq = middle_freq + .map(|f| f * 1e6) // MHz -> Hz + .unwrap_or(metafits.centre_freq_hz as _); + let freq_res = freq_res * 1e3; // kHz -> Hz + let fine_chan_freqs = { + let half_num_fine_chans = num_fine_channels as f64 / 2.0; + let mut fine_chan_freqs = Vec::with_capacity(num_fine_channels); + for i in 0..num_fine_channels { + fine_chan_freqs + .push(middle_freq - half_num_fine_chans * freq_res + freq_res * i as f64); + } + Vec1::try_from_vec(fine_chan_freqs).map_err(|_| VisSimulateArgsError::FineChansZero)? + }; + let coarse_chan_freqs = { + let (mut coarse_chan_freqs, mut coarse_chan_nums): (Vec, Vec) = + fine_chan_freqs + .iter() + .map(|&f| { + // MWA coarse channel numbers are a multiple of 1.28 MHz. + // This might change with MWAX, but ignore that until it + // becomes an issue; vis-simulate is mostly useful for + // testing. + let cc_num = (f / 1.28e6).round(); + (cc_num * 1.28e6, cc_num as u32) + }) + .unzip(); + // Deduplicate. As `fine_chan_freqs` is always sorted, we don't need + // to sort here. + coarse_chan_freqs.dedup(); + coarse_chan_nums.dedup(); + debug!("MWA coarse channel numbers: {coarse_chan_nums:?}"); + // Convert the coarse channel numbers to a range starting from 1. + coarse_chan_freqs + }; + debug!( + "Coarse channel centre frequencies [Hz]: {:?}", + coarse_chan_freqs + ); + + let mut chan_printer = InfoPrinter::new("Channel info".into()); + chan_printer + .push_line(format!("Simulating at resolution: {:.2} kHz", freq_res / 1e3).into()); + chan_printer.push_block(vec![ + format!("Number of fine channels: {num_fine_channels}").into(), + format!( + "First fine-channel: {:.3} MHz", + *fine_chan_freqs.first() / 1e6 + ) + .into(), + format!( + "Last fine-channel: {:.3} MHz", + *fine_chan_freqs.last() / 1e6 + ) + .into(), + ]); + chan_printer.display(); + + let beam = beam_args.parse( + metafits.num_ants, + Some(Delays::Full(get_dipole_delays(&metafits))), + Some(get_dipole_gains(&metafits)), + None, + )?; + let modelling_params = modelling_args.parse(); + + let source_list = srclist_args.parse( + phase_centre, + lst_rad, + latitude_rad, + &coarse_chan_freqs, + &*beam, + )?; + + // Apply any filters. + let source_list = if filter_points || filter_gaussians || filter_shapelets { + let sl = source_list.filter(filter_points, filter_gaussians, filter_shapelets); + let ComponentCounts { + num_points, + num_gaussians, + num_shapelets, + .. + } = sl.get_counts(); + info!( + "After filtering, there are {num_points} points, {num_gaussians} gaussians, {num_shapelets} shapelets" + ); + sl + } else { + source_list + }; + + // Parse the output model vis args like normal output vis args, to + // re-use existing code (we only make the args distinct to make it clear + // that these visibilities are not calibrated, just the model vis). + let output_vis_params = OutputVisArgs { + outputs: output_model_files, + output_vis_time_average: output_model_time_average, + output_vis_freq_average: output_model_freq_average, + } + .parse( + time_res, + freq_res, + ×tamps, + false, + DEFAULT_OUTPUT_VIS_FILENAME, + Some("simulated"), + )?; + + display_warnings(); + + Ok(VisSimulateParams { + source_list, + metafits, + output_vis_params, + phase_centre, + fine_chan_freqs, + freq_res_hz: freq_res, + tile_xyzs, + tile_names, + tile_baseline_flags, + timestamps, + time_res, + beam, + array_position, + dut1, + modelling_params, + }) + } + + pub(super) fn run(self, dry_run: bool) -> Result<(), HyperdriveError> { + debug!("Converting arguments into parameters"); + trace!("{:#?}", self); + let params = self.parse()?; + + if dry_run { + info!("Dry run -- exiting now."); + return Ok(()); + } + + params.run()?; + Ok(()) + } +} + +#[derive(Error, Debug)] +pub(super) enum VisSimulateArgsError { + #[error("No metafits file was supplied")] + NoMetafits, + + #[error("Metafits file '{0}' doesn't exist")] + MetafitsDoesntExist(Box), + + #[error("Right Ascension was not within 0 to 360!")] + RaInvalid, + + #[error("Declination was not within -90 to 90!")] + DecInvalid, + + #[error("One of RA and Dec was specified, but none or both are required!")] + OnlyOneRAOrDec, + + #[error("Number of fine channels cannot be 0!")] + FineChansZero, + + #[error("The fine channel resolution cannot be 0 or negative!")] + FineChansWidthTooSmall, + + #[error("Number of timesteps cannot be 0!")] + ZeroTimeSteps, + + #[error("Array position specified as {pos:?}, not [, , ]")] + BadArrayPosition { pos: Vec }, +} + +impl VisSimulateCliArgs { + fn merge(self, other: Self) -> Self { + Self { + metafits: self.metafits.or(other.metafits), + dut1: self.dut1.or(other.dut1), + ignore_dut1: self.ignore_dut1 || other.ignore_dut1, + ra: self.ra.or(other.ra), + dec: self.dec.or(other.dec), + num_fine_channels: self.num_fine_channels.or(other.num_fine_channels), + freq_res: self.freq_res.or(other.freq_res), + middle_freq: self.middle_freq.or(other.middle_freq), + num_timesteps: self.num_timesteps.or(other.num_timesteps), + time_res: self.time_res.or(other.time_res), + time_offset: self.time_offset.or(other.time_offset), + array_position: self.array_position.or(other.array_position), + output_model_files: self.output_model_files.or(other.output_model_files), + output_model_time_average: self + .output_model_time_average + .or(other.output_model_time_average), + output_model_freq_average: self + .output_model_freq_average + .or(other.output_model_freq_average), + filter_points: self.filter_points || other.filter_points, + filter_gaussians: self.filter_gaussians || other.filter_gaussians, + filter_shapelets: self.filter_shapelets || other.filter_shapelets, + } + } +} diff --git a/src/model/integration_tests.rs b/src/cli/vis_simulate/tests.rs similarity index 97% rename from src/model/integration_tests.rs rename to src/cli/vis_simulate/tests.rs index 5e51565a..760094c8 100644 --- a/src/model/integration_tests.rs +++ b/src/cli/vis_simulate/tests.rs @@ -68,18 +68,17 @@ fn test_1090008640_vis_simulate() { let temp_dir = TempDir::new().expect("couldn't make tmp dir"); let output_path = temp_dir.path().join("model.uvfits"); let args = get_reduced_1090008640(false, false); - let metafits = args.data.as_ref().unwrap()[0].clone(); + let metafits = args.data_args.files.as_ref().unwrap()[0].clone(); #[rustfmt::skip] let sim_args = VisSimulateArgs::parse_from([ "vis-simulate", "--metafits", &metafits, - "--source-list", &args.source_list.unwrap(), + "--source-list", &args.srclist_args.source_list.unwrap(), "--output-model-files", &format!("{}", output_path.display()), "--num-timesteps", &format!("{num_timesteps}"), "--num-fine-channels", &format!("{num_chans}"), "--veto-threshold", "0.0", // Don't complicate things with vetoing - "--no-progress-bars" ]); // Run vis-simulate and check that it succeeds @@ -349,16 +348,15 @@ fn test_1090008640_vis_simulate_cpu_gpu_match() { let temp_dir = TempDir::new().expect("couldn't make tmp dir"); let output_path = temp_dir.path().join("model.uvfits"); let args = get_reduced_1090008640(false, false); - let metafits = args.data.as_ref().unwrap()[0].clone(); + let metafits = args.data_args.files.as_ref().unwrap()[0].clone(); #[rustfmt::skip] let sim_args = VisSimulateArgs::parse_from([ "vis-simulate", "--metafits", &metafits, - "--source-list", &args.source_list.unwrap(), + "--source-list", &args.srclist_args.source_list.unwrap(), "--output-model-files", &format!("{}", output_path.display()), "--num-timesteps", &format!("{num_timesteps}"), "--num-fine-channels", &format!("{num_chans}"), - "--no-progress-bars", "--cpu", ]); let result = sim_args.run(false); @@ -398,16 +396,15 @@ fn test_1090008640_vis_simulate_cpu_gpu_match() { drop(uvfits); let args = get_reduced_1090008640(false, false); - let metafits = args.data.as_ref().unwrap()[0].clone(); + let metafits = args.data_args.files.as_ref().unwrap()[0].clone(); #[rustfmt::skip] let sim_args = VisSimulateArgs::parse_from([ "vis-simulate", "--metafits", &metafits, - "--source-list", &args.source_list.unwrap(), + "--source-list", &args.srclist_args.source_list.unwrap(), "--output-model-files", &format!("{}", output_path.display()), "--num-timesteps", &format!("{num_timesteps}"), "--num-fine-channels", &format!("{num_chans}"), - "--no-progress-bars" ]); // Run vis-simulate and check that it succeeds diff --git a/src/cli/vis_subtract/mod.rs b/src/cli/vis_subtract/mod.rs new file mode 100644 index 00000000..192f506f --- /dev/null +++ b/src/cli/vis_subtract/mod.rs @@ -0,0 +1,405 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +use std::{borrow::Cow, path::PathBuf, str::FromStr}; + +use clap::Parser; +use log::{debug, info, trace}; +use marlu::{precession::precess_time, LatLngHeight}; +use serde::{Deserialize, Serialize}; + +use super::common::{ + display_warnings, BeamArgs, InputVisArgs, ModellingArgs, OutputVisArgs, SkyModelWithVetoArgs, + ARG_FILE_HELP, +}; +use crate::{ + cli::common::InfoPrinter, + constants::{DEFAULT_CUTOFF_DISTANCE, DEFAULT_VETO_THRESHOLD}, + io::{get_single_match_from_glob, write::VIS_OUTPUT_EXTENSIONS}, + params::{ModellingParams, VisSubtractParams}, + srclist::{ + read::read_source_list_file, veto_sources, ComponentCounts, ReadSourceListError, + SourceList, SourceListType, + }, + HyperdriveError, +}; + +const DEFAULT_OUTPUT_VIS_FILENAME: &str = "hyp_subtracted.uvfits"; + +lazy_static::lazy_static! { + static ref OUTPUTS_HELP: String = + format!("Paths to the subtracted visibility files. Supported formats: {}. Default: {}", *VIS_OUTPUT_EXTENSIONS, DEFAULT_OUTPUT_VIS_FILENAME); +} + +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +struct VisSubtractCliArgs { + /// Invert the subtraction; sources *not* specified in sources-to-subtract + /// will be subtracted from the input data. + #[clap(short, long, help_heading = "SKY-MODEL SOURCES")] + #[serde(default)] + invert: bool, + + /// The names of the sources in the sky-model source list that will be + /// subtracted from the input data. + #[clap(long, multiple_values(true), help_heading = "SKY-MODEL SOURCES")] + sources_to_subtract: Option>, + + #[clap( + short = 'o', + long, + multiple_values(true), + help = OUTPUTS_HELP.as_str(), + help_heading = "OUTPUT FILES" + )] + outputs: Option>, + + /// When writing out visibilities, average this many timesteps together. + /// Also supports a target time resolution (e.g. 8s). The value must be a + /// multiple of the input data's time resolution. The default is no + /// averaging, i.e. a value of 1. Examples: If the input data is in 0.5s + /// resolution and this variable is 4, then we average 2s worth of data + /// together before writing the data out. If the variable is instead 4s, + /// then 8 timesteps are averaged together before writing the data out. + #[clap(long, help_heading = "OUTPUT FILES")] + output_vis_time_average: Option, + + /// When writing out visibilities, average this many fine freq. channels + /// together. Also supports a target freq. resolution (e.g. 80kHz). The + /// value must be a multiple of the input data's freq. resolution. The + /// default is no averaging, i.e. a value of 1. Examples: If the input data + /// is in 40kHz resolution and this variable is 4, then we average 160kHz + /// worth of data together before writing the data out. If the variable is + /// instead 80kHz, then 2 fine freq. channels are averaged together before + /// writing the data out. + #[clap(long, help_heading = "OUTPUT FILES")] + output_vis_freq_average: Option, + + /// Rather than writing out the entire input bandwidth, write out only the + /// smallest contiguous band. e.g. Typical 40 kHz MWA data has 768 channels, + /// but the first 2 and last 2 channels are usually flagged. Turning this + /// option on means that 764 channels would be written out instead of 768. + /// Note that other flagged channels in the band are unaffected, because the + /// data written out must be contiguous. + #[clap(long, help_heading = "OUTPUT FILES")] + #[serde(default)] + output_smallest_contiguous_band: bool, +} + +#[derive(Parser, Debug, Clone, Default, Serialize, Deserialize)] +pub(super) struct VisSubtractArgs { + #[clap(name = "ARGUMENTS_FILE", help = ARG_FILE_HELP.as_str(), parse(from_os_str))] + args_file: Option, + + #[clap(flatten)] + #[serde(rename = "data")] + #[serde(default)] + data_args: InputVisArgs, + + #[clap(flatten)] + #[serde(rename = "sky-model")] + #[serde(default)] + srclist_args: SkyModelWithVetoArgs, + + #[clap(flatten)] + #[serde(rename = "model")] + #[serde(default)] + modelling_args: ModellingArgs, + + #[clap(flatten)] + #[serde(rename = "beam")] + #[serde(default)] + beam_args: BeamArgs, + + #[clap(flatten)] + #[serde(rename = "vis-subtract")] + #[serde(default)] + vis_subtract_args: VisSubtractCliArgs, +} + +impl VisSubtractArgs { + /// Both command-line and file arguments overlap in terms of what is + /// available; this function consolidates everything that was specified into + /// a single struct. Where applicable, it will prefer CLI parameters over + /// those in the file. + /// + /// The argument to this function is the path to the arguments file. + /// + /// This function should only ever merge arguments, and not try to make + /// sense of them. + pub(super) fn merge(self) -> Result { + debug!("Merging command-line arguments with the argument file"); + + let cli_args = self; + + if let Some(arg_file) = cli_args.args_file { + // Read in the file arguments. Ensure all of the file args are + // accounted for by pattern matching. + let VisSubtractArgs { + args_file: _, + data_args, + srclist_args, + modelling_args, + beam_args, + vis_subtract_args, + } = unpack_arg_file!(arg_file); + + // Merge all the arguments, preferring the CLI args when available. + Ok(VisSubtractArgs { + args_file: None, + data_args: cli_args.data_args.merge(data_args), + srclist_args: cli_args.srclist_args.merge(srclist_args), + modelling_args: cli_args.modelling_args.merge(modelling_args), + beam_args: cli_args.beam_args.merge(beam_args), + vis_subtract_args: cli_args.vis_subtract_args.merge(vis_subtract_args), + }) + } else { + Ok(cli_args) + } + } + + fn parse(self) -> Result { + debug!("{:#?}", self); + + let Self { + args_file: _, + data_args, + srclist_args, + modelling_args, + beam_args, + vis_subtract_args: + VisSubtractCliArgs { + invert, + sources_to_subtract, + outputs, + output_vis_time_average, + output_vis_freq_average, + output_smallest_contiguous_band, + }, + } = self; + + let input_vis_params = data_args.parse("Vis subtracting")?; + let obs_context = input_vis_params.get_obs_context(); + let total_num_tiles = obs_context.get_total_num_tiles(); + + let beam = beam_args.parse( + total_num_tiles, + obs_context.dipole_delays.clone(), + obs_context.dipole_gains.clone(), + Some(obs_context.input_data_type), + )?; + let modelling_params @ ModellingParams { + apply_precession, .. + } = modelling_args.parse(); + + let LatLngHeight { + longitude_rad, + latitude_rad, + height_metres: _, + } = obs_context.array_position; + let precession_info = precess_time( + longitude_rad, + latitude_rad, + obs_context.phase_centre, + input_vis_params.timeblocks.first().median, + input_vis_params.dut1, + ); + let (lmst, latitude) = if apply_precession { + ( + precession_info.lmst_j2000, + precession_info.array_latitude_j2000, + ) + } else { + (precession_info.lmst, latitude_rad) + }; + + // If we're not inverted but `sources_to_subtract` is empty, then there's + // nothing to do. + let sources_to_subtract = sources_to_subtract.unwrap_or_default(); + if !invert && sources_to_subtract.is_empty() { + return Err(VisSubtractArgsError::NoSources.into()); + } + + // Read in the source list and remove all but the specified sources. We + // have to parse the arguments manually as we're doing custom stuff here + // in vis-subtract. + let SkyModelWithVetoArgs { + source_list, + source_list_type, + num_sources, + source_dist_cutoff, + veto_threshold, + } = srclist_args; + + let source_list: SourceList = { + let source_list = source_list.ok_or(ReadSourceListError::NoSourceList)?; + // If the specified source list file can't be found, treat it as a glob + // and expand it to find a match. + let pb = PathBuf::from(&source_list); + let pb = if pb.exists() { + pb + } else { + get_single_match_from_glob(&source_list) + .map_err(|e| HyperdriveError::Generic(e.to_string()))? + }; + + // Read the source list file. If the type was manually specified, + // use that, otherwise the reading code will try all available + // kinds. + let sl_type_not_specified = source_list_type.is_none(); + let sl_type = source_list_type + .as_ref() + .and_then(|t| SourceListType::from_str(t.as_ref()).ok()); + let (sl, sl_type) = read_source_list_file(pb, sl_type)?; + + // If the user didn't specify the source list type, then print out + // what we found. + if sl_type_not_specified { + trace!("Successfully parsed {}-style source list", sl_type); + } + if num_sources == Some(0) || sl.is_empty() { + return Err(ReadSourceListError::NoSources.into()); + } + sl + }; + debug!("Found {} sources in the source list", source_list.len()); + let ComponentCounts { + num_points, + num_gaussians, + num_shapelets, + .. + } = source_list.get_counts(); + let mut sl_printer = InfoPrinter::new("Sky model info".into()); + sl_printer.push_block(vec![ + format!("Source list contains {} sources", source_list.len()).into(), + format!("({} components, {num_points} points, {num_gaussians} Gaussians, {num_shapelets} shapelets)", num_points + num_gaussians + num_shapelets).into() + ]); + + // Ensure that all specified sources are actually in the source list. + for name in &sources_to_subtract { + if !source_list.contains_key(name) { + return Err(HyperdriveError::from(VisSubtractArgsError::MissingSource { + name: name.to_string().into(), + })); + } + } + // Handle the invert option. + let source_list: SourceList = if invert { + let mut sl: SourceList = source_list + .into_iter() + .filter(|(name, _)| !sources_to_subtract.contains(name)) + .collect(); + if sl.is_empty() { + // Nothing to do. + return Err(VisSubtractArgsError::AllSourcesFiltered.into()); + } + veto_sources( + &mut sl, + obs_context.phase_centre, + lmst, + latitude, + &obs_context.get_veto_freqs(), + &*beam, + num_sources, + source_dist_cutoff.unwrap_or(DEFAULT_CUTOFF_DISTANCE), + veto_threshold.unwrap_or(DEFAULT_VETO_THRESHOLD), + )?; + if sl.is_empty() { + return Err(ReadSourceListError::NoSourcesAfterVeto.into()); + } + sl + } else { + source_list + .into_iter() + .filter(|(name, _)| sources_to_subtract.contains(name)) + .collect() + }; + let ComponentCounts { + num_points, + num_gaussians, + num_shapelets, + num_power_laws: _, + num_curved_power_laws: _, + num_lists: _, + } = source_list.get_counts(); + sl_printer.push_block(vec![ + format!( + "Subtracting {} sources with a total of {} components", + source_list.len(), + num_points + num_gaussians + num_shapelets + ) + .into(), + format!("{num_points} points, {num_gaussians} Gaussians, {num_shapelets} shapelets") + .into(), + ]); + sl_printer.display(); + + let output_vis_params = OutputVisArgs { + outputs, + output_vis_time_average, + output_vis_freq_average, + } + .parse( + input_vis_params.time_res, + input_vis_params.spw.freq_res, + &input_vis_params.timeblocks.mapped_ref(|tb| tb.median), + output_smallest_contiguous_band, + DEFAULT_OUTPUT_VIS_FILENAME, + Some("subtracted"), + )?; + + display_warnings(); + + Ok(VisSubtractParams { + input_vis_params, + output_vis_params, + beam, + source_list, + modelling_params, + }) + } + + pub(super) fn run(self, dry_run: bool) -> Result<(), HyperdriveError> { + debug!("Converting arguments into parameters"); + trace!("{:#?}", self); + let params = self.parse()?; + + if dry_run { + info!("Dry run -- exiting now."); + return Ok(()); + } + + params.run()?; + Ok(()) + } +} + +#[derive(thiserror::Error, Debug)] +pub(super) enum VisSubtractArgsError { + #[error("Specified source {name} is not in the input source list; can't subtract it")] + MissingSource { name: Cow<'static, str> }, + + #[error("No sources were specified for subtraction. Did you want to subtract all sources? See the \"invert\" option.")] + NoSources, + + #[error("No sources were left after removing specified sources from the source list.")] + AllSourcesFiltered, +} + +impl VisSubtractCliArgs { + fn merge(self, other: Self) -> Self { + Self { + invert: self.invert || other.invert, + sources_to_subtract: self.sources_to_subtract.or(other.sources_to_subtract), + outputs: self.outputs.or(other.outputs), + output_vis_time_average: self + .output_vis_time_average + .or(other.output_vis_time_average), + output_vis_freq_average: self + .output_vis_freq_average + .or(other.output_vis_freq_average), + output_smallest_contiguous_band: self.output_smallest_contiguous_band + || other.output_smallest_contiguous_band, + } + } +} diff --git a/src/cli/vis_utils/subtract/tests.rs b/src/cli/vis_subtract/tests.rs similarity index 97% rename from src/cli/vis_utils/subtract/tests.rs rename to src/cli/vis_subtract/tests.rs index 2834afe1..fd44e913 100644 --- a/src/cli/vis_utils/subtract/tests.rs +++ b/src/cli/vis_subtract/tests.rs @@ -15,7 +15,7 @@ use vec1::vec1; use super::*; use crate::{ - cli::vis_utils::simulate::VisSimulateArgs, + cli::vis_simulate::VisSimulateArgs, io::read::fits::{fits_open, fits_open_hdu}, srclist::{ComponentType, FluxDensity, FluxDensityType, Source, SourceComponent, SourceList}, tests::reduced_obsids::get_reduced_1090008640, @@ -32,8 +32,8 @@ fn test_1090008640_vis_subtract() { let subtracted = temp_dir.path().join("subtracted.uvfits"); let mut args = get_reduced_1090008640(false, false); - args.no_beam = true; - let metafits = args.data.as_ref().unwrap()[0].as_str(); + args.beam_args.no_beam = true; + let metafits = args.data_args.files.as_ref().unwrap()[0].as_str(); let mut srclist = SourceList::new(); srclist.insert( "src1".to_string(), @@ -87,7 +87,6 @@ fn test_1090008640_vis_subtract() { "--output-model-files", &format!("{}", model_1.display()), "--num-timesteps", &format!("{num_timesteps}"), "--num-fine-channels", &format!("{num_chans}"), - "--no-progress-bars" ]); let result = sim_args.run(false); assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); @@ -100,7 +99,6 @@ fn test_1090008640_vis_subtract() { "--output-model-files", &format!("{}", model_2.display()), "--num-timesteps", &format!("{num_timesteps}"), "--num-fine-channels", &format!("{num_chans}"), - "--no-progress-bars" ]); let result = sim_args.run(false); assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); @@ -115,7 +113,6 @@ fn test_1090008640_vis_subtract() { "--outputs", &format!("{}", subtracted.display()), "--source-list", &format!("{}", source_list_2.display()), "--sources-to-subtract", "src2", - "--no-progress-bars", ]); let result = sub_args.run(false); assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); @@ -190,7 +187,6 @@ fn test_1090008640_vis_subtract() { "--outputs", &format!("{}", subtracted.display()), "--source-list", &format!("{}", source_list_2.display()), "--sources-to-subtract", "src1", "src2", - "--no-progress-bars", ]); let result = sub_args.run(false); assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); diff --git a/src/cli/vis_utils/mod.rs b/src/cli/vis_utils/mod.rs deleted file mode 100644 index bde84caa..00000000 --- a/src/cli/vis_utils/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//! Utilities surrounding visibilities. - -pub(crate) mod simulate; -pub(crate) mod subtract; diff --git a/src/cli/vis_utils/simulate/error.rs b/src/cli/vis_utils/simulate/error.rs deleted file mode 100644 index 4ec4fb37..00000000 --- a/src/cli/vis_utils/simulate/error.rs +++ /dev/null @@ -1,83 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//! Error type for all errors related to vis-simulate. - -use std::path::PathBuf; - -use thiserror::Error; - -use crate::io::write::VIS_OUTPUT_EXTENSIONS; - -#[derive(Error, Debug)] -pub(crate) enum VisSimulateError { - #[error("Right Ascension was not within 0 to 360!")] - RaInvalid, - - #[error("Declination was not within -90 to 90!")] - DecInvalid, - - #[error("One of RA and Dec was specified, but none or both are required!")] - OnlyOneRAOrDec, - - #[error("Number of fine channels cannot be 0!")] - FineChansZero, - - #[error("The fine channel resolution cannot be 0 or negative!")] - FineChansWidthTooSmall, - - #[error("Number of timesteps cannot be 0!")] - ZeroTimeSteps, - - #[error( - "The specified MWA dipole delays aren't valid; there should be 16 values between 0 and 32" - )] - BadDelays, - - #[error( - "An invalid output format was specified ({0}). Supported:\n{}", - *VIS_OUTPUT_EXTENSIONS, - )] - InvalidOutputFormat(PathBuf), - - #[error("Array position specified as {pos:?}, not [, , ]")] - BadArrayPosition { pos: Vec }, - - #[error("After vetoing sources, none were left. Decrease the veto threshold, or supply more sources")] - NoSourcesAfterVeto, - - #[error(transparent)] - FileWrite(#[from] crate::io::write::FileWriteError), - - #[error(transparent)] - AverageFactor(#[from] crate::averaging::AverageFactorError), - - #[error(transparent)] - SourceList(#[from] crate::srclist::ReadSourceListError), - - #[error(transparent)] - Veto(#[from] crate::srclist::VetoError), - - #[error(transparent)] - Beam(#[from] crate::beam::BeamError), - - #[error(transparent)] - VisWrite(#[from] crate::io::write::VisWriteError), - - #[error(transparent)] - Glob(#[from] crate::io::GlobError), - - #[error(transparent)] - Mwalib(#[from] mwalib::MwalibError), - - #[error(transparent)] - Model(#[from] crate::model::ModelError), - - #[error(transparent)] - IO(#[from] std::io::Error), - - #[cfg(feature = "cuda")] - #[error(transparent)] - Cuda(#[from] crate::cuda::CudaError), -} diff --git a/src/cli/vis_utils/subtract/error.rs b/src/cli/vis_utils/subtract/error.rs deleted file mode 100644 index 56a784bd..00000000 --- a/src/cli/vis_utils/subtract/error.rs +++ /dev/null @@ -1,127 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//! Error type for all errors related to vis-subtract. - -use std::path::PathBuf; - -use thiserror::Error; -use vec1::Vec1; - -use crate::{ - filenames::SUPPORTED_CALIBRATED_INPUT_FILE_COMBINATIONS, io::write::VIS_OUTPUT_EXTENSIONS, -}; - -#[derive(Error, Debug)] -pub(crate) enum VisSubtractError { - #[error("Specified source {name} is not in the input source list; can't subtract it")] - MissingSource { name: String }, - - #[error("No sources were specified for subtraction. Did you want to subtract all sources? See the \"invert\" option.")] - NoSources, - - #[error("No sources were left after removing specified sources from the source list.")] - AllSourcesFiltered, - - #[error("After vetoing sources, none were left. Decrease the veto threshold, or supply more sources")] - NoSourcesAfterVeto, - - #[error("Tried to create a beam object, but MWA dipole delay information isn't available!")] - NoDelays, - - #[error( - "The specified MWA dipole delays aren't valid; there should be 16 values between 0 and 32" - )] - BadDelays, - - #[error("No input data was given!")] - NoInputData, - - #[error( - "{0}\n\nSupported combinations of file formats:\n{SUPPORTED_CALIBRATED_INPUT_FILE_COMBINATIONS}", - )] - InvalidDataInput(&'static str), - - #[error("The data either contains no timesteps or no timesteps are being used")] - NoTimesteps, - - #[error("Duplicate timesteps were specified; this is invalid")] - DuplicateTimesteps, - - #[error("Timestep {got} was specified but it isn't available; the last timestep is {last}")] - UnavailableTimestep { got: usize, last: usize }, - - #[error( - "An invalid output format was specified ({0}). Supported:\n{}", - *VIS_OUTPUT_EXTENSIONS, - )] - InvalidOutputFormat(PathBuf), - - #[error("Error when parsing output vis. time average factor: {0}")] - ParseOutputVisTimeAverageFactor(crate::unit_parsing::UnitParseError), - - #[error("Error when parsing output vis. freq. average factor: {0}")] - ParseOutputVisFreqAverageFactor(crate::unit_parsing::UnitParseError), - - #[error("Output vis. time average factor isn't an integer")] - OutputVisTimeFactorNotInteger, - - #[error("Output vis. freq. average factor isn't an integer")] - OutputVisFreqFactorNotInteger, - - #[error("Output vis. time average factor cannot be 0")] - OutputVisTimeAverageFactorZero, - - #[error("Output vis. freq. average factor cannot be 0")] - OutputVisFreqAverageFactorZero, - - #[error("Output vis. time resolution isn't a multiple of input data's: {out} seconds vs {inp} seconds")] - OutputVisTimeResNotMultiple { out: f64, inp: f64 }, - - #[error("Output vis. freq. resolution isn't a multiple of input data's: {out} Hz vs {inp} Hz")] - OutputVisFreqResNotMultiple { out: f64, inp: f64 }, - - #[error("Multiple metafits files were specified: {0:?}\nThis is unsupported.")] - MultipleMetafits(Vec1), - - #[error("Multiple measurement sets were specified: {0:?}\nThis is unsupported.")] - MultipleMeasurementSets(Vec1), - - #[error("Multiple uvfits files were specified: {0:?}\nThis is unsupported.")] - MultipleUvfits(Vec1), - - #[error("Array position specified as {pos:?}, not [, , ]")] - BadArrayPosition { pos: Vec }, - - #[error(transparent)] - Veto(#[from] crate::srclist::VetoError), - - #[error(transparent)] - VisRead(#[from] crate::io::read::VisReadError), - - #[error(transparent)] - Glob(#[from] crate::io::GlobError), - - #[error(transparent)] - VisWrite(#[from] crate::io::write::VisWriteError), - - #[error(transparent)] - FileWrite(#[from] crate::io::write::FileWriteError), - - #[error(transparent)] - SourceList(#[from] crate::srclist::ReadSourceListError), - - #[error(transparent)] - Beam(#[from] crate::beam::BeamError), - - #[error(transparent)] - Model(#[from] crate::model::ModelError), - - #[error(transparent)] - IO(#[from] std::io::Error), - - #[cfg(feature = "cuda")] - #[error(transparent)] - Cuda(#[from] crate::cuda::CudaError), -} diff --git a/src/cli/vis_utils/subtract/mod.rs b/src/cli/vis_utils/subtract/mod.rs deleted file mode 100644 index 1f67dfd9..00000000 --- a/src/cli/vis_utils/subtract/mod.rs +++ /dev/null @@ -1,1015 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//! Given input data, a sky model and specific sources, subtract those specific -//! sources from the input data and write them out. - -mod error; - -pub(crate) use error::VisSubtractError; -#[cfg(test)] -mod tests; - -use std::{ - collections::HashSet, - path::{Path, PathBuf}, - str::FromStr, - thread, -}; - -use clap::Parser; -use crossbeam_channel::{bounded, Receiver, Sender}; -use crossbeam_utils::atomic::AtomicCell; -use hifitime::Duration; -use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; -use itertools::Itertools; -use log::{debug, info, warn}; -use marlu::{precession::precess_time, Jones, LatLngHeight}; -use ndarray::{prelude::*, ArcArray2}; -use scopeguard::defer_on_unwind; -use vec1::{vec1, Vec1}; - -use crate::{ - averaging::{ - parse_freq_average_factor, parse_time_average_factor, timesteps_to_timeblocks, - AverageFactorError, - }, - beam::{create_fee_beam_object, create_no_beam_object, Beam, Delays}, - constants::{DEFAULT_CUTOFF_DISTANCE, DEFAULT_VETO_THRESHOLD}, - context::ObsContext, - filenames::InputDataTypes, - help_texts::{ - ARRAY_POSITION_HELP, DIPOLE_DELAYS_HELP, MS_DATA_COL_NAME_HELP, - SOURCE_DIST_CUTOFF_HELP as sdc_help, SOURCE_LIST_TYPE_HELP, VETO_THRESHOLD_HELP as vt_help, - }, - io::{ - get_single_match_from_glob, - read::{MsReader, UvfitsReader, VisInputType, VisRead, VisReadError}, - write::{can_write_to_file, write_vis, VisOutputType, VisTimestep, VIS_OUTPUT_EXTENSIONS}, - }, - math::TileBaselineFlags, - messages, - model::ModellerInfo, - srclist::{read::read_source_list_file, veto_sources, SourceList, SourceListType}, - HyperdriveError, -}; - -pub(crate) const DEFAULT_OUTPUT_VIS_FILENAME: &str = "hyp_subtracted.uvfits"; - -lazy_static::lazy_static! { - static ref OUTPUTS_HELP: String = - format!("Paths to the output visibility files. Supported formats: {}. Default: {}", *VIS_OUTPUT_EXTENSIONS, DEFAULT_OUTPUT_VIS_FILENAME); - - static ref SOURCE_DIST_CUTOFF_HELP: String = - format!("{}. Only useful if subtraction is inverted.", *sdc_help); - - static ref VETO_THRESHOLD_HELP: String = - format!("{}. Only useful if subtraction is inverted.", *vt_help); -} - -#[derive(Parser, Debug, Default)] -pub struct VisSubtractArgs { - /// Paths to the input data files to have visibilities subtracted. These can - /// include a metafits file, a measurement set and/or uvfits files. - #[clap(short, long, multiple_values(true), help_heading = "INPUT FILES")] - data: Vec, - - /// Path to the sky-model source list used for simulation. - #[clap(short, long, help_heading = "INPUT FILES")] - source_list: String, - - #[clap(long, help = SOURCE_LIST_TYPE_HELP.as_str(), help_heading = "INPUT FILES")] - source_list_type: Option, - - /// The timesteps to use from the input data. The default is to use all - /// timesteps, including flagged ones. - #[clap(long, multiple_values(true), help_heading = "INPUT FILES")] - timesteps: Option>, - - #[clap(long, help = MS_DATA_COL_NAME_HELP, help_heading = "INPUT FILES")] - ms_data_column_name: Option, - - /// Use a DUT1 value of 0 seconds rather than what is in the input data. - #[clap(long, help_heading = "INPUT FILES")] - ignore_dut1: bool, - - #[clap( - short = 'o', - long, - multiple_values(true), - help = OUTPUTS_HELP.as_str(), - help_heading = "OUTPUT FILES" - )] - outputs: Vec, - - /// When writing out subtracted visibilities, average this many timesteps - /// together. Also supports a target time resolution (e.g. 8s). The value - /// must be a multiple of the input data's time resolution. The default is - /// to preserve the input data's time resolution. e.g. If the input data is - /// in 0.5s resolution and this variable is 4, then we average 2s worth of - /// subtracted data together before writing the data out. If the variable is - /// instead 4s, then 8 subtracted timesteps are averaged together before - /// writing the data out. - #[clap(long, help_heading = "OUTPUT FILES")] - time_average: Option, - - /// When writing out subtracted visibilities, average this many fine freq. - /// channels together. Also supports a target freq. resolution (e.g. 80kHz). - /// The value must be a multiple of the input data's freq. resolution. The - /// default is to preserve the input data's freq. resolution. e.g. If the - /// input data is in 40kHz resolution and this variable is 4, then we - /// average 160kHz worth of subtracted data together before writing the data - /// out. If the variable is instead 80kHz, then 2 subtracted fine freq. - /// channels are averaged together before writing the data out. - #[clap(long, help_heading = "OUTPUT FILES")] - freq_average: Option, - - /// The names of the sources in the sky-model source list that will be - /// subtracted from the input data. - #[clap(long, multiple_values(true), help_heading = "SKY-MODEL SOURCES")] - sources_to_subtract: Vec, - - /// Invert the subtraction; sources *not* specified in sources-to-subtract - /// will be subtracted from the input data. - #[clap(short, long, help_heading = "SKY-MODEL SOURCES")] - invert: bool, - - /// The number of sources to use in the source list. Only useful if - /// subtraction is inverted. The default is to use all sources in the source - /// list. Example: If 1000 sources are specified here, then the top 1000 - /// sources *after* removing specified sources are subtracted. Standard veto - /// rules apply (sources are ranked based on their flux densities after the - /// beam attenuation, must be within the specified source distance cutoff - /// and above the horizon). - #[clap(short, long, help_heading = "SKY-MODEL SOURCES")] - num_sources: Option, - - #[clap(long, help = SOURCE_DIST_CUTOFF_HELP.as_str(), help_heading = "SKY-MODEL SOURCES")] - source_dist_cutoff: Option, - - #[clap(long, help = VETO_THRESHOLD_HELP.as_str(), help_heading = "SKY-MODEL SOURCES")] - veto_threshold: Option, - - /// Should we use a beam? Default is to use the FEE beam. - #[clap(long, help_heading = "MODEL PARAMETERS")] - no_beam: bool, - - /// The path to the HDF5 MWA FEE beam file. If not specified, this must be - /// provided by the MWA_BEAM_FILE environment variable. - #[clap(long, help_heading = "MODEL PARAMETERS")] - beam_file: Option, - - /// Pretend that all MWA dipoles are alive and well, ignoring whatever is in - /// the metafits file. - #[clap(long, help_heading = "MODEL PARAMETERS")] - unity_dipole_gains: bool, - - #[clap(long, multiple_values(true), help = DIPOLE_DELAYS_HELP.as_str(), help_heading = "MODEL PARAMETERS")] - delays: Option>, - - #[clap( - long, help = ARRAY_POSITION_HELP.as_str(), help_heading = "MODEL PARAMETERS", - number_of_values = 3, - allow_hyphen_values = true, - value_names = &["LONG_DEG", "LAT_DEG", "HEIGHT_M"] - )] - array_position: Option>, - - /// If specified, don't precess the array to J2000. We assume that sky-model - /// sources are specified in the J2000 epoch. - #[clap(long, help_heading = "MODEL PARAMETERS")] - no_precession: bool, - - /// Use the CPU for visibility generation. This is deliberately made - /// non-default because using a GPU is much faster. - #[cfg(feature = "cuda")] - #[clap(long, help_heading = "MODEL PARAMETERS")] - cpu: bool, - - /// Don't draw progress bars. - #[clap(long, help_heading = "USER INTERFACE")] - no_progress_bars: bool, -} - -impl VisSubtractArgs { - pub fn run(self, dry_run: bool) -> Result<(), HyperdriveError> { - vis_subtract(self, dry_run)?; - Ok(()) - } -} - -fn vis_subtract(args: VisSubtractArgs, dry_run: bool) -> Result<(), VisSubtractError> { - debug!("{:#?}", args); - - // Expose all the struct fields to ensure they're all used. - let VisSubtractArgs { - data, - source_list, - source_list_type, - timesteps, - ms_data_column_name, - ignore_dut1, - outputs, - time_average, - freq_average, - sources_to_subtract, - invert, - num_sources, - source_dist_cutoff, - veto_threshold, - no_beam, - beam_file, - unity_dipole_gains, - delays, - array_position, - no_precession, - #[cfg(feature = "cuda")] - cpu: use_cpu_for_modelling, - no_progress_bars, - } = args; - - // If we're going to use a GPU for modelling, get the device info so we - // can ensure a CUDA-capable device is available, and so we can report - // it to the user later. - #[cfg(feature = "cuda")] - let modeller_info = if use_cpu_for_modelling { - ModellerInfo::Cpu - } else { - let (device_info, driver_info) = crate::cuda::get_device_info()?; - ModellerInfo::Cuda { - device_info, - driver_info, - } - }; - #[cfg(not(feature = "cuda"))] - let modeller_info = ModellerInfo::Cpu; - - // If we're not inverted but `sources_to_subtract` is empty, then there's - // nothing to do. - if !invert && sources_to_subtract.is_empty() { - return Err(VisSubtractError::NoSources); - } - - // Read in the source list and remove all but the specified sources. - let source_list: SourceList = { - // If the specified source list file can't be found, treat it as a glob - // and expand it to find a match. - let pb = PathBuf::from(&source_list); - let pb = if pb.exists() { - pb - } else { - get_single_match_from_glob(&source_list)? - }; - - // Read the source list file. If the type was manually specified, - // use that, otherwise the reading code will try all available - // kinds. - let sl_type = source_list_type - .as_ref() - .and_then(|t| SourceListType::from_str(t.as_ref()).ok()); - let (sl, _) = match crate::misc::expensive_op( - || read_source_list_file(pb, sl_type), - "Still reading source list file", - ) { - Ok((sl, sl_type)) => (sl, sl_type), - Err(e) => return Err(VisSubtractError::from(e)), - }; - - sl - }; - debug!("Found {} sources in the source list", source_list.len()); - - // Ensure that all specified sources are actually in the source list. - for name in &sources_to_subtract { - if !source_list.contains_key(name) { - return Err(VisSubtractError::MissingSource { - name: name.to_owned(), - }); - } - } - - // If the user supplied the array position, unpack it here. - let array_position = match array_position { - Some(pos) => { - if pos.len() != 3 { - return Err(VisSubtractError::BadArrayPosition { pos }); - } - Some(LatLngHeight { - longitude_rad: pos[0].to_radians(), - latitude_rad: pos[1].to_radians(), - height_metres: pos[2], - }) - } - None => None, - }; - - // Prepare an input data reader. - let input_data_types = InputDataTypes::new(&data)?; - let input_data: Box = match ( - input_data_types.metafits, - input_data_types.gpuboxes, - input_data_types.mwafs, - input_data_types.ms, - input_data_types.uvfits, - ) { - // Valid input for reading a measurement set. - (meta, None, None, Some(ms), None) => { - // Only one MS is supported at the moment. - let ms: PathBuf = if ms.len() > 1 { - return Err(VisSubtractError::MultipleMeasurementSets(ms)); - } else { - ms.first().clone() - }; - - // Ensure that there's only one metafits. - let meta: Option<&Path> = match meta.as_ref() { - None => None, - Some(m) => { - if m.len() > 1 { - return Err(VisSubtractError::MultipleMetafits(m.clone())); - } else { - Some(m.first().as_path()) - } - } - }; - - let input_data = MsReader::new(ms, ms_data_column_name, meta, array_position) - .map_err(VisReadError::from)?; - match input_data.get_obs_context().obsid { - Some(o) => info!( - "Reading obsid {} from measurement set {}", - o, - input_data.ms.canonicalize()?.display() - ), - None => info!( - "Reading measurement set {}", - input_data.ms.canonicalize()?.display() - ), - } - Box::new(input_data) - } - - // Valid input for reading uvfits files. - (meta, None, None, None, Some(uvfits)) => { - // Only one uvfits is supported at the moment. - let uvfits: PathBuf = if uvfits.len() > 1 { - return Err(VisSubtractError::MultipleUvfits(uvfits)); - } else { - uvfits.first().clone() - }; - - // Ensure that there's only one metafits. - let meta: Option<&Path> = match meta.as_ref() { - None => None, - Some(m) => { - if m.len() > 1 { - return Err(VisSubtractError::MultipleMetafits(m.clone())); - } else { - Some(m.first()) - } - } - }; - - let input_data = - UvfitsReader::new(uvfits, meta, array_position).map_err(VisReadError::from)?; - match input_data.get_obs_context().obsid { - Some(o) => info!( - "Reading obsid {} from uvfits {}", - o, - input_data.uvfits.canonicalize()?.display() - ), - None => info!( - "Reading uvfits {}", - input_data.uvfits.canonicalize()?.display() - ), - } - Box::new(input_data) - } - - // The following matches are for invalid combinations of input - // files. Make an error message for the user. - (_, Some(_), _, _, _) => { - let msg = "Received gpubox files, but these are not supported by vis-subtract."; - return Err(VisSubtractError::InvalidDataInput(msg)); - } - (_, _, Some(_), _, _) => { - let msg = "Received mwaf files, but these are not supported by vis-subtract."; - return Err(VisSubtractError::InvalidDataInput(msg)); - } - (Some(_), None, None, None, None) => { - let msg = "Received only a metafits file; a calibrated uvfits file or calibrated measurement set is required."; - return Err(VisSubtractError::InvalidDataInput(msg)); - } - (_, _, _, Some(_), Some(_)) => { - let msg = "Received uvfits and measurement set files; this is not supported."; - return Err(VisSubtractError::InvalidDataInput(msg)); - } - (None, None, None, None, None) => return Err(VisSubtractError::NoInputData), - }; - - let obs_context = input_data.get_obs_context(); - let total_num_tiles = obs_context.get_total_num_tiles(); - let num_unflagged_tiles = obs_context.get_num_unflagged_tiles(); - let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; - let flagged_tiles = obs_context - .get_tile_flags(false, None) - .expect("can't fail; no additional flags"); - let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, flagged_tiles); - let vis_shape = ( - obs_context.fine_chan_freqs.len(), - num_unflagged_cross_baselines, - ); - - // Set up the beam for modelling. - let dipole_delays = match delays { - // We have user-provided delays; check that they're are sensible, - // regardless of whether we actually need them. - Some(d) => { - if d.len() != 16 || d.iter().any(|&v| v > 32) { - return Err(VisSubtractError::BadDelays); - } - Some(Delays::Partial(d)) - } - - // No delays were provided; use whatever was in the input data. - None => obs_context.dipole_delays.clone(), - }; - - let beam: Box = if no_beam { - create_no_beam_object(obs_context.tile_xyzs.len()) - } else { - let mut dipole_delays = dipole_delays.ok_or(VisSubtractError::NoDelays)?; - let dipole_gains = if unity_dipole_gains { - None - } else { - // If we don't have dipole gains from the input data, then - // we issue a warning that we must assume no dead dipoles. - if obs_context.dipole_gains.is_none() { - match input_data.get_input_data_type() { - VisInputType::MeasurementSet => { - warn!("Measurement sets cannot supply dead dipole information."); - warn!("Without a metafits file, we must assume all dipoles are alive."); - warn!("This will make beam Jones matrices inaccurate in sky-model generation."); - } - VisInputType::Uvfits => { - warn!("uvfits files cannot supply dead dipole information."); - warn!("Without a metafits file, we must assume all dipoles are alive."); - warn!("This will make beam Jones matrices inaccurate in sky-model generation."); - } - VisInputType::Raw => unreachable!(), - } - } - obs_context.dipole_gains.clone() - }; - if dipole_gains.is_none() { - // If we don't have dipole gains, we must assume all dipoles are - // "alive". But, if any dipole delays are 32, then the beam code - // will still ignore those dipoles. So use ideal dipole delays for - // all tiles. - let ideal_delays = dipole_delays.get_ideal_delays(); - - // Warn the user if they wanted unity dipole gains but the ideal - // dipole delays contain 32. - if unity_dipole_gains && ideal_delays.iter().any(|&v| v == 32) { - warn!("Some ideal dipole delays are 32; these dipoles will not have unity gains"); - } - dipole_delays.set_to_ideal_delays(); - } - - create_fee_beam_object( - beam_file.as_deref(), - total_num_tiles, - dipole_delays, - dipole_gains, - )? - }; - let beam_file = beam.get_beam_file(); - debug!("Beam file: {beam_file:?}"); - - let array_position = obs_context.array_position; - - let timesteps = match timesteps { - None => Vec1::try_from(obs_context.all_timesteps.as_slice()), - Some(mut ts) => { - // Make sure there are no duplicates. - let timesteps_hashset: HashSet<&usize> = ts.iter().collect(); - if timesteps_hashset.len() != ts.len() { - return Err(VisSubtractError::DuplicateTimesteps); - } - - // Ensure that all specified timesteps are actually available. - for &t in &ts { - if obs_context.timestamps.get(t).is_none() { - return Err(VisSubtractError::UnavailableTimestep { - got: t, - last: obs_context.timestamps.len() - 1, - }); - } - } - - ts.sort_unstable(); - Vec1::try_from_vec(ts) - } - } - .map_err(|_| VisSubtractError::NoTimesteps)?; - - let dut1 = if ignore_dut1 { None } else { obs_context.dut1 }; - - let precession_info = precess_time( - array_position.longitude_rad, - array_position.latitude_rad, - obs_context.phase_centre, - obs_context.timestamps[*timesteps.first()], - dut1.unwrap_or_else(|| Duration::from_seconds(0.0)), - ); - let (lmst, latitude) = if no_precession { - (precession_info.lmst, array_position.latitude_rad) - } else { - ( - precession_info.lmst_j2000, - precession_info.array_latitude_j2000, - ) - }; - - messages::ArrayDetails { - array_position: Some(array_position), - array_latitude_j2000: if no_precession { - None - } else { - Some(precession_info.array_latitude_j2000) - }, - total_num_tiles, - num_unflagged_tiles, - flagged_tiles: &tile_baseline_flags - .flagged_tiles - .iter() - .cloned() - .sorted() - .map(|i| (obs_context.tile_names[i].as_str(), i)) - .collect::>(), - } - .print(); - - let time_res = obs_context.guess_time_res(); - let freq_res = obs_context.guess_freq_res(); - - messages::ObservationDetails { - dipole_delays: beam.get_ideal_dipole_delays(), - beam_file, - num_tiles_with_dead_dipoles: if unity_dipole_gains { - None - } else { - obs_context.dipole_gains.as_ref().map(|array| { - array - .outer_iter() - .filter(|tile_dipole_gains| { - tile_dipole_gains.iter().any(|g| g.abs() < f64::EPSILON) - }) - .count() - }) - }, - phase_centre: obs_context.phase_centre, - pointing_centre: None, - dut1, - lmst: Some(precession_info.lmst), - lmst_j2000: if no_precession { - None - } else { - Some(precession_info.lmst_j2000) - }, - available_timesteps: Some(obs_context.all_timesteps.as_slice()), - unflagged_timesteps: Some(obs_context.unflagged_timesteps.as_slice()), - using_timesteps: Some(timesteps.as_slice()), - first_timestamp: Some(obs_context.timestamps[*timesteps.first()]), - last_timestamp: Some(obs_context.timestamps[*timesteps.last()]), - time_res: Some(time_res), - total_num_channels: obs_context.fine_chan_freqs.len(), - num_unflagged_channels: None, - flagged_chans_per_coarse_chan: None, - first_freq_hz: Some(*obs_context.fine_chan_freqs.first() as f64), - last_freq_hz: Some(*obs_context.fine_chan_freqs.last() as f64), - first_unflagged_freq_hz: None, - last_unflagged_freq_hz: None, - freq_res_hz: Some(freq_res), - } - .print(); - - // Handle the invert option. - let source_list: SourceList = if invert { - let mut sl: SourceList = source_list - .into_iter() - .filter(|(name, _)| !sources_to_subtract.contains(name)) - .collect(); - if sl.is_empty() { - // Nothing to do. - return Err(VisSubtractError::AllSourcesFiltered); - } - veto_sources( - &mut sl, - obs_context.phase_centre, - lmst, - latitude, - &obs_context.get_veto_freqs(), - &*beam, - num_sources, - source_dist_cutoff.unwrap_or(DEFAULT_CUTOFF_DISTANCE), - veto_threshold.unwrap_or(DEFAULT_VETO_THRESHOLD), - )?; - if sl.is_empty() { - return Err(VisSubtractError::NoSourcesAfterVeto); - } - info!("Subtracting {} sources", sl.len()); - sl - } else { - let sl = source_list - .into_iter() - .filter(|(name, _)| sources_to_subtract.contains(name)) - .collect(); - info!( - "Subtracting {} specified sources", - sources_to_subtract.len() - ); - sl - }; - - messages::SkyModelDetails { - source_list: &source_list, - } - .print(); - - messages::print_modeller_info(&modeller_info); - - // Handle output visibility arguments. - let (time_average_factor, freq_average_factor) = { - // Parse and verify user input (specified resolutions must - // evenly divide the input data's resolutions). - let time_factor = parse_time_average_factor( - obs_context.time_res, - time_average.as_deref(), - 1, - ) - .map_err(|e| match e { - AverageFactorError::Zero => VisSubtractError::OutputVisTimeAverageFactorZero, - AverageFactorError::NotInteger => VisSubtractError::OutputVisTimeFactorNotInteger, - AverageFactorError::NotIntegerMultiple { out, inp } => { - VisSubtractError::OutputVisTimeResNotMultiple { out, inp } - } - AverageFactorError::Parse(e) => VisSubtractError::ParseOutputVisTimeAverageFactor(e), - })?; - let freq_factor = parse_freq_average_factor( - obs_context.freq_res, - freq_average.as_deref(), - 1, - ) - .map_err(|e| match e { - AverageFactorError::Zero => VisSubtractError::OutputVisFreqAverageFactorZero, - AverageFactorError::NotInteger => VisSubtractError::OutputVisFreqFactorNotInteger, - AverageFactorError::NotIntegerMultiple { out, inp } => { - VisSubtractError::OutputVisFreqResNotMultiple { out, inp } - } - AverageFactorError::Parse(e) => VisSubtractError::ParseOutputVisFreqAverageFactor(e), - })?; - - (time_factor, freq_factor) - }; - - let outputs = { - if outputs.is_empty() { - vec1![( - PathBuf::from(DEFAULT_OUTPUT_VIS_FILENAME), - VisOutputType::Uvfits - )] - } else { - let mut valid_outputs = Vec::with_capacity(outputs.len()); - for file in outputs { - // Is the output file type supported? - let ext = file.extension().and_then(|os_str| os_str.to_str()); - match ext.and_then(|s| VisOutputType::from_str(s).ok()) { - Some(t) => { - can_write_to_file(&file)?; - valid_outputs.push((file.to_owned(), t)); - } - None => return Err(VisSubtractError::InvalidOutputFormat(file.clone())), - } - } - Vec1::try_from_vec(valid_outputs).unwrap() - } - }; - - messages::OutputFileDetails { - output_solutions: &[], - vis_type: "subtracted", - output_vis: Some(&outputs), - input_vis_time_res: Some(time_res), - input_vis_freq_res: Some(freq_res), - output_vis_time_average_factor: time_average_factor, - output_vis_freq_average_factor: freq_average_factor, - } - .print(); - - let timeblocks = - timesteps_to_timeblocks(&obs_context.timestamps, time_average_factor, ×teps); - - if dry_run { - info!("Dry run -- exiting now."); - return Ok(()); - } - - // Channel for modelling and subtracting. - let (tx_model, rx_model) = bounded(5); - // Channel for writing subtracted visibilities. - let (tx_write, rx_write) = bounded(5); - - // Progress bars. - let multi_progress = MultiProgress::with_draw_target(if no_progress_bars { - ProgressDrawTarget::hidden() - } else { - ProgressDrawTarget::stdout() - }); - let read_progress = multi_progress.add( - ProgressBar::new(timesteps.len() as _) - .with_style( - ProgressStyle::default_bar() - .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() - .progress_chars("=> "), - ) - .with_position(0) - .with_message("Reading data"), -); - let model_progress = multi_progress.add( - ProgressBar::new(timesteps.len() as _) - .with_style( - ProgressStyle::default_bar() - .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() - .progress_chars("=> "), - ) - .with_position(0) - .with_message("Sky modelling"), -); - let write_progress = multi_progress.add( - ProgressBar::new(timeblocks.len() as _) - .with_style( - ProgressStyle::default_bar() - .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timeblocks ({elapsed_precise}<{eta_precise})").unwrap() - .progress_chars("=> "), - ) - .with_position(0) - .with_message("Subtracted writing"), - ); - - // Use a variable to track whether any threads have an issue. - let error = AtomicCell::new(false); - - info!("Reading input data, sky modelling, and writing"); - let scoped_threads_result = thread::scope(|s| { - // Input visibility-data reading thread. - let data_handle = s.spawn(|| { - // If a panic happens, update our atomic error. - defer_on_unwind! { error.store(true); } - read_progress.tick(); - - let result = read_vis( - obs_context, - &tile_baseline_flags, - &*input_data, - ×teps, - vis_shape, - tx_model, - &error, - read_progress, - ); - // If the result of reading data was an error, allow the other - // threads to see this so they can abandon their work early. - if result.is_err() { - error.store(true); - } - result - }); - - // Sky-model generation and subtraction thread. - let model_handle = s.spawn(|| { - defer_on_unwind! { error.store(true); } - model_progress.tick(); - - let result = model_vis_and_subtract( - &*beam, - &source_list, - obs_context, - array_position, - vis_shape, - dut1.unwrap_or_else(|| Duration::from_seconds(0.0)), - !no_precession, - rx_model, - tx_write, - &error, - model_progress, - #[cfg(feature = "cuda")] - use_cpu_for_modelling, - ); - if result.is_err() { - error.store(true); - } - result - }); - - // Subtracted vis writing thread. - let write_handle = s.spawn(|| { - defer_on_unwind! { error.store(true); } - write_progress.tick(); - - let result = write_vis( - &outputs, - array_position, - obs_context.phase_centre, - obs_context.pointing_centre, - &obs_context.tile_xyzs, - &obs_context.tile_names, - obs_context.obsid, - &obs_context.timestamps, - ×teps, - &timeblocks, - time_res, - dut1.unwrap_or_else(|| Duration::from_seconds(0.0)), - freq_res, - &obs_context.fine_chan_freqs.mapped_ref(|&f| f as f64), - &tile_baseline_flags - .unflagged_cross_baseline_to_tile_map - .values() - .copied() - .sorted() - .collect::>(), - // TODO: Provide CLI options - &HashSet::new(), - time_average_factor, - freq_average_factor, - input_data.get_marlu_mwa_info().as_ref(), - rx_write, - &error, - Some(write_progress), - ); - if result.is_err() { - error.store(true); - } - result - }); - - // Join all thread handles. This propagates any errors and lets us know - // if any threads panicked, if panics aren't aborting as per the - // Cargo.toml. (It would be nice to capture the panic information, if - // it's possible, but I don't know how, so panics are currently - // aborting.) - let result = data_handle.join().unwrap(); - let result = result.and_then(|_| model_handle.join().unwrap()); - result.and_then(|_| write_handle.join().unwrap().map_err(VisSubtractError::from)) - }); - - // Propagate errors and print out the write message. - let s = scoped_threads_result?; - info!("{s}"); - - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -fn read_vis( - obs_context: &ObsContext, - tile_baseline_flags: &TileBaselineFlags, - input_data: &dyn VisRead, - timesteps: &Vec1, - vis_shape: (usize, usize), - tx: Sender, - error: &AtomicCell, - progress_bar: ProgressBar, -) -> Result<(), VisSubtractError> { - let flagged_fine_chans = HashSet::new(); - - // Read data to fill the buffer, pausing when the buffer is full to - // write it all out. - for ×tep in timesteps { - let timestamp = obs_context.timestamps[timestep]; - debug!("Reading timestamp {}", timestamp.to_gpst_seconds()); - - let mut cross_data_fb: ArcArray2> = ArcArray2::zeros(vis_shape); - let mut cross_weights_fb: ArcArray2 = ArcArray2::zeros(vis_shape); - input_data.read_crosses( - cross_data_fb.view_mut(), - cross_weights_fb.view_mut(), - timestep, - tile_baseline_flags, - &flagged_fine_chans, - )?; - - // Should we continue? - if error.load() { - return Ok(()); - } - - match tx.send(VisTimestep { - cross_data_fb, - cross_weights_fb, - autos: None, - timestamp, - }) { - Ok(()) => (), - // If we can't send the message, it's because the channel - // has been closed on the other side. That should only - // happen because the writer has exited due to error; in - // that case, just exit this thread. - Err(_) => return Ok(()), - } - progress_bar.inc(1); - } - debug!("Finished reading"); - progress_bar.abandon_with_message("Finished reading visibilities"); - - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -fn model_vis_and_subtract( - beam: &dyn Beam, - source_list: &SourceList, - obs_context: &ObsContext, - array_position: LatLngHeight, - vis_shape: (usize, usize), - dut1: Duration, - apply_precession: bool, - rx: Receiver, - tx: Sender, - error: &AtomicCell, - progress_bar: ProgressBar, - #[cfg(feature = "cuda")] use_cpu_for_modelling: bool, -) -> Result<(), VisSubtractError> { - let flagged_tiles = obs_context - .get_tile_flags(false, None) - .expect("can't fail; no additional flags"); - let unflagged_tile_xyzs = obs_context - .tile_xyzs - .iter() - .enumerate() - .filter(|(i, _)| !flagged_tiles.contains(i)) - .map(|(_, xyz)| *xyz) - .collect::>(); - let freqs = obs_context - .fine_chan_freqs - .iter() - .map(|&i| i as f64) - .collect::>(); - let modeller = crate::model::new_sky_modeller( - #[cfg(feature = "cuda")] - use_cpu_for_modelling, - beam, - source_list, - obs_context.polarisations, - &unflagged_tile_xyzs, - &freqs, - &flagged_tiles, - obs_context.phase_centre, - array_position.longitude_rad, - array_position.latitude_rad, - dut1, - apply_precession, - )?; - - // Recycle an array for model visibilities. - let mut cross_model_fb = Array2::zeros(vis_shape); - - // Iterate over the incoming data. - for VisTimestep { - mut cross_data_fb, - cross_weights_fb, - autos, - timestamp, - } in rx.iter() - { - debug!("Modelling timestamp {}", timestamp.to_gpst_seconds()); - modeller.model_timestep_with(timestamp, cross_model_fb.view_mut())?; - cross_data_fb - .iter_mut() - .zip(cross_model_fb.iter()) - .for_each(|(data, model)| { - *data = Jones::from(Jones::::from(*data) - Jones::::from(*model)); - }); - cross_model_fb.fill(Jones::default()); - - // Should we continue? - if error.load() { - return Ok(()); - } - - match tx.send(VisTimestep { - cross_data_fb, - cross_weights_fb, - autos, - timestamp, - }) { - Ok(()) => (), - Err(_) => return Ok(()), - } - progress_bar.inc(1); - } - debug!("Finished modelling"); - progress_bar.abandon_with_message("Finished subtracting sky model"); - Ok(()) -} diff --git a/src/constants.rs b/src/constants.rs index 26c367a8..db20d39a 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -27,3 +27,6 @@ pub(crate) const DEFAULT_CUTOFF_DISTANCE: f64 = 50.0; pub(crate) const SQRT_FRAC_PI_SQ_2_LN_2: f64 = 2.6682231283184983; pub(crate) use marlu::constants::*; + +/// The default column to use when reading visibilities from a measurement set. +pub(crate) const DEFAULT_MS_DATA_COL_NAME: &str = "DATA"; diff --git a/src/context/mod.rs b/src/context/mod.rs index d25dc7cb..cfc58e68 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -8,22 +8,17 @@ mod tests; use std::{ - collections::HashSet, fmt::{Display, Write}, - num::NonZeroUsize, + num::NonZeroU16, }; use hifitime::{Duration, Epoch}; use log::{debug, error, info, trace, warn}; -use marlu::{ - constants::{FREQ_WEIGHT_FACTOR, TIME_WEIGHT_FACTOR}, - LatLngHeight, RADec, XyzGeodetic, -}; +use marlu::{LatLngHeight, RADec, XyzGeodetic}; use ndarray::Array2; -use thiserror::Error; use vec1::Vec1; -use crate::beam::Delays; +use crate::{beam::Delays, io::read::VisInputType}; /// Currently supported polarisations. #[derive(Clone, Copy, Debug, PartialEq)] @@ -79,6 +74,9 @@ impl Polarisations { /// Tile information is ordered according to the "Antenna" column in HDU 1 of /// the observation's metafits file. pub(crate) struct ObsContext { + /// The format of the file containing the visibilities (e.g. uvfits). + pub(crate) input_data_type: VisInputType, + /// The observation ID, which is also the observation's scheduled start GPS /// time (but shouldn't be used for this purpose). pub(crate) obsid: Option, @@ -121,7 +119,7 @@ pub(crate) struct ObsContext { /// The Earth position of the instrumental array *that is described by the /// input data*. It is provided here to distinguish from `array_position`, /// which may be different. - pub(crate) _supplied_array_position: LatLngHeight, + pub(crate) supplied_array_position: LatLngHeight, /// The difference between UT1 and UTC. If this is 0 seconds, then LSTs are /// wrong by up to 0.9 seconds. The code will assume that 0 seconds means @@ -137,7 +135,8 @@ pub(crate) struct ObsContext { /// The [`XyzGeodetic`] coordinates of all tiles in the array (all /// coordinates are specified in \[metres\]). This includes flagged and - /// unavailable tiles. + /// unavailable tiles. The values described here may be affected by a + /// user-supplied `array_position`. pub(crate) tile_xyzs: Vec1, /// The flagged tiles, i.e. what the observation data suggests to be @@ -186,7 +185,7 @@ pub(crate) struct ObsContext { /// The number of fine-frequency channels per coarse channel. For 40 kHz /// legacy MWA data, this is 32. - pub(crate) num_fine_chans_per_coarse_chan: Option, + pub(crate) num_fine_chans_per_coarse_chan: Option, /// The fine-channel resolution of the supplied data \[Hz\]. This is not /// necessarily the fine-channel resolution of the original observation's @@ -202,11 +201,11 @@ pub(crate) struct ObsContext { /// The flagged fine channels for each baseline in the supplied data. Zero /// indexed. - pub(crate) flagged_fine_chans: Vec, + pub(crate) flagged_fine_chans: Vec, /// The fine channels per coarse channel already flagged in the supplied /// data. Zero indexed. - pub(crate) flagged_fine_chans_per_coarse_chan: Option>, + pub(crate) flagged_fine_chans_per_coarse_chan: Option>, /// The polarisations included in the data. Any combinations not listed are /// not supported. @@ -220,111 +219,6 @@ impl ObsContext { self.tile_xyzs.len() } - /// Get the number of unflagged tiles in the observation, i.e. total - - /// flagged. - pub(crate) fn get_num_unflagged_tiles(&self) -> usize { - self.get_total_num_tiles() - self.flagged_tiles.len() - } - - /// Attempt to get time resolution using heuristics if it is not present. - /// - /// If `time_res` is `None`, then attempt to determine it from the minimum - /// distance between timestamps. If there is no more than 1 timestamp, then - /// return 1s, since the time resolution of single-timestep observations is - /// not important anyway. - pub(crate) fn guess_time_res(&self) -> Duration { - match self.time_res { - Some(t) => t, - None => { - warn!("No integration time specified; assuming {TIME_WEIGHT_FACTOR} second"); - Duration::from_seconds(TIME_WEIGHT_FACTOR) - } - } - } - - pub(crate) fn guess_freq_res(&self) -> f64 { - match self.freq_res { - Some(f) => f, - None => { - warn!( - "No frequency resolution specified; assuming {} kHz", - FREQ_WEIGHT_FACTOR / 1e3 - ); - FREQ_WEIGHT_FACTOR - } - } - } - - /// Given whether to use the [ObsContext]'s tile flags and additional tile - /// flags (as strings or indices), return de-duplicated and sorted tile flag - /// indices. - pub(crate) fn get_tile_flags( - &self, - ignore_input_data_tile_flags: bool, - additional_flags: Option<&[String]>, - ) -> Result, InvalidTileFlag> { - let mut flagged_tiles = HashSet::new(); - - if !ignore_input_data_tile_flags { - // Add tiles that have already been flagged by the input data. - for &obs_tile_flag in &self.flagged_tiles { - flagged_tiles.insert(obs_tile_flag); - } - } - // Unavailable tiles must be regarded as flagged. - for i in &self.unavailable_tiles { - flagged_tiles.insert(*i); - } - - if let Some(flag_strings) = additional_flags { - // We need to convert the strings into antenna indices. The strings - // are either indices themselves or antenna names. - for flag_string in flag_strings { - // Try to parse a naked number. - let result = match flag_string.trim().parse().ok() { - Some(i) => { - let total_num_tiles = self.get_total_num_tiles(); - if i >= total_num_tiles { - Err(InvalidTileFlag::Index { - got: i, - max: total_num_tiles - 1, - }) - } else { - flagged_tiles.insert(i); - Ok(()) - } - } - None => { - // Check if this is an antenna name. - match self - .tile_names - .iter() - .enumerate() - .find(|(_, name)| name.to_lowercase() == flag_string.to_lowercase()) - { - // If there are no matches, complain that the user input - // is no good. - None => Err(InvalidTileFlag::BadTileFlag(flag_string.to_string())), - Some((i, _)) => { - flagged_tiles.insert(i); - Ok(()) - } - } - } - }; - if result.is_err() { - // If there's a problem, show all the tile names and their - // indices to help out the user. - self.print_info_tile_statuses(); - // Propagate the error. - result?; - } - } - } - - Ok(flagged_tiles) - } - /// Return all frequencies within the fine frequency channel range that are /// multiples of 1.28 MHz. pub(crate) fn get_veto_freqs(&self) -> Vec { @@ -340,7 +234,7 @@ impl ObsContext { /// Print information on the indices, names and statuses of all of the tiles /// in this observation at the indicated log level. - fn print_tile_statuses(&self, level: log::Level) { + pub(crate) fn print_tile_statuses(&self, level: log::Level) { let s = "All tile indices, names and default statuses:"; match level { log::Level::Error => error!("{}", s), @@ -372,25 +266,4 @@ impl ObsContext { s.clear(); }); } - - /// At info level, print information on the indices, names and statuses of - /// all of the tiles in this observation. - pub(crate) fn print_info_tile_statuses(&self) { - self.print_tile_statuses(log::Level::Info) - } - - /// At info level, print information on the indices, names and statuses of - /// all of the tiles in this observation. - pub(crate) fn print_debug_tile_statuses(&self) { - self.print_tile_statuses(log::Level::Debug) - } -} - -#[derive(Error, Debug)] -pub(crate) enum InvalidTileFlag { - #[error("Got a tile flag {got}, but the biggest possible antenna index is {max}")] - Index { got: usize, max: usize }, - - #[error("Bad flag value: '{0}' is neither an integer or an available antenna name. Run with extra verbosity to see all tile statuses.")] - BadTileFlag(String), } diff --git a/src/context/tests.rs b/src/context/tests.rs index e277f930..c7b2ac9c 100644 --- a/src/context/tests.rs +++ b/src/context/tests.rs @@ -2,15 +2,16 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use hifitime::{Duration, Epoch}; +use hifitime::Epoch; use marlu::{LatLngHeight, RADec, XyzGeodetic}; use vec1::vec1; use super::Polarisations; -use crate::{beam::Delays, context::ObsContext}; +use crate::{beam::Delays, context::ObsContext, io::read::VisInputType}; fn get_minimal_obs_context() -> ObsContext { ObsContext { + input_data_type: VisInputType::Raw, obsid: None, timestamps: vec1![Epoch::from_gpst_seconds(1090008640.0)], all_timesteps: vec1![0], @@ -18,7 +19,7 @@ fn get_minimal_obs_context() -> ObsContext { phase_centre: RADec::default(), pointing_centre: Some(RADec::default()), array_position: LatLngHeight::mwa(), - _supplied_array_position: LatLngHeight::mwa(), + supplied_array_position: LatLngHeight::mwa(), dut1: None, tile_names: vec1!["Tile00".into()], tile_xyzs: vec1![XyzGeodetic::default()], @@ -38,44 +39,6 @@ fn get_minimal_obs_context() -> ObsContext { } } -#[test] -fn test_guess_time_res() { - let mut obs_ctx = get_minimal_obs_context(); - - // test fallback to 1s - obs_ctx.time_res = None; - obs_ctx.timestamps = vec1![Epoch::from_gpst_seconds(1090000000.0)]; - - assert_eq!(obs_ctx.guess_time_res(), Duration::from_seconds(1.)); - - // test use direct time_res over min_delta - obs_ctx.time_res = Some(Duration::from_seconds(2.)); - obs_ctx.timestamps = vec1![ - Epoch::from_gpst_seconds(1090000000.0), - Epoch::from_gpst_seconds(1090000010.0), - Epoch::from_gpst_seconds(1090000013.0), - ]; - - assert_eq!(obs_ctx.guess_time_res(), Duration::from_seconds(2.)); -} - -#[test] -fn test_guess_freq_res() { - let mut obs_ctx = get_minimal_obs_context(); - - // test fallback to 1s - obs_ctx.freq_res = None; - obs_ctx.fine_chan_freqs = vec1![128_000_000]; - - assert_eq!(obs_ctx.guess_freq_res(), 10_000.); - - // test use direct freq_res over min_delta - obs_ctx.freq_res = Some(30_000.); - obs_ctx.fine_chan_freqs = vec1![128_000_000, 128_100_000, 128_120_000]; - - assert_eq!(obs_ctx.guess_freq_res(), 30_000.); -} - #[test] fn test_veto_freqs() { let mut obs_ctx = get_minimal_obs_context(); diff --git a/src/cuda/mod.rs b/src/cuda/mod.rs index 9437df8f..b1d7f087 100644 --- a/src/cuda/mod.rs +++ b/src/cuda/mod.rs @@ -19,7 +19,7 @@ use std::{ use thiserror::Error; -pub(crate) use utils::{get_device_info, CudaDeviceInfo, CudaDriverInfo}; +pub(crate) use utils::get_device_info; // Import Rust bindings to the CUDA code specific to the precision we're using, // and set corresponding compile-time types. diff --git a/src/di_calibrate/error.rs b/src/di_calibrate/error.rs deleted file mode 100644 index be8abae3..00000000 --- a/src/di_calibrate/error.rs +++ /dev/null @@ -1,42 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//! Error type for all calibration-related errors. - -use thiserror::Error; - -#[derive(Error, Debug)] -pub(crate) enum DiCalibrateError { - #[error("Insufficient memory available to perform calibration; need {need_gib} of memory.\nYou could try using fewer timesteps and channels.")] - InsufficientMemory { need_gib: indicatif::HumanBytes }, - - #[error( - "Timestep {timestep} wasn't available in the timestamps list; this is a programmer error" - )] - TimestepUnavailable { timestep: usize }, - - #[error(transparent)] - DiCalArgs(#[from] crate::cli::di_calibrate::DiCalArgsError), - - #[error(transparent)] - SolutionsRead(#[from] crate::solutions::SolutionsReadError), - - #[error(transparent)] - SolutionsWrite(#[from] crate::solutions::SolutionsWriteError), - - #[error(transparent)] - Model(#[from] crate::model::ModelError), - - #[error(transparent)] - VisRead(#[from] crate::io::read::VisReadError), - - #[error(transparent)] - VisWrite(#[from] crate::io::write::VisWriteError), - - #[error(transparent)] - Fitsio(#[from] fitsio::errors::Error), - - #[error(transparent)] - IO(#[from] std::io::Error), -} diff --git a/src/di_calibrate/mod.rs b/src/di_calibrate/mod.rs index 5a9a1f89..42d66f58 100644 --- a/src/di_calibrate/mod.rs +++ b/src/di_calibrate/mod.rs @@ -7,457 +7,26 @@ //! This code borrows heavily from Torrance Hodgson's excellent Julia code at //! -mod error; #[cfg(test)] pub(crate) mod tests; -pub(crate) use error::DiCalibrateError; - -use std::thread; - -use crossbeam_channel::{unbounded, Sender}; -use crossbeam_utils::atomic::AtomicCell; -use hifitime::Duration; -use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; +use indicatif::{ProgressBar, ProgressDrawTarget, ProgressStyle}; use itertools::Itertools; -use log::{debug, info}; -use marlu::{ - c64, - constants::{FREQ_WEIGHT_FACTOR, TIME_WEIGHT_FACTOR}, - math::num_tiles_from_num_cross_correlation_baselines, - Jones, -}; -use ndarray::{iter::AxisIterMut, prelude::*}; +use log::info; +use marlu::{c64, math::num_tiles_from_num_cross_correlation_baselines, Jones}; +use ndarray::prelude::*; use rayon::prelude::*; -use scopeguard::defer_on_unwind; use vec1::Vec1; use crate::{ - averaging::{timesteps_to_timeblocks, Chanblock, Timeblock}, - cli::di_calibrate::DiCalParams, + averaging::{Chanblock, Timeblock}, context::Polarisations, - io::write::{write_vis, VisTimestep}, math::average_epoch, - misc::expensive_op, - model::{self, ModellerInfo}, + params::DiCalParams, solutions::CalibrationSolutions, + MODEL_DEVICE, PROGRESS_BARS, }; -pub(crate) struct CalVis { - /// Visibilites read from input data. - pub(crate) vis_data_tfb: Array3>, - - /// The weights on the visibilites read from input data. - pub(crate) vis_weights_tfb: Array3, - - /// Visibilites generated from the sky-model source list. - pub(crate) vis_model_tfb: Array3>, - - /// The available polarisations within the data. - pub(crate) pols: Polarisations, -} - -/// For calibration, read in unflagged visibilities and generate sky-model -/// visibilities. -pub(crate) fn get_cal_vis( - params: &DiCalParams, - draw_progress_bar: bool, -) -> Result { - // TODO: Use all fences, not just the first. - let fence = params.fences.first(); - - // Get the time and frequency resolutions once; these functions issue - // warnings if they have to guess, so doing this once means we aren't - // issuing too many warnings. - let obs_context = params.get_obs_context(); - let time_res = obs_context.guess_time_res(); - let freq_res = obs_context.guess_freq_res(); - - let vis_shape = ( - params.get_num_timesteps(), - fence.chanblocks.len(), - params.get_num_unflagged_cross_baselines(), - ); - let num_elems = vis_shape.0 * vis_shape.1 * vis_shape.2; - // We need this many bytes for each of the data and model arrays to do - // calibration. - let size = indicatif::HumanBytes((num_elems * std::mem::size_of::>()) as u64); - debug!( - "Shape of data and model arrays: ({} timesteps, {} channels, {} baselines; {size} each)", - vis_shape.0, vis_shape.1, vis_shape.2 - ); - - macro_rules! fallible_allocator { - ($default:expr) => {{ - let mut v = Vec::new(); - match v.try_reserve_exact(num_elems) { - Ok(()) => { - v.resize(num_elems, $default); - Ok(Array3::from_shape_vec(vis_shape, v).unwrap()) - } - Err(_) => { - // We need this many gibibytes to do calibration (two - // visibility arrays and one weights array). - let need_gib = indicatif::HumanBytes( - (num_elems - * (2 * std::mem::size_of::>() + std::mem::size_of::())) - as u64, - ); - - Err(DiCalibrateError::InsufficientMemory { - // Instead of erroring out with how many bytes we need - // for the array we just tried to allocate, error out - // with how many bytes we need total. - need_gib, - }) - } - } - }}; - } - - debug!("Allocating memory for input data visibilities and model visibilities"); - let CalVis { - mut vis_data_tfb, - mut vis_model_tfb, - mut vis_weights_tfb, - pols: _, - } = expensive_op( - || -> Result<_, DiCalibrateError> { - let vis_data_tfb: Array3> = fallible_allocator!(Jones::default())?; - let vis_model_tfb: Array3> = fallible_allocator!(Jones::default())?; - let vis_weights_tfb: Array3 = fallible_allocator!(0.0)?; - Ok(CalVis { - vis_data_tfb, - vis_weights_tfb, - vis_model_tfb, - pols: Polarisations::default(), - }) - }, - "Still waiting to allocate visibility memory", - )?; - - // Sky-modelling communication channel. Used to tell the model writer when - // visibilities have been generated and they're ready to be written. - let (tx_model, rx_model) = unbounded(); - - // Progress bars. Courtesy Dev. - let multi_progress = MultiProgress::with_draw_target(if draw_progress_bar { - ProgressDrawTarget::stdout() - } else { - ProgressDrawTarget::hidden() - }); - let read_progress = multi_progress.add( - ProgressBar::new(vis_shape.0 as _) - .with_style( - ProgressStyle::default_bar() - .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() - .progress_chars("=> "), - ) - .with_position(0) - .with_message("Reading data"), - ); - let model_progress = multi_progress.add( - ProgressBar::new(vis_shape.0 as _) - .with_style( - ProgressStyle::default_bar() - .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() - .progress_chars("=> "), - ) - .with_position(0) - .with_message("Sky modelling"), - ); - // Only add a model writing progress bar if we need it. - let model_write_progress = params.model_files.as_ref().map(|_| { - multi_progress.add( - ProgressBar::new(vis_shape.0 as _) - .with_style( - ProgressStyle::default_bar() - .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timeblocks ({elapsed_precise}<{eta_precise})").unwrap() - .progress_chars("=> "), - ) - .with_position(0) - .with_message("Model writing"), - ) - }); - - // Use a variable to track whether any threads have an issue. - let error = AtomicCell::new(false); - info!("Reading input data and sky modelling"); - let scoped_threads_result = thread::scope(|s| { - // Mutable slices of the "global" arrays. These allow threads to mutate - // the global arrays in parallel (using the Arc> pattern would - // kill performance here). - let vis_data_slices = vis_data_tfb.outer_iter_mut(); - let vis_model_slices = vis_model_tfb.outer_iter_mut(); - let vis_weight_slices = vis_weights_tfb.outer_iter_mut(); - - // Input visibility-data reading thread. - let data_handle = s.spawn(|| { - // If a panic happens, update our atomic error. - defer_on_unwind! { error.store(true); } - read_progress.tick(); - - let result = read_vis_data( - params, - vis_data_slices, - vis_weight_slices, - &error, - read_progress, - ); - // If the result of reading data was an error, allow the other - // threads to see this so they can abandon their work early. - if result.is_err() { - error.store(true); - } - result - }); - - // Sky-model generation thread. - let model_handle = s.spawn(|| { - defer_on_unwind! { error.store(true); } - model_progress.tick(); - - let result = model_vis( - params, - vis_model_slices, - time_res, - freq_res, - tx_model, - &error, - model_progress, - #[cfg(feature = "cuda")] - matches!(params.modeller_info, ModellerInfo::Cpu), - ); - if result.is_err() { - error.store(true); - } - result - }); - - // Model writing thread. If the user hasn't specified to write the model - // to a file, then this thread just consumes messages from the modeller. - let writer_handle = s.spawn(|| { - defer_on_unwind! { error.store(true); } - - // If the user wants the sky model written out, `model_file` is - // populated. - if let Some(model_files) = ¶ms.model_files { - if let Some(pb) = model_write_progress.as_ref() { - pb.tick(); - } - - let fine_chan_freqs = obs_context.fine_chan_freqs.mapped_ref(|&f| f as f64); - let unflagged_baseline_tile_pairs = params - .tile_baseline_flags - .tile_to_unflagged_cross_baseline_map - .keys() - .copied() - .sorted() - .collect::>(); - // These timeblocks are distinct from `params.timeblocks`; the - // latter are for calibration time averaging, whereas we want - // timeblocks for model visibility averaging. - let timeblocks = timesteps_to_timeblocks( - &obs_context.timestamps, - params.output_model_time_average_factor, - ¶ms.timesteps, - ); - - let result = write_vis( - model_files, - params.array_position, - obs_context.phase_centre, - obs_context.pointing_centre, - &obs_context.tile_xyzs, - &obs_context.tile_names, - obs_context.obsid, - &obs_context.timestamps, - ¶ms.timesteps, - &timeblocks, - time_res, - params.dut1, - freq_res, - &fine_chan_freqs, - &unflagged_baseline_tile_pairs, - ¶ms.flagged_fine_chans, - params.output_model_time_average_factor, - params.output_model_freq_average_factor, - params.input_data.get_marlu_mwa_info().as_ref(), - rx_model, - &error, - model_write_progress, - ); - if result.is_err() { - error.store(true); - } - match result { - // Discard the result string. - Ok(_) => Ok(()), - Err(e) => Err(DiCalibrateError::from(e)), - } - } else { - // There's no model to write out, but we still need to handle all of the - // incoming messages. - for _ in rx_model.iter() {} - Ok(()) - } - }); - - // Join all thread handles. This propagates any errors and lets us know - // if any threads panicked, if panics aren't aborting as per the - // Cargo.toml. (It would be nice to capture the panic information, if - // it's possible, but I don't know how, so panics are currently - // aborting.) - let result = data_handle.join().unwrap(); - let result = result.and_then(|_| model_handle.join().unwrap()); - result.and_then(|_| writer_handle.join().unwrap()) - }); - - // Propagate errors. - scoped_threads_result?; - - debug!("Multiplying visibilities by weights"); - - // Multiply the visibilities by the weights (and baseline weights based on - // UVW cuts). If a weight is negative, it means the corresponding visibility - // should be flagged, so that visibility is set to 0; this means it does not - // affect calibration. Not iterating over weights during calibration makes - // makes calibration run significantly faster. - vis_data_tfb - .outer_iter_mut() - .into_par_iter() - .zip(vis_model_tfb.outer_iter_mut()) - .zip(vis_weights_tfb.outer_iter()) - .for_each(|((mut vis_data_fb, mut vis_model_fb), vis_weights_fb)| { - vis_data_fb - .outer_iter_mut() - .zip(vis_model_fb.outer_iter_mut()) - .zip(vis_weights_fb.outer_iter()) - .for_each(|((mut vis_data_b, mut vis_model_b), vis_weights_b)| { - vis_data_b - .iter_mut() - .zip(vis_model_b.iter_mut()) - .zip(vis_weights_b.iter()) - .zip(params.baseline_weights.iter()) - .for_each(|(((data, model), &weight), baseline_weight)| { - let weight = f64::from(weight) * *baseline_weight; - if weight <= 0.0 { - *data = Jones::default(); - *model = Jones::default(); - } else { - *data = Jones::::from(Jones::::from(*data) * weight); - *model = Jones::::from(Jones::::from(*model) * weight); - } - }); - }); - }); - - info!("Finished reading input data and sky modelling"); - - Ok(CalVis { - vis_data_tfb, - vis_weights_tfb, - vis_model_tfb, - pols: obs_context.polarisations, - }) -} - -fn read_vis_data( - params: &DiCalParams, - vis_data_slices_fb: AxisIterMut, Dim<[usize; 2]>>, - vis_weight_slices_fb: AxisIterMut>, - error: &AtomicCell, - progress_bar: ProgressBar, -) -> Result<(), DiCalibrateError> { - for ((×tep, vis_data_slice_fb), vis_weight_slice_fb) in params - .timesteps - .iter() - .zip(vis_data_slices_fb) - .zip(vis_weight_slices_fb) - { - params.read_crosses(vis_data_slice_fb, vis_weight_slice_fb, timestep)?; - - // Should we continue? - if error.load() { - return Ok(()); - } - - progress_bar.inc(1); - } - - progress_bar.abandon_with_message("Finished reading input data"); - Ok(()) -} - -#[allow(clippy::too_many_arguments)] -fn model_vis( - params: &DiCalParams, - vis_model_slices_fb: AxisIterMut, Dim<[usize; 2]>>, - time_res: Duration, - freq_res: f64, - tx_model: Sender, - error: &AtomicCell, - progress_bar: ProgressBar, - #[cfg(feature = "cuda")] use_cpu_for_modelling: bool, -) -> Result<(), DiCalibrateError> { - let obs_context = params.get_obs_context(); - let modeller = model::new_sky_modeller( - #[cfg(feature = "cuda")] - use_cpu_for_modelling, - &*params.beam, - ¶ms.source_list, - obs_context.polarisations, - ¶ms.unflagged_tile_xyzs, - ¶ms.unflagged_fine_chan_freqs, - ¶ms.tile_baseline_flags.flagged_tiles, - obs_context.phase_centre, - params.array_position.longitude_rad, - params.array_position.latitude_rad, - params.dut1, - params.apply_precession, - )?; - - let weight_factor = - ((freq_res / FREQ_WEIGHT_FACTOR) * (time_res.to_seconds() / TIME_WEIGHT_FACTOR)) as f32; - - // Iterate over all calibration timesteps and write to the model slices. - for (×tep, mut vis_model_slice) in params.timesteps.iter().zip(vis_model_slices_fb) { - // If for some reason the timestamp isn't there for this timestep, a - // programmer stuffed up, but emit a decent error message. - let timestamp = obs_context - .timestamps - .get(timestep) - .ok_or(DiCalibrateError::TimestepUnavailable { timestep })?; - match modeller.model_timestep_with(*timestamp, vis_model_slice.view_mut()) { - // Send the model information to the writer. - Ok(_) => match tx_model.send(VisTimestep { - cross_data_fb: vis_model_slice.to_shared(), - cross_weights_fb: ArcArray::from_elem(vis_model_slice.dim(), weight_factor), - autos: None, - timestamp: *timestamp, - }) { - Ok(()) => (), - // If we can't send the message, it's because the channel has - // been closed on the other side. That should only happen - // because the writer has exited due to error; in that case, - // just exit this thread. - Err(_) => return Ok(()), - }, - Err(e) => return Err(DiCalibrateError::from(e)), - } - - // Should we continue? - if error.load() { - return Ok(()); - } - - progress_bar.inc(1); - } - - progress_bar.abandon_with_message("Finished generating sky model"); - Ok(()) -} - /// (Possibly) incomplete calibration solutions. /// /// hyperdrive only reads in the data it needs for DI calibration; it ignores @@ -531,15 +100,12 @@ impl<'a> IncompleteSolutions<'a> { min_threshold, } = self; - let obs_context = params.get_obs_context(); - let total_num_tiles = params.get_total_num_tiles(); - // TODO: Picket fences. - let flagged_chanblock_indices = ¶ms.fences.first().flagged_chanblock_indices; - // TODO: Don't use the obs_context here. This needs to be the centroid - // frequencies of the chanblocks. This only works because frequency - // averaging (i.e. more than one channel per chanblock) isn't possible - // right now. - let chanblock_freqs = obs_context.fine_chan_freqs.mapped_ref(|&u| u as f64); + let input_vis_params = ¶ms.input_vis_params; + let obs_context = input_vis_params.get_obs_context(); + let total_num_tiles = obs_context.get_total_num_tiles(); + let flagged_chanblock_indices = &input_vis_params.spw.flagged_chanblock_indices; + let flagged_tiles = &input_vis_params.tile_baseline_flags.flagged_tiles; + let chanblock_freqs = input_vis_params.spw.get_all_freqs(); let (num_timeblocks, num_unflagged_tiles, num_unflagged_chanblocks) = di_jones.dim(); let total_num_chanblocks = chanblocks.len() + flagged_chanblock_indices.len(); @@ -549,10 +115,7 @@ impl<'a> IncompleteSolutions<'a> { assert!(!timeblocks.is_empty()); assert_eq!(num_timeblocks, timeblocks.len()); assert!(num_unflagged_tiles <= total_num_tiles); - assert_eq!( - num_unflagged_tiles + params.tile_baseline_flags.flagged_tiles.len(), - total_num_tiles - ); + assert_eq!(num_unflagged_tiles + flagged_tiles.len(), total_num_tiles); assert_eq!(num_unflagged_chanblocks, chanblocks.len()); assert_eq!( num_unflagged_chanblocks + flagged_chanblock_indices.len(), @@ -599,7 +162,7 @@ impl<'a> IncompleteSolutions<'a> { .enumerate() .for_each(|(i_tile, mut out_di_jones)| { // Nothing needs to be done if this tile is flagged. - if !params.tile_baseline_flags.flagged_tiles.contains(&i_tile) { + if !flagged_tiles.contains(&i_tile) { // Iterate over the chanblocks. let mut i_unflagged_chanblock = 0; out_di_jones.iter_mut().enumerate().for_each( @@ -655,9 +218,7 @@ impl<'a> IncompleteSolutions<'a> { let mut i_baseline = 0; for i_tile_1 in 0..total_num_tiles { for i_tile_2 in i_tile_1 + 1..total_num_tiles { - if params.tile_baseline_flags.flagged_tiles.contains(&i_tile_1) - || params.tile_baseline_flags.flagged_tiles.contains(&i_tile_2) - { + if flagged_tiles.contains(&i_tile_1) || flagged_tiles.contains(&i_tile_2) { i_baseline += 1; continue; } else { @@ -672,23 +233,19 @@ impl<'a> IncompleteSolutions<'a> { CalibrationSolutions { di_jones: out_di_jones, - flagged_tiles: params - .tile_baseline_flags - .flagged_tiles - .iter() - .copied() - .sorted() - .collect(), - flagged_chanblocks: flagged_chanblock_indices.clone(), + flagged_tiles: flagged_tiles.iter().copied().sorted().collect(), + flagged_chanblocks: flagged_chanblock_indices.iter().cloned().collect(), chanblock_freqs: Some(chanblock_freqs), obsid: obs_context.obsid, start_timestamps: Some(timeblocks.mapped_ref(|tb| *tb.timestamps.first())), end_timestamps: Some(timeblocks.mapped_ref(|tb| *tb.timestamps.last())), - average_timestamps: Some(timeblocks.mapped_ref(|tb| average_epoch(&tb.timestamps))), + average_timestamps: Some( + timeblocks.mapped_ref(|tb| average_epoch(tb.timestamps.iter().copied())), + ), max_iterations: Some(max_iterations), stop_threshold: Some(stop_threshold), min_threshold: Some(min_threshold), - raw_data_corrections: params.raw_data_corrections, + raw_data_corrections: input_vis_params.vis_reader.get_raw_data_corrections(), tile_names: Some(obs_context.tile_names.clone()), dipole_gains: Some(params.beam.get_dipole_gains()), dipole_delays: params.beam.get_dipole_delays(), @@ -698,22 +255,9 @@ impl<'a> IncompleteSolutions<'a> { uvw_min: Some(params.uvw_min), uvw_max: Some(params.uvw_max), freq_centroid: Some(params.freq_centroid), - modeller: match ¶ms.modeller_info { - ModellerInfo::Cpu => Some("CPU".to_string()), - - #[cfg(feature = "cuda")] - ModellerInfo::Cuda { - device_info, - driver_info, - } => Some(format!( - "{} (capability {}, {} MiB), CUDA driver {}, runtime {}", - device_info.name, - device_info.capability, - device_info.total_global_mem, - driver_info.driver_version, - driver_info.runtime_version - )), - }, + modeller: Some(MODEL_DEVICE.load().get_device_info().expect( + "unlikely to fail as device info should've successfully been retrieved earlier", + )), } } } @@ -738,7 +282,6 @@ pub fn calibrate_timeblocks<'a>( stop_threshold: f64, min_threshold: f64, pols: Polarisations, - draw_progress_bar: bool, print_convergence_messages: bool, ) -> (IncompleteSolutions<'a>, Array2) { let num_unflagged_tiles = num_tiles_from_num_cross_correlation_baselines(vis_data_tfb.dim().2); @@ -749,11 +292,7 @@ pub fn calibrate_timeblocks<'a>( let cal_results = if num_timeblocks == 1 { // Calibrate all timesteps together. - let pb = make_calibration_progress_bar( - num_chanblocks, - "Calibrating".to_string(), - draw_progress_bar, - ); + let pb = make_calibration_progress_bar(num_chanblocks, "Calibrating".to_string()); let cal_results = calibrate_timeblock( vis_data_tfb.view(), vis_model_tfb.view(), @@ -781,7 +320,6 @@ pub fn calibrate_timeblocks<'a>( let pb = make_calibration_progress_bar( num_chanblocks, "Calibrating all timeblocks together".to_string(), - draw_progress_bar, ); // This timeblock represents all timeblocks. let timeblock = { @@ -826,7 +364,6 @@ pub fn calibrate_timeblocks<'a>( i_timeblock + 1, num_timeblocks ), - draw_progress_bar, ); let mut cal_results = calibrate_timeblock( vis_data_tfb.view(), @@ -867,16 +404,11 @@ pub fn calibrate_timeblocks<'a>( ) } -/// Convenience function to make a progress bar while calibrating. `draw` -/// determines if the progress bar is actually displayed. -fn make_calibration_progress_bar( - num_chanblocks: usize, - message: String, - draw: bool, -) -> ProgressBar { +/// Convenience function to make a progress bar while calibrating. +fn make_calibration_progress_bar(num_chanblocks: usize, message: String) -> ProgressBar { ProgressBar::with_draw_target( Some(num_chanblocks as _), - if draw { + if PROGRESS_BARS.load() { // Use stdout, not stderr, because the messages printed by the // progress bar are valuable. ProgressDrawTarget::stdout() @@ -1003,7 +535,7 @@ fn calibrate_timeblock( // that failed. Then find the next that succeeded. With a // converged solution on both sides (or either side) of the // failures, use a weighted average for a guess of what the - // Jones matrices should be, then re-run MitchCal. + // Jones matrices should be, then re-run calibration. let mut left = None; let mut pairs = vec![]; let mut in_failures = false; diff --git a/src/di_calibrate/tests.rs b/src/di_calibrate/tests.rs index 40c39f9f..22858ecf 100644 --- a/src/di_calibrate/tests.rs +++ b/src/di_calibrate/tests.rs @@ -6,71 +6,30 @@ use std::{ collections::{HashMap, HashSet}, + num::NonZeroUsize, path::PathBuf, }; use approx::{assert_abs_diff_eq, assert_abs_diff_ne}; -use clap::Parser; use hifitime::{Duration, Epoch}; use indicatif::{ProgressBar, ProgressDrawTarget}; -use marlu::{ - constants::{MWA_HEIGHT_M, MWA_LAT_DEG, MWA_LONG_DEG}, - Jones, LatLngHeight, -}; +use marlu::Jones; use ndarray::prelude::*; -use serial_test::serial; -use tempfile::TempDir; use vec1::{vec1, Vec1}; -use super::{calibrate, calibrate_timeblocks, CalVis, IncompleteSolutions}; +use super::{calibrate, calibrate_timeblocks, DiCalParams, IncompleteSolutions}; use crate::{ - averaging::{channels_to_chanblocks, timesteps_to_timeblocks, Chanblock, Fence, Timeblock}, + averaging::{channels_to_chanblocks, timesteps_to_timeblocks, Chanblock, Spw, Timeblock}, beam::create_no_beam_object, - cli::di_calibrate::DiCalParams, context::Polarisations, di_calibrate::calibrate_timeblock, - io::read::{ - fits::{fits_get_col, fits_get_required_key, fits_open, fits_open_hdu}, - MsReader, RawDataCorrections, RawDataReader, VisRead, - }, + io::read::{RawDataCorrections, RawDataReader}, math::{is_prime, TileBaselineFlags}, + params::{InputVisParams, ModellingParams}, solutions::CalSolutionType, srclist::SourceList, - tests::reduced_obsids::get_reduced_1090008640, - CalibrationSolutions, DiCalArgs, VisSimulateArgs, }; -#[test] -#[serial] -fn test_1090008640_di_calibrate_writes_solutions() { - let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let args = get_reduced_1090008640(true, false); - let data = args.data.unwrap(); - let metafits = &data[0]; - let gpufits = &data[1]; - let sols = tmp_dir.path().join("sols.fits"); - let cal_model = tmp_dir.path().join("hyp_model.uvfits"); - - #[rustfmt::skip] - let cal_args = DiCalArgs::parse_from([ - "di-calibrate", - "--data", metafits, gpufits, - "--source-list", &args.source_list.unwrap(), - "--outputs", &format!("{}", sols.display()), - "--model-filenames", &format!("{}", cal_model.display()), - "--no-progress-bars", - ]); - - // Run di-cal and check that it succeeds - let result = cal_args.run(false); - assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); - - // check solutions file has been created, is readable - assert!(sols.exists(), "sols file not written"); - let sol_data = CalibrationSolutions::read_solutions_from_ext(sols, metafits.into()).unwrap(); - assert_eq!(sol_data.obsid, Some(1090008640)); -} - /// Make some data "four times as bright as the model". The solutions should /// then be all "twos". As data and model visibilities are given per baseline /// and solutions are given per tile, the per tile values should be the sqrt of @@ -279,71 +238,72 @@ fn test_calibrate_trivial_with_flags() { fn get_default_params() -> DiCalParams { let e = Epoch::from_gpst_seconds(1090008640.0); DiCalParams { - input_data: Box::new( - RawDataReader::new( - &"test_files/1090008640/1090008640.metafits", - &["test_files/1090008640/1090008640_20140721201027_gpubox01_00.fits"], - None, - RawDataCorrections::default(), - Some(LatLngHeight { - longitude_rad: 0.0, - latitude_rad: 0.0, - height_metres: 0.0, - }), - ) - .unwrap(), - ), - raw_data_corrections: None, + input_vis_params: InputVisParams { + vis_reader: Box::new( + RawDataReader::new( + &PathBuf::from("test_files/1090008640/1090008640.metafits"), + &[PathBuf::from( + "test_files/1090008640/1090008640_20140721201027_gpubox01_00.fits", + )], + None, + RawDataCorrections::default(), + None, + ) + .unwrap(), + ), + solutions: None, + timeblocks: vec1![Timeblock { + index: 0, + range: 0..1, + timestamps: vec1![e], + timesteps: vec1![0], + median: e, + }], + time_res: Duration::from_seconds(2.0), + spw: Spw { + chanblocks: vec![Chanblock { + chanblock_index: 0, + unflagged_index: 0, + freq: 10.0, + }], + flagged_chan_indices: HashSet::new(), + flagged_chanblock_indices: HashSet::new(), + chans_per_chanblock: NonZeroUsize::new(1).unwrap(), + freq_res: 1.0, + first_freq: 10.0, + }, + tile_baseline_flags: TileBaselineFlags { + tile_to_unflagged_cross_baseline_map: HashMap::new(), + tile_to_unflagged_auto_index_map: HashMap::new(), + unflagged_cross_baseline_to_tile_map: HashMap::new(), + unflagged_auto_index_to_tile_map: HashMap::new(), + flagged_tiles: HashSet::new(), + }, + using_autos: false, + ignore_weights: false, + dut1: Duration::default(), + }, beam: create_no_beam_object(1), source_list: SourceList::new(), - uvw_min: 0.0, - uvw_max: f64::INFINITY, - freq_centroid: 150e6, - baseline_weights: Vec1::try_from_vec(vec![1.0; 8128]).unwrap(), - timeblocks: vec1![Timeblock { + cal_timeblocks: vec1![Timeblock { index: 0, range: 0..1, timestamps: vec1![e], - median: e - }], - timesteps: vec1![0], - freq_average_factor: 1, - fences: vec1![Fence { - chanblocks: vec![Chanblock { - chanblock_index: 0, - unflagged_index: 0, - _freq: 10.0 - }], - flagged_chanblock_indices: vec![], - _first_freq: 10.0, - _freq_res: Some(1.0) + timesteps: vec1![0], + median: e, }], - unflagged_fine_chan_freqs: vec![0.0], - flagged_fine_chans: HashSet::new(), - tile_baseline_flags: TileBaselineFlags { - tile_to_unflagged_cross_baseline_map: HashMap::new(), - tile_to_unflagged_auto_index_map: HashMap::new(), - unflagged_cross_baseline_to_tile_map: HashMap::new(), - unflagged_auto_index_to_tile_map: HashMap::new(), - flagged_tiles: HashSet::new(), - }, - unflagged_tile_xyzs: vec![], - array_position: LatLngHeight { - longitude_rad: 0.0, - latitude_rad: 0.0, - height_metres: 0.0, - }, - dut1: Duration::from_seconds(0.0), - apply_precession: false, + uvw_min: 0.0, + uvw_max: f64::INFINITY, + freq_centroid: 150e6, + baseline_weights: Vec1::try_from_vec(vec![1.0; 8128]).unwrap(), max_iterations: 50, stop_threshold: 1e-6, min_threshold: 1e-3, - output_solutions_filenames: vec![(CalSolutionType::Fits, PathBuf::from("asdf.fits"))], - model_files: None, - output_model_time_average_factor: 1, - output_model_freq_average_factor: 1, - no_progress_bars: true, - modeller_info: crate::model::ModellerInfo::Cpu, + output_solution_files: vec1![(PathBuf::from("asdf.fits"), CalSolutionType::Fits)], + output_model_vis_params: None, + modelling_params: ModellingParams { + apply_precession: true, + }, } } @@ -352,34 +312,38 @@ fn get_default_params() -> DiCalParams { #[test] fn incomplete_to_complete_trivial() { let mut params = get_default_params(); - params.timeblocks = vec1![Timeblock { + params.input_vis_params.timeblocks = vec1![Timeblock { index: 0, range: 0..1, timestamps: vec1![Epoch::from_gpst_seconds(1065880128.0)], + timesteps: vec1![0], median: Epoch::from_gpst_seconds(1065880128.0), }]; - params.fences.first_mut().chanblocks = vec![ + params.input_vis_params.spw.chanblocks = vec![ Chanblock { chanblock_index: 0, unflagged_index: 0, - _freq: 150e6, + freq: 150e6, }, Chanblock { chanblock_index: 1, unflagged_index: 1, - _freq: 151e6, + freq: 151e6, }, Chanblock { chanblock_index: 2, unflagged_index: 2, - _freq: 152e6, + freq: 152e6, }, ]; - params.fences.first_mut().flagged_chanblock_indices = vec![]; - params.tile_baseline_flags.flagged_tiles = HashSet::new(); - let num_timeblocks = params.timeblocks.len(); - let num_tiles = params.get_total_num_tiles(); - let num_chanblocks = params.fences.first().chanblocks.len(); + params.input_vis_params.spw.flagged_chanblock_indices = HashSet::new(); + params.input_vis_params.tile_baseline_flags.flagged_tiles = HashSet::new(); + let num_timeblocks = params.input_vis_params.timeblocks.len(); + let num_tiles = params + .input_vis_params + .get_obs_context() + .get_total_num_tiles(); + let num_chanblocks = params.input_vis_params.spw.chanblocks.len(); let incomplete_di_jones: Vec> = (0..num_tiles * num_chanblocks) .map(|i| Jones::identity() * (i + 1) as f64 * if is_prime(i) { 1.0 } else { 0.5 }) @@ -391,8 +355,8 @@ fn incomplete_to_complete_trivial() { .unwrap(); let incomplete = IncompleteSolutions { di_jones: incomplete_di_jones.clone(), - timeblocks: ¶ms.timeblocks, - chanblocks: ¶ms.fences.first().chanblocks, + timeblocks: ¶ms.input_vis_params.timeblocks, + chanblocks: ¶ms.input_vis_params.spw.chanblocks, max_iterations: 50, stop_threshold: 1e-8, min_threshold: 1e-4, @@ -414,35 +378,44 @@ fn incomplete_to_complete_trivial() { #[test] fn incomplete_to_complete_flags_simple() { let mut params = get_default_params(); - params.timeblocks = vec1![Timeblock { + params.input_vis_params.timeblocks = vec1![Timeblock { index: 0, range: 0..1, timestamps: vec1![Epoch::from_gpst_seconds(1065880128.0)], + timesteps: vec1![0], median: Epoch::from_gpst_seconds(1065880128.0) }]; - params.fences.first_mut().chanblocks = vec![ + params.input_vis_params.spw.chanblocks = vec![ Chanblock { chanblock_index: 1, unflagged_index: 0, - _freq: 151e6, + freq: 151e6, }, Chanblock { chanblock_index: 2, unflagged_index: 1, - _freq: 152e6, + freq: 152e6, }, Chanblock { chanblock_index: 3, unflagged_index: 2, - _freq: 153e6, + freq: 153e6, }, ]; - params.fences.first_mut().flagged_chanblock_indices = vec![0]; - params.tile_baseline_flags.flagged_tiles = HashSet::new(); - let num_timeblocks = params.timeblocks.len(); - let total_num_tiles = params.get_total_num_tiles(); - let num_tiles = total_num_tiles - params.tile_baseline_flags.flagged_tiles.len(); - let num_chanblocks = params.fences.first().chanblocks.len(); + params.input_vis_params.spw.flagged_chanblock_indices = HashSet::from([0]); + params.input_vis_params.tile_baseline_flags.flagged_tiles = HashSet::new(); + let num_timeblocks = params.input_vis_params.timeblocks.len(); + let total_num_tiles = params + .input_vis_params + .get_obs_context() + .get_total_num_tiles(); + let num_tiles = total_num_tiles + - params + .input_vis_params + .tile_baseline_flags + .flagged_tiles + .len(); + let num_chanblocks = params.input_vis_params.spw.chanblocks.len(); let di_jones: Vec> = (0..num_tiles * num_chanblocks) .map(|i| Jones::identity() * (i + 1) as f64 * if is_prime(i) { 1.0 } else { 0.5 }) @@ -451,8 +424,8 @@ fn incomplete_to_complete_flags_simple() { Array3::from_shape_vec((num_timeblocks, num_tiles, num_chanblocks), di_jones).unwrap(); let incomplete = IncompleteSolutions { di_jones: incomplete_di_jones.clone(), - timeblocks: ¶ms.timeblocks, - chanblocks: ¶ms.fences.first().chanblocks, + timeblocks: ¶ms.input_vis_params.timeblocks, + chanblocks: ¶ms.input_vis_params.spw.chanblocks, max_iterations: 50, stop_threshold: 1e-8, min_threshold: 1e-4, @@ -486,35 +459,43 @@ fn incomplete_to_complete_flags_simple() { #[test] fn incomplete_to_complete_flags_simple2() { let mut params = get_default_params(); - params.timeblocks = vec1![Timeblock { + params.input_vis_params.timeblocks = vec1![Timeblock { index: 0, range: 0..1, timestamps: vec1![Epoch::from_gpst_seconds(1065880128.0)], + timesteps: vec1![0], median: Epoch::from_gpst_seconds(1065880128.0) }]; - params.fences.first_mut().chanblocks = vec![ + params.input_vis_params.spw.chanblocks = vec![ Chanblock { chanblock_index: 0, unflagged_index: 0, - _freq: 151e6, + freq: 151e6, }, Chanblock { chanblock_index: 1, unflagged_index: 1, - _freq: 152e6, + freq: 152e6, }, Chanblock { chanblock_index: 2, unflagged_index: 2, - _freq: 153e6, + freq: 153e6, }, ]; - params.tile_baseline_flags.flagged_tiles = HashSet::new(); - params.fences.first_mut().flagged_chanblock_indices = vec![3]; - let num_timeblocks = params.timeblocks.len(); - let num_chanblocks = params.fences.first().chanblocks.len(); - let total_num_tiles = params.get_total_num_tiles(); - let num_tiles = total_num_tiles - params.tile_baseline_flags.flagged_tiles.len(); + params.input_vis_params.spw.flagged_chanblock_indices = HashSet::from([3]); + let num_timeblocks = params.input_vis_params.timeblocks.len(); + let num_chanblocks = params.input_vis_params.spw.chanblocks.len(); + let total_num_tiles = params + .input_vis_params + .get_obs_context() + .get_total_num_tiles(); + let num_tiles = total_num_tiles + - params + .input_vis_params + .tile_baseline_flags + .flagged_tiles + .len(); let incomplete_di_jones: Vec> = (0..num_tiles * num_chanblocks) .map(|i| Jones::identity() * (i + 1) as f64 * if is_prime(i) { 1.0 } else { 0.5 }) @@ -526,8 +507,8 @@ fn incomplete_to_complete_flags_simple2() { .unwrap(); let incomplete = IncompleteSolutions { di_jones: incomplete_di_jones.clone(), - timeblocks: ¶ms.timeblocks, - chanblocks: ¶ms.fences.first().chanblocks, + timeblocks: ¶ms.input_vis_params.timeblocks, + chanblocks: ¶ms.input_vis_params.spw.chanblocks, max_iterations: 50, stop_threshold: 1e-8, min_threshold: 1e-4, @@ -560,37 +541,46 @@ fn incomplete_to_complete_flags_simple2() { #[test] fn incomplete_to_complete_flags_complex() { let mut params = get_default_params(); - params.timeblocks = vec1![Timeblock { + params.input_vis_params.timeblocks = vec1![Timeblock { index: 0, range: 0..1, timestamps: vec1![Epoch::from_gpst_seconds(1065880128.0)], + timesteps: vec1![0], median: Epoch::from_gpst_seconds(1065880128.0) }]; - params.fences.first_mut().chanblocks = vec![ + params.input_vis_params.spw.chanblocks = vec![ Chanblock { chanblock_index: 0, unflagged_index: 0, - _freq: 150e6, + freq: 150e6, }, Chanblock { chanblock_index: 2, unflagged_index: 1, - _freq: 152e6, + freq: 152e6, }, Chanblock { chanblock_index: 3, unflagged_index: 2, - _freq: 153e6, + freq: 153e6, }, ]; - params.fences.first_mut().flagged_chanblock_indices = vec![1]; - params.tile_baseline_flags.flagged_tiles = HashSet::from([2]); - let num_timeblocks = params.timeblocks.len(); - let num_chanblocks = params.fences.first().chanblocks.len(); - let total_num_tiles = params.get_total_num_tiles(); - let num_tiles = total_num_tiles - params.tile_baseline_flags.flagged_tiles.len(); + params.input_vis_params.spw.flagged_chanblock_indices = HashSet::from([1]); + params.input_vis_params.tile_baseline_flags.flagged_tiles = HashSet::from([2]); + let num_timeblocks = params.input_vis_params.timeblocks.len(); + let num_chanblocks = params.input_vis_params.spw.chanblocks.len(); + let total_num_tiles = params + .input_vis_params + .get_obs_context() + .get_total_num_tiles(); + let num_tiles = total_num_tiles + - params + .input_vis_params + .tile_baseline_flags + .flagged_tiles + .len(); let total_num_chanblocks = - num_chanblocks + params.fences.first().flagged_chanblock_indices.len(); + num_chanblocks + params.input_vis_params.spw.flagged_chanblock_indices.len(); // Cower at my evil, awful code. let mut primes = vec1![2]; @@ -608,8 +598,8 @@ fn incomplete_to_complete_flags_complex() { .unwrap(); let incomplete = IncompleteSolutions { di_jones: incomplete_di_jones, - timeblocks: ¶ms.timeblocks, - chanblocks: ¶ms.fences.first().chanblocks, + timeblocks: ¶ms.input_vis_params.timeblocks, + chanblocks: ¶ms.input_vis_params.spw.chanblocks, max_iterations: 50, stop_threshold: 1e-8, min_threshold: 1e-4, @@ -626,13 +616,18 @@ fn incomplete_to_complete_flags_complex() { let sub_array = complete.di_jones.slice(s![0, i_tile, ..]); let mut i_unflagged_chanblock = 0; - if params.tile_baseline_flags.flagged_tiles.contains(&i_tile) { + if params + .input_vis_params + .tile_baseline_flags + .flagged_tiles + .contains(&i_tile) + { assert!(sub_array.iter().all(|j| j.any_nan())); } else { for i_chan in 0..total_num_chanblocks { if params - .fences - .first() + .input_vis_params + .spw .flagged_chanblock_indices .contains(&(i_chan as u16)) { @@ -659,373 +654,6 @@ fn incomplete_to_complete_flags_complex() { assert!(complete.flagged_chanblocks.contains(&1)); } -#[test] -fn test_1090008640_di_calibrate_uses_array_position() { - let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let args = get_reduced_1090008640(true, false); - let data = args.data.unwrap(); - let metafits = &data[0]; - let gpufits = &data[1]; - let sols = tmp_dir.path().join("sols.fits"); - let cal_model = tmp_dir.path().join("hyp_model.uvfits"); - - // with non-default array position - let exp_lat_deg = MWA_LAT_DEG - 1.; - let exp_long_deg = MWA_LONG_DEG - 1.; - let exp_height_m = MWA_HEIGHT_M - 1.; - - #[rustfmt::skip] - let cal_args = DiCalArgs::parse_from([ - "di-calibrate", - "--data", metafits, gpufits, - "--source-list", args.source_list.as_ref().unwrap(), - "--outputs", &format!("{}", sols.display()), - "--model-filenames", &format!("{}", cal_model.display()), - "--array-position", - &format!("{exp_long_deg}"), - &format!("{exp_lat_deg}"), - &format!("{exp_height_m}"), - "--no-progress-bars", - ]); - - let pos = cal_args.array_position.unwrap(); - - assert_abs_diff_eq!(pos[0], exp_long_deg); - assert_abs_diff_eq!(pos[1], exp_lat_deg); - assert_abs_diff_eq!(pos[2], exp_height_m); -} - -#[test] -fn test_1090008640_di_calibrate_array_pos_requires_3_args() { - let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let args = get_reduced_1090008640(true, false); - let data = args.data.unwrap(); - let metafits = &data[0]; - let gpufits = &data[1]; - let sols = tmp_dir.path().join("sols.fits"); - let cal_model = tmp_dir.path().join("hyp_model.uvfits"); - - // no height specified - let exp_lat_deg = MWA_LAT_DEG - 1.; - let exp_long_deg = MWA_LONG_DEG - 1.; - - #[rustfmt::skip] - let result = DiCalArgs::try_parse_from([ - "di-calibrate", - "--data", metafits, gpufits, - "--source-list", args.source_list.as_ref().unwrap(), - "--outputs", &format!("{}", sols.display()), - "--model-filenames", &format!("{}", cal_model.display()), - "--array-position", - &format!("{exp_long_deg}"), - &format!("{exp_lat_deg}"), - ]); - - assert!(result.is_err()); - assert!(matches!( - result.err().unwrap().kind(), - clap::ErrorKind::WrongNumberOfValues - )); -} - -#[test] -/// Generate a model with "vis-simulate" (in uvfits), then feed it to -/// "di-calibrate" and write out the model used for calibration (as uvfits). The -/// visibilities should be exactly the same. -fn test_1090008640_calibrate_model_uvfits() { - let num_timesteps = 2; - let num_chans = 10; - - let temp_dir = TempDir::new().expect("couldn't make tmp dir"); - let model = temp_dir.path().join("model.uvfits"); - let args = get_reduced_1090008640(false, false); - let metafits = &args.data.as_ref().unwrap()[0]; - let srclist = args.source_list.unwrap(); - #[rustfmt::skip] - let sim_args = VisSimulateArgs::parse_from([ - "vis-simulate", - "--metafits", metafits, - "--source-list", &srclist, - "--output-model-files", &format!("{}", model.display()), - "--num-timesteps", &format!("{num_timesteps}"), - "--num-fine-channels", &format!("{num_chans}"), - "--veto-threshold", "0.0", // Don't complicate things with vetoing - // The array position is needed because, if not specified, it's read - // slightly different out of the uvfits. - "--array-position", "116.67081523611111", "-26.703319405555554", "377.827", - "--no-progress-bars", - ]); - - // Run vis-simulate and check that it succeeds - let result = sim_args.run(false); - assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); - - let sols = temp_dir.path().join("sols.fits"); - let cal_model = temp_dir.path().join("cal_model.uvfits"); - - #[rustfmt::skip] - let cal_args = DiCalArgs::parse_from([ - "di-calibrate", - "--data", &format!("{}", model.display()), metafits, - "--source-list", &srclist, - "--outputs", &format!("{}", sols.display()), - "--model-filenames", &format!("{}", cal_model.display()), - "--veto-threshold", "0.0", // Don't complicate things with vetoing - "--array-position", "116.67081523611111", "-26.703319405555554", "377.827", - "--no-progress-bars", - ]); - - // Run di-cal and check that it succeeds - let result = cal_args.into_params().unwrap().calibrate(); - assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); - let sols = result.unwrap(); - - let mut uvfits_m = fits_open(&model).unwrap(); - let hdu_m = fits_open_hdu(&mut uvfits_m, 0).unwrap(); - let gcount_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "GCOUNT").unwrap(); - let pcount_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "PCOUNT").unwrap(); - let floats_per_pol_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "NAXIS2").unwrap(); - let num_pols_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "NAXIS3").unwrap(); - let num_fine_freq_chans_m: String = - fits_get_required_key(&mut uvfits_m, &hdu_m, "NAXIS4").unwrap(); - let jd_zero_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "PZERO5").unwrap(); - let ptype4_m: String = fits_get_required_key(&mut uvfits_m, &hdu_m, "PTYPE4").unwrap(); - - let mut uvfits_c = fits_open(&cal_model).unwrap(); - let hdu_c = fits_open_hdu(&mut uvfits_c, 0).unwrap(); - let gcount_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "GCOUNT").unwrap(); - let pcount_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "PCOUNT").unwrap(); - let floats_per_pol_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "NAXIS2").unwrap(); - let num_pols_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "NAXIS3").unwrap(); - let num_fine_freq_chans_c: String = - fits_get_required_key(&mut uvfits_c, &hdu_c, "NAXIS4").unwrap(); - let jd_zero_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "PZERO5").unwrap(); - let ptype4_c: String = fits_get_required_key(&mut uvfits_c, &hdu_c, "PTYPE4").unwrap(); - - let pcount: usize = pcount_m.parse().unwrap(); - assert_eq!(pcount, 7); - assert_eq!(gcount_m, gcount_c); - assert_eq!(pcount_m, pcount_c); - assert_eq!(floats_per_pol_m, floats_per_pol_c); - assert_eq!(num_pols_m, num_pols_c); - assert_eq!(num_fine_freq_chans_m, num_fine_freq_chans_c); - assert_eq!(jd_zero_m, jd_zero_c); - assert_eq!(ptype4_m, ptype4_c); - - let hdu_m = fits_open_hdu(&mut uvfits_m, 1).unwrap(); - let tile_names_m: Vec = fits_get_col(&mut uvfits_m, &hdu_m, "ANNAME").unwrap(); - let hdu_c = fits_open_hdu(&mut uvfits_c, 1).unwrap(); - let tile_names_c: Vec = fits_get_col(&mut uvfits_c, &hdu_c, "ANNAME").unwrap(); - for (tile_m, tile_c) in tile_names_m.into_iter().zip(tile_names_c.into_iter()) { - assert_eq!(tile_m, tile_c); - } - - // Test visibility values. - fits_open_hdu(&mut uvfits_m, 0).unwrap(); - let mut group_params_m = Array1::zeros(pcount); - let mut vis_m = Array1::zeros(num_chans * 4 * 3); - fits_open_hdu(&mut uvfits_c, 0).unwrap(); - let mut group_params_c = group_params_m.clone(); - let mut vis_c = vis_m.clone(); - - let mut status = 0; - for i_row in 0..gcount_m.parse::().unwrap() { - unsafe { - // ffggpe = fits_read_grppar_flt - fitsio_sys::ffggpe( - uvfits_m.as_raw(), /* I - FITS file pointer */ - 1 + i_row, /* I - group to read (1 = 1st group) */ - 1, /* I - first vector element to read (1 = 1st) */ - group_params_m.len() as i64, /* I - number of values to read */ - group_params_m.as_mut_ptr(), /* O - array of values that are returned */ - &mut status, /* IO - error status */ - ); - assert_eq!(status, 0, "Status wasn't 0"); - assert_abs_diff_ne!(group_params_m, group_params_c); - // ffggpe = fits_read_grppar_flt - fitsio_sys::ffggpe( - uvfits_c.as_raw(), /* I - FITS file pointer */ - 1 + i_row, /* I - group to read (1 = 1st group) */ - 1, /* I - first vector element to read (1 = 1st) */ - group_params_c.len() as i64, /* I - number of values to read */ - group_params_c.as_mut_ptr(), /* O - array of values that are returned */ - &mut status, /* IO - error status */ - ); - assert_eq!(status, 0, "Status wasn't 0"); - assert_abs_diff_eq!(group_params_m, group_params_c); - - // ffgpve = fits_read_img_flt - fitsio_sys::ffgpve( - uvfits_m.as_raw(), /* I - FITS file pointer */ - 1 + i_row, /* I - group to read (1 = 1st group) */ - 1, /* I - first vector element to read (1 = 1st) */ - vis_m.len() as i64, /* I - number of values to read */ - 0.0, /* I - value for undefined pixels */ - vis_m.as_mut_ptr(), /* O - array of values that are returned */ - &mut 0, /* O - set to 1 if any values are null; else 0 */ - &mut status, /* IO - error status */ - ); - assert_abs_diff_ne!(vis_m, vis_c); - // ffgpve = fits_read_img_flt - fitsio_sys::ffgpve( - uvfits_c.as_raw(), /* I - FITS file pointer */ - 1 + i_row, /* I - group to read (1 = 1st group) */ - 1, /* I - first vector element to read (1 = 1st) */ - vis_c.len() as i64, /* I - number of values to read */ - 0.0, /* I - value for undefined pixels */ - vis_c.as_mut_ptr(), /* O - array of values that are returned */ - &mut 0, /* O - set to 1 if any values are null; else 0 */ - &mut status, /* IO - error status */ - ); - assert_eq!(status, 0, "Status wasn't 0"); - assert_abs_diff_eq!(vis_m, vis_c); - }; - } - - // Inspect the solutions; they should all be close to identity. - assert_abs_diff_eq!( - sols.di_jones, - Array3::from_elem(sols.di_jones.dim(), Jones::identity()), - epsilon = 1e-15 - ); -} - -#[test] -#[serial] -/// Generate a model with "vis-simulate" (in a measurement set), then feed it to -/// "di-calibrate" and write out the model used for calibration (into a -/// measurement set). The visibilities should be exactly the same. -fn test_1090008640_calibrate_model_ms() { - let num_timesteps = 2; - let num_chans = 10; - - let temp_dir = TempDir::new().expect("couldn't make tmp dir"); - let model = temp_dir.path().join("model.ms"); - let args = get_reduced_1090008640(false, false); - let metafits = &args.data.as_ref().unwrap()[0]; - let srclist = args.source_list.unwrap(); - - // Non-default array position - let lat_deg = MWA_LAT_DEG - 1.; - let long_deg = MWA_LONG_DEG - 1.; - let height_m = MWA_HEIGHT_M - 1.; - - #[rustfmt::skip] - let sim_args = VisSimulateArgs::parse_from([ - "vis-simulate", - "--metafits", metafits, - "--source-list", &srclist, - "--output-model-files", &format!("{}", model.display()), - "--num-timesteps", &format!("{num_timesteps}"), - "--num-fine-channels", &format!("{num_chans}"), - "--array-position", - &format!("{long_deg}"), - &format!("{lat_deg}"), - &format!("{height_m}"), - "--no-progress-bars" - ]); - - // Run vis-simulate and check that it succeeds - let result = sim_args.run(false); - assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); - - let sols = temp_dir.path().join("sols.fits"); - let cal_model = temp_dir.path().join("cal_model.ms"); - #[rustfmt::skip] - let cal_args = DiCalArgs::parse_from([ - "di-calibrate", - "--data", &format!("{}", model.display()), metafits, - "--source-list", &srclist, - "--outputs", &format!("{}", sols.display()), - "--model-filenames", &format!("{}", cal_model.display()), - "--array-position", - &format!("{long_deg}"), - &format!("{lat_deg}"), - &format!("{height_m}"), - "--no-progress-bars" - ]); - - // Run di-cal and check that it succeeds - let result = cal_args.into_params().unwrap().calibrate(); - assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); - let sols = result.unwrap(); - - let array_pos = LatLngHeight::mwa(); - let ms_m = MsReader::new(model, None, Some(&PathBuf::from(metafits)), Some(array_pos)).unwrap(); - let ctx_m = ms_m.get_obs_context(); - let ms_c = MsReader::new( - cal_model, - None, - Some(&PathBuf::from(metafits)), - Some(array_pos), - ) - .unwrap(); - let ctx_c = ms_c.get_obs_context(); - assert_eq!(ctx_m.all_timesteps, ctx_c.all_timesteps); - assert_eq!(ctx_m.all_timesteps.len(), num_timesteps); - assert_eq!(ctx_m.timestamps, ctx_c.timestamps); - assert_eq!(ctx_m.fine_chan_freqs, ctx_c.fine_chan_freqs); - let m_flags = ctx_m.get_tile_flags(false, None).unwrap(); - let c_flags = ctx_c.get_tile_flags(false, None).unwrap(); - for m in &m_flags { - assert!(c_flags.contains(m)); - } - assert_eq!(ctx_m.tile_xyzs, ctx_c.tile_xyzs); - assert_eq!(ctx_m.flagged_fine_chans, ctx_c.flagged_fine_chans); - - let flagged_fine_chans_set: HashSet = ctx_m.flagged_fine_chans.iter().cloned().collect(); - let tile_baseline_flags = TileBaselineFlags::new(ctx_m.tile_xyzs.len(), m_flags); - let max_baseline_idx = tile_baseline_flags - .tile_to_unflagged_cross_baseline_map - .values() - .max() - .unwrap(); - let data_shape = ( - ctx_m.fine_chan_freqs.len() - ctx_m.flagged_fine_chans.len(), - max_baseline_idx + 1, - ); - let mut vis_m = Array2::>::zeros(data_shape); - let mut vis_c = Array2::>::zeros(data_shape); - let mut weight_m = Array2::::zeros(data_shape); - let mut weight_c = Array2::::zeros(data_shape); - - for ×tep in &ctx_m.all_timesteps { - ms_m.read_crosses( - vis_m.view_mut(), - weight_m.view_mut(), - timestep, - &tile_baseline_flags, - &flagged_fine_chans_set, - ) - .unwrap(); - ms_c.read_crosses( - vis_c.view_mut(), - weight_c.view_mut(), - timestep, - &tile_baseline_flags, - &flagged_fine_chans_set, - ) - .unwrap(); - - // Unlike the equivalent uvfits test, we have to use an epsilon here. - // This is due to the MS antenna positions being in geocentric - // coordinates and not geodetic like uvfits; in the process of - // converting from geocentric to geodetic, small float errors are - // introduced. If a metafits' positions are used instead, the results - // are *exactly* the same, but we should trust the MS's positions, so - // these errors must remain. - assert_abs_diff_eq!(vis_m, vis_c, epsilon = 5e-6); - assert_abs_diff_eq!(weight_m, weight_c); - } - - // Inspect the solutions; they should all be close to identity. - assert_abs_diff_eq!( - sols.di_jones, - Array3::from_elem(sols.di_jones.dim(), Jones::identity()), - epsilon = 5e-9 - ); -} - #[test] fn test_multiple_timeblocks_behave() { let timestamps = vec1![ @@ -1034,7 +662,6 @@ fn test_multiple_timeblocks_behave() { Epoch::from_gpst_seconds(1090008644.0), ]; let num_timesteps = timestamps.len(); - let timesteps_to_use = Vec1::try_from_vec((0..num_timesteps).collect()).unwrap(); let num_tiles = 5; let num_baselines = num_tiles * (num_tiles - 1) / 2; let num_chanblocks = 1; @@ -1043,20 +670,29 @@ fn test_multiple_timeblocks_behave() { let vis_data: Array3> = Array3::from_elem(vis_shape, Jones::identity() * 4.0); let vis_model: Array3> = Array3::from_elem(vis_shape, Jones::identity()); - let timeblocks = timesteps_to_timeblocks(×tamps, 1, ×teps_to_use); - let fences = channels_to_chanblocks(&[150000000], Some(40e3), 1, &HashSet::new()); + let timeblocks = timesteps_to_timeblocks( + ×tamps, + Duration::from_seconds(2.0), + NonZeroUsize::new(1).unwrap(), + None, + ); + let spws = channels_to_chanblocks( + &[150000000], + 40e3 as u64, + NonZeroUsize::new(1).unwrap(), + &HashSet::new(), + ); let (incomplete_sols, _) = calibrate_timeblocks( vis_data.view(), vis_model.view(), &timeblocks, - &fences.first().unwrap().chanblocks, + &spws.first().unwrap().chanblocks, 10, 1e-8, 1e-4, Polarisations::default(), false, - false, ); // The solutions for all timeblocks should be the same. @@ -1075,7 +711,6 @@ fn test_chanblocks_without_data_have_nan_solutions() { Epoch::from_gpst_seconds(1090008644.0), ]; let num_timesteps = timestamps.len(); - let timesteps_to_use = Vec1::try_from_vec((0..num_timesteps).collect()).unwrap(); let num_tiles = 5; let num_baselines = num_tiles * (num_tiles - 1) / 2; let freqs = [150000000]; @@ -1085,8 +720,18 @@ fn test_chanblocks_without_data_have_nan_solutions() { let vis_data: Array3> = Array3::zeros(vis_shape); let vis_model: Array3> = Array3::zeros(vis_shape); - let timeblocks = timesteps_to_timeblocks(×tamps, 1, ×teps_to_use); - let fences = channels_to_chanblocks(&freqs, Some(40e3), 1, &HashSet::new()); + let timeblocks = timesteps_to_timeblocks( + ×tamps, + Duration::from_seconds(2.0), + NonZeroUsize::new(1).unwrap(), + None, + ); + let fences = channels_to_chanblocks( + &freqs, + 40e3 as u64, + NonZeroUsize::new(1).unwrap(), + &HashSet::new(), + ); let (incomplete_sols, results) = calibrate_timeblocks( vis_data.view(), @@ -1098,7 +743,6 @@ fn test_chanblocks_without_data_have_nan_solutions() { 1e-4, Polarisations::default(), false, - false, ); // All solutions are NaN, because all data and model Jones matrices were 0. assert!(incomplete_sols.di_jones.into_iter().all(|j| j.any_nan())); @@ -1120,7 +764,6 @@ fn test_chanblocks_without_data_have_nan_solutions() { 1e-4, Polarisations::default(), false, - false, ); assert!(incomplete_sols.di_jones.into_iter().all(|j| !j.any_nan())); assert!(results.iter().all(|r| r.num_iterations == 1)); @@ -1130,7 +773,6 @@ fn test_chanblocks_without_data_have_nan_solutions() { fn test_recalibrating_failed_chanblocks() { let timestamps = vec1![Epoch::from_gpst_seconds(1090008640.0),]; let num_timesteps = timestamps.len(); - let timesteps_to_use = Vec1::try_from_vec((0..num_timesteps).collect()).unwrap(); let num_tiles = 5; let num_baselines = num_tiles * (num_tiles - 1) / 2; let freqs = [150000000, 150040000, 150080000]; @@ -1140,8 +782,18 @@ fn test_recalibrating_failed_chanblocks() { let vis_data: Array3> = Array3::from_elem(vis_shape, Jones::identity() * 4.0); let vis_model: Array3> = Array3::from_elem(vis_shape, Jones::identity()); - let timeblocks = timesteps_to_timeblocks(×tamps, 1, ×teps_to_use); - let fences = channels_to_chanblocks(&freqs, Some(40e3), 1, &HashSet::new()); + let timeblocks = timesteps_to_timeblocks( + ×tamps, + Duration::from_seconds(2.0), + NonZeroUsize::new(1).unwrap(), + None, + ); + let fences = channels_to_chanblocks( + &freqs, + 40000, + NonZeroUsize::new(1).unwrap(), + &HashSet::new(), + ); // Unlike `calibrate_timeblocks`, `calibrate_timeblock` takes in calibration // solutions. These are initially set to identity by `calibrate_timeblocks`; @@ -1190,77 +842,3 @@ fn test_recalibrating_failed_chanblocks() { assert!(!result.converged); } } - -/// Given calibration parameters and visibilities, this function tests that -/// everything matches an expected quality. The values may change over time but -/// they should be consistent with whatever tests use this test code. -pub(crate) fn test_1090008640_quality(params: DiCalParams, cal_vis: CalVis) { - let (_, cal_results) = calibrate_timeblocks( - cal_vis.vis_data_tfb.view(), - cal_vis.vis_model_tfb.view(), - ¶ms.timeblocks, - ¶ms.fences.first().chanblocks, - 50, - 1e-8, - 1e-4, - Polarisations::default(), - false, - false, - ); - - // Only one timeblock. - assert_eq!(cal_results.dim().0, 1); - - let mut count_50 = 0; - let mut count_42 = 0; - let mut chanblocks_42 = vec![]; - let mut fewest_iterations = u32::MAX; - for cal_result in cal_results { - match cal_result.num_iterations { - 50 => { - count_50 += 1; - fewest_iterations = fewest_iterations.min(cal_result.num_iterations); - } - 42 => { - count_42 += 1; - chanblocks_42.push(cal_result.chanblock.unwrap()); - fewest_iterations = fewest_iterations.min(cal_result.num_iterations); - } - 0 => panic!("0 iterations? Something is wrong."), - _ => { - if cal_result.num_iterations % 2 == 1 { - panic!("An odd number of iterations shouldn't be possible; at the time of writing, only even numbers are allowed."); - } - fewest_iterations = fewest_iterations.min(cal_result.num_iterations); - } - } - - assert!( - cal_result.converged, - "Chanblock {} did not converge", - cal_result.chanblock.unwrap() - ); - assert_eq!(cal_result.num_failed, 0); - assert!(cal_result.max_precision < 1e8); - } - - let expected_count_50 = 14; - let expected_count_42 = 1; - let expected_chanblocks_42 = vec![13]; - let expected_fewest_iterations = 40; - if count_50 != expected_count_50 - || count_42 != expected_count_42 - || chanblocks_42 != expected_chanblocks_42 - || fewest_iterations != expected_fewest_iterations - { - panic!( - r#" -Calibration quality has changed. This test expects: - {expected_count_50} chanblocks with 50 iterations (got {count_50}), - {expected_count_42} chanblocks with 42 iterations (got {count_42}), - chanblocks {expected_chanblocks_42:?} to need 42 iterations (got {chanblocks_42:?}), and - no chanblocks to finish in less than {expected_fewest_iterations} iterations (got {fewest_iterations}). -"# - ); - } -} diff --git a/src/error.rs b/src/error.rs deleted file mode 100644 index 92277b64..00000000 --- a/src/error.rs +++ /dev/null @@ -1,408 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//! Error type for all hyperdrive-related errors. This should be the *only* -//! error enum that is publicly visible. - -use thiserror::Error; - -use crate::{ - cli::{ - di_calibrate::DiCalArgsError, - solutions::{apply::SolutionsApplyError, plot::SolutionsPlotError}, - vis_utils::{simulate::VisSimulateError, subtract::VisSubtractError}, - }, - di_calibrate::DiCalibrateError, - filenames::InputFileError, - io::read::VisReadError, - solutions::{SolutionsReadError, SolutionsWriteError}, - srclist::SrclistError, -}; - -const URL: &str = "https://MWATelescope.github.io/mwa_hyperdrive"; - -/// The *only* publicly visible error from hyperdrive. Each error message should -/// include the URL, unless it's "generic". -#[derive(Error, Debug)] -pub enum HyperdriveError { - /// An error related to di-calibrate. - #[error("{0}\n\nSee for more info: {URL}/user/di_cal/intro.html")] - DiCalibrate(String), - - /// An error related to solutions-apply. - #[error("{0}\n\nSee for more info: {URL}/user/solutions_apply/intro.html")] - SolutionsApply(String), - - /// An error related to solutions-plot. - #[error("{0}\n\nSee for more info: {URL}/user/plotting.html")] - SolutionsPlot(String), - - /// An error related to vis-simulate. - #[error("{0}\n\nSee for more info: {URL}/user/vis_simulate/intro.html")] - VisSimulate(String), - - /// An error related to vis-subtract. - #[error("{0}\n\nSee for more info: {URL}/user/vis_subtract/intro.html")] - VisSubtract(String), - - /// Generic error surrounding source lists. - #[error("{0}\n\nSee for more info: {URL}/defs/source_lists.html")] - Srclist(String), - - /// Generic error surrounding calibration solutions. - #[error("{0}\n\nSee for more info: {URL}/defs/cal_sols.html")] - Solutions(String), - - /// Error specific to hyperdrive calibration solutions. - #[error("{0}\n\nSee for more info: {URL}/defs/cal_sols_hyp.html")] - SolutionsHyp(String), - - /// Error specific to AO calibration solutions. - #[error("{0}\n\nSee for more info: {URL}/defs/cal_sols_ao.html")] - SolutionsAO(String), - - /// Error specific to RTS calibration solutions. - #[error("{0}\n\nSee for more info: {URL}/defs/cal_sols_rts.html")] - SolutionsRts(String), - - /// An error related to reading visibilities. - #[error("{0}\n\nSee for more info: {URL}/defs/vis_formats_read.html")] - VisRead(String), - - /// An error related to reading visibilities. - #[error("{0}\n\nSee for more info: {URL}/defs/vis_formats_write.html")] - VisWrite(String), - - /// An error related to averaging. - #[error("{0}\n\nSee for more info: {URL}/defs/vis_formats_write.html#visibility-averaging")] - Averaging(String), - - /// An error related to raw MWA data corrections. - #[error("{0}\n\nSee for more info: {URL}/defs/mwa/corrections.html")] - RawDataCorrections(String), - - /// An error related to metafits files. - #[error("{0}\n\nSee for more info: {URL}/defs/mwa/metafits.html")] - Metafits(String), - - /// An error related to dipole delays. - #[error("{0}\n\nSee for more info: {URL}/defs/mwa/delays.html")] - Delays(String), - - /// An error related to mwaf files. - #[error("{0}\n\nSee for more info: {URL}/defs/mwa/mwaf.html")] - Mwaf(String), - - /// An error related to mwalib. - #[error("{0}\n\nSee for more info: {URL}/defs/mwa/mwalib.html")] - Mwalib(String), - - /// An error related to beam code. - #[error("{0}\n\nSee for more info: {URL}/defs/beam.html")] - Beam(String), - - /// A cfitsio error. Because these are usually quite spartan, some - /// suggestions are provided here. - #[error("cfitsio error: {0}\n\nIf you don't know what this means, try turning up verbosity (-v or -vv) and maybe disabling progress bars.")] - Cfitsio(String), - - /// A generic error that can't be clarified further with documentation, e.g. - /// IO errors. - #[error("{0}")] - Generic(String), -} - -// When changing the error propagation below, ensure `Self::from(e)` uses the -// correct `e`! - -// Binary sub-command errors. - -impl From for HyperdriveError { - fn from(e: DiCalibrateError) -> Self { - let s = e.to_string(); - match e { - DiCalibrateError::InsufficientMemory { .. } - | DiCalibrateError::TimestepUnavailable { .. } => Self::DiCalibrate(s), - DiCalibrateError::DiCalArgs(e) => Self::from(e), - DiCalibrateError::SolutionsRead(_) | DiCalibrateError::SolutionsWrite(_) => { - Self::Solutions(s) - } - DiCalibrateError::Fitsio(_) => Self::Cfitsio(s), - DiCalibrateError::VisRead(e) => Self::from(e), - DiCalibrateError::VisWrite(_) => Self::VisWrite(s), - DiCalibrateError::Model(_) | DiCalibrateError::IO(_) => Self::Generic(s), - } - } -} - -impl From for HyperdriveError { - fn from(e: SolutionsApplyError) -> Self { - let s = e.to_string(); - match e { - SolutionsApplyError::NoInputData | SolutionsApplyError::TileCountMismatch { .. } => { - Self::SolutionsApply(s) - } - SolutionsApplyError::MultipleMetafits(_) - | SolutionsApplyError::MultipleMeasurementSets(_) - | SolutionsApplyError::MultipleUvfits(_) - | SolutionsApplyError::InvalidDataInput(_) => Self::VisRead(s), - SolutionsApplyError::InvalidOutputFormat(_) | SolutionsApplyError::NoOutput => { - Self::VisWrite(s) - } - SolutionsApplyError::NoTiles - | SolutionsApplyError::TileFlag(_) - | SolutionsApplyError::NoTimesteps - | SolutionsApplyError::DuplicateTimesteps - | SolutionsApplyError::UnavailableTimestep { .. } - | SolutionsApplyError::BadArrayPosition { .. } => Self::Generic(s), - SolutionsApplyError::ParsePfbFlavour(_) => Self::RawDataCorrections(s), - SolutionsApplyError::ParseOutputVisTimeAverageFactor(_) - | SolutionsApplyError::ParseOutputVisFreqAverageFactor(_) - | SolutionsApplyError::OutputVisTimeFactorNotInteger - | SolutionsApplyError::OutputVisFreqFactorNotInteger - | SolutionsApplyError::OutputVisTimeAverageFactorZero - | SolutionsApplyError::OutputVisFreqAverageFactorZero - | SolutionsApplyError::OutputVisTimeResNotMultiple { .. } - | SolutionsApplyError::OutputVisFreqResNotMultiple { .. } => Self::Averaging(s), - SolutionsApplyError::SolutionsRead(_) => Self::Solutions(s), - SolutionsApplyError::VisRead(e) => Self::from(e), - SolutionsApplyError::FileWrite(_) | SolutionsApplyError::VisWrite(_) => { - Self::VisWrite(s) - } - SolutionsApplyError::IO(_) => Self::Generic(s), - } - } -} - -impl From for HyperdriveError { - fn from(e: SolutionsPlotError) -> Self { - let s = e.to_string(); - match e { - #[cfg(not(feature = "plotting"))] - SolutionsPlotError::NoPlottingFeature => Self::SolutionsPlot(s), - SolutionsPlotError::SolutionsRead(_) => Self::Solutions(s), - SolutionsPlotError::Mwalib(_) => Self::Mwalib(s), - SolutionsPlotError::IO(_) => Self::Generic(s), - #[cfg(feature = "plotting")] - SolutionsPlotError::MetafitsNoAntennaNames => Self::Metafits(s), - #[cfg(feature = "plotting")] - SolutionsPlotError::Draw(_) - | SolutionsPlotError::NoInputs - | SolutionsPlotError::InvalidSolsFormat(_) => Self::Generic(s), - } - } -} - -impl From for HyperdriveError { - fn from(e: VisSimulateError) -> Self { - let s = e.to_string(); - match e { - VisSimulateError::RaInvalid - | VisSimulateError::DecInvalid - | VisSimulateError::OnlyOneRAOrDec - | VisSimulateError::NoSourcesAfterVeto - | VisSimulateError::FineChansZero - | VisSimulateError::FineChansWidthTooSmall - | VisSimulateError::ZeroTimeSteps - | VisSimulateError::BadArrayPosition { .. } => Self::VisSimulate(s), - VisSimulateError::BadDelays => Self::Delays(s), - VisSimulateError::SourceList(_) | VisSimulateError::Veto(_) => Self::Srclist(s), - VisSimulateError::Beam(_) => Self::Beam(s), - VisSimulateError::Mwalib(_) => Self::Mwalib(s), - VisSimulateError::InvalidOutputFormat(_) | VisSimulateError::VisWrite(_) => { - Self::VisWrite(s) - } - VisSimulateError::AverageFactor(_) => Self::Averaging(s), - VisSimulateError::Glob(_) - | VisSimulateError::FileWrite(_) - | VisSimulateError::Model(_) - | VisSimulateError::IO(_) => Self::Generic(s), - #[cfg(feature = "cuda")] - VisSimulateError::Cuda(_) => Self::Generic(s), - } - } -} - -impl From for HyperdriveError { - fn from(e: VisSubtractError) -> Self { - let s = e.to_string(); - match e { - VisSubtractError::MissingSource { .. } - | VisSubtractError::NoSourcesAfterVeto - | VisSubtractError::NoSources - | VisSubtractError::AllSourcesFiltered - | VisSubtractError::NoTimesteps - | VisSubtractError::DuplicateTimesteps - | VisSubtractError::UnavailableTimestep { .. } - | VisSubtractError::NoInputData - | VisSubtractError::InvalidDataInput(_) - | VisSubtractError::BadArrayPosition { .. } - | VisSubtractError::MultipleMetafits(_) - | VisSubtractError::MultipleMeasurementSets(_) - | VisSubtractError::MultipleUvfits(_) => Self::VisSubtract(s), - VisSubtractError::NoDelays | VisSubtractError::BadDelays => Self::Delays(s), - VisSubtractError::VisWrite(_) | VisSubtractError::InvalidOutputFormat(_) => { - Self::VisWrite(s) - } - VisSubtractError::VisRead(e) => Self::from(e), - VisSubtractError::SourceList(_) | VisSubtractError::Veto(_) => Self::Srclist(s), - VisSubtractError::Beam(_) => Self::Beam(s), - VisSubtractError::ParseOutputVisTimeAverageFactor(_) - | VisSubtractError::ParseOutputVisFreqAverageFactor(_) - | VisSubtractError::OutputVisTimeFactorNotInteger - | VisSubtractError::OutputVisFreqFactorNotInteger - | VisSubtractError::OutputVisTimeAverageFactorZero - | VisSubtractError::OutputVisFreqAverageFactorZero - | VisSubtractError::OutputVisTimeResNotMultiple { .. } - | VisSubtractError::OutputVisFreqResNotMultiple { .. } => Self::Averaging(s), - VisSubtractError::Glob(_) - | VisSubtractError::FileWrite(_) - | VisSubtractError::Model(_) - | VisSubtractError::IO(_) => Self::Generic(s), - #[cfg(feature = "cuda")] - VisSubtractError::Cuda(_) => Self::Generic(s), - } - } -} - -// Library code errors. - -impl From for HyperdriveError { - fn from(e: SrclistError) -> Self { - let s = e.to_string(); - match e { - SrclistError::NoSourcesAfterVeto - | SrclistError::ReadSourceList(_) - | SrclistError::WriteSourceList(_) - | SrclistError::Veto(_) => Self::Srclist(s), - SrclistError::MissingMetafits => Self::Metafits(s), - SrclistError::Beam(_) => Self::Beam(s), - SrclistError::Mwalib(_) => Self::Mwalib(s), - SrclistError::IO(_) => Self::Generic(s), - } - } -} - -impl From for HyperdriveError { - fn from(e: SolutionsReadError) -> Self { - let s = e.to_string(); - match e { - SolutionsReadError::UnsupportedExt { .. } => Self::Solutions(s), - SolutionsReadError::BadShape { .. } | SolutionsReadError::ParsePfbFlavour(_) => { - Self::SolutionsHyp(s) - } - SolutionsReadError::AndreBinaryStr { .. } - | SolutionsReadError::AndreBinaryVal { .. } => Self::SolutionsAO(s), - SolutionsReadError::RtsMetafitsRequired | SolutionsReadError::Rts(_) => { - Self::SolutionsRts(s) - } - SolutionsReadError::Fits(_) | SolutionsReadError::Fitsio(_) => Self::Cfitsio(s), - SolutionsReadError::IO(_) => Self::Generic(s), - } - } -} - -impl From for HyperdriveError { - fn from(e: SolutionsWriteError) -> Self { - let s = e.to_string(); - match e { - SolutionsWriteError::UnsupportedExt { .. } => Self::Solutions(s), - SolutionsWriteError::Fits(_) | SolutionsWriteError::Fitsio(_) => Self::Cfitsio(s), - SolutionsWriteError::IO(_) => Self::Generic(s), - } - } -} - -impl From for HyperdriveError { - fn from(e: InputFileError) -> Self { - let s = e.to_string(); - match e { - InputFileError::PpdMetafitsUnsupported(_) => Self::Metafits(s), - InputFileError::NotRecognised(_) - | InputFileError::DoesNotExist(_) - | InputFileError::CouldNotRead(_) - | InputFileError::Glob(_) - | InputFileError::IO(_, _) => Self::VisRead(s), - } - } -} - -impl From for HyperdriveError { - fn from(e: VisReadError) -> Self { - let s = e.to_string(); - match e { - VisReadError::InputFile(e) => Self::from(e), - VisReadError::Raw(_) - | VisReadError::Birli(_) - | VisReadError::MS(_) - | VisReadError::Uvfits(_) => Self::VisRead(s), - VisReadError::MwafFlagsMissingForTimestep { .. } => Self::Mwaf(s), - VisReadError::BadArraySize { .. } | VisReadError::SelectionError(_) => Self::Generic(s), - } - } -} - -impl From for HyperdriveError { - fn from(e: DiCalArgsError) -> Self { - let s = e.to_string(); - match e { - DiCalArgsError::NoInputData - | DiCalArgsError::NoOutput - | DiCalArgsError::NoTiles - | DiCalArgsError::NoChannels - | DiCalArgsError::NoTimesteps - | DiCalArgsError::AllBaselinesFlaggedFromUvwCutoffs - | DiCalArgsError::UnavailableTimestep { .. } - | DiCalArgsError::DuplicateTimesteps - | DiCalArgsError::TileFlag(_) - | DiCalArgsError::NoSources - | DiCalArgsError::BadArrayPosition { .. } - | DiCalArgsError::ParseUvwMin(_) - | DiCalArgsError::ParseUvwMax(_) => Self::DiCalibrate(s), - DiCalArgsError::NoSourceList - | DiCalArgsError::NoSourcesAfterVeto - | DiCalArgsError::Veto(_) - | DiCalArgsError::SourceList(_) => Self::Srclist(s), - DiCalArgsError::NoDelays | DiCalArgsError::BadDelays => Self::Delays(s), - DiCalArgsError::CalibrationOutputFile { .. } => Self::Solutions(s), - DiCalArgsError::ParsePfbFlavour(_) => Self::RawDataCorrections(s), - DiCalArgsError::Beam(_) => Self::Beam(s), - DiCalArgsError::ParseCalTimeAverageFactor(_) - | DiCalArgsError::ParseCalFreqAverageFactor(_) - | DiCalArgsError::CalTimeFactorNotInteger - | DiCalArgsError::CalFreqFactorNotInteger - | DiCalArgsError::CalTimeResNotMultiple { .. } - | DiCalArgsError::CalFreqResNotMultiple { .. } - | DiCalArgsError::CalTimeFactorZero - | DiCalArgsError::CalFreqFactorZero - | DiCalArgsError::ParseOutputVisTimeAverageFactor(_) - | DiCalArgsError::ParseOutputVisFreqAverageFactor(_) - | DiCalArgsError::OutputVisTimeFactorNotInteger - | DiCalArgsError::OutputVisFreqFactorNotInteger - | DiCalArgsError::OutputVisTimeAverageFactorZero - | DiCalArgsError::OutputVisFreqAverageFactorZero - | DiCalArgsError::OutputVisTimeResNotMultiple { .. } - | DiCalArgsError::OutputVisFreqResNotMultiple { .. } => Self::Averaging(s), - DiCalArgsError::InvalidDataInput(_) - | DiCalArgsError::MultipleMetafits(_) - | DiCalArgsError::MultipleMeasurementSets(_) - | DiCalArgsError::MultipleUvfits(_) => Self::VisRead(s), - DiCalArgsError::VisRead(e) => Self::from(e), - DiCalArgsError::VisFileType { .. } | DiCalArgsError::FileWrite(_) => Self::VisWrite(s), - DiCalArgsError::UnrecognisedArgFileExt(_) - | DiCalArgsError::TomlDecode { .. } - | DiCalArgsError::JsonDecode { .. } - | DiCalArgsError::Glob(_) - | DiCalArgsError::IO(_) => Self::Generic(s), - #[cfg(feature = "cuda")] - DiCalArgsError::Cuda(_) => Self::Generic(s), - } - } -} - -impl From for HyperdriveError { - fn from(e: mwalib::MwalibError) -> Self { - Self::Mwalib(e.to_string()) - } -} diff --git a/src/flagging/error.rs b/src/flagging/error.rs index ba63e9ec..03a398cd 100644 --- a/src/flagging/error.rs +++ b/src/flagging/error.rs @@ -6,9 +6,7 @@ use std::path::PathBuf; -use thiserror::Error; - -#[derive(Error, Debug)] +#[derive(thiserror::Error, Debug)] /// Error type associated with mwaf files. pub enum MwafError { #[error("mwaf file '{file:?}' has an unhandled version '{version}'")] @@ -21,7 +19,7 @@ pub enum MwafError { FitsError(#[from] crate::io::read::fits::FitsError), } -#[derive(Error, Debug)] +#[derive(thiserror::Error, Debug)] /// Error type associated with merging the contents of mwaf files. pub enum MwafMergeError { /// Error to describe some kind of inconsistent state within an mwaf file. diff --git a/src/io/read/error.rs b/src/io/read/error.rs index 9457d4a4..a7fb394b 100644 --- a/src/io/read/error.rs +++ b/src/io/read/error.rs @@ -2,17 +2,8 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -//! Errors from building a `InputData` trait instance. - -use birli::BirliError; -use marlu::SelectionError; -use thiserror::Error; - -#[derive(Error, Debug)] +#[derive(thiserror::Error, Debug)] pub enum VisReadError { - #[error("The supplied mwaf files don't have flags for timestep {timestep} (GPS time {gps})")] - MwafFlagsMissingForTimestep { timestep: usize, gps: f64 }, - #[error("Output {array_type} array did not have expected {expected_len} elements on axis {axis_num}")] BadArraySize { array_type: &'static str, @@ -20,21 +11,12 @@ pub enum VisReadError { axis_num: usize, }, - #[error(transparent)] - InputFile(#[from] crate::filenames::InputFileError), - #[error(transparent)] Raw(#[from] super::raw::RawReadError), - #[error(transparent)] - Birli(#[from] BirliError), - #[error(transparent)] MS(#[from] super::ms::MsReadError), #[error(transparent)] Uvfits(#[from] super::uvfits::UvfitsReadError), - - #[error(transparent)] - SelectionError(#[from] SelectionError), } diff --git a/src/io/read/fits/error.rs b/src/io/read/fits/error.rs index 01327a50..d1fdd621 100644 --- a/src/io/read/fits/error.rs +++ b/src/io/read/fits/error.rs @@ -6,9 +6,7 @@ use std::path::Path; -use thiserror::Error; - -#[derive(Error, Debug)] +#[derive(thiserror::Error, Debug)] pub enum FitsError { /// Error when opening a fits file. #[error( diff --git a/src/io/read/mod.rs b/src/io/read/mod.rs index e25662b8..0f6dfc89 100644 --- a/src/io/read/mod.rs +++ b/src/io/read/mod.rs @@ -11,9 +11,11 @@ mod raw; mod uvfits; pub(crate) use error::VisReadError; +pub(crate) use ms::MsReadError; pub use ms::MsReader; -pub(crate) use raw::pfb_gains; +pub(crate) use raw::{pfb_gains, RawReadError}; pub use raw::{RawDataCorrections, RawDataReader}; +pub(crate) use uvfits::UvfitsReadError; pub use uvfits::UvfitsReader; use std::collections::HashSet; @@ -29,7 +31,7 @@ use vec1::Vec1; use crate::{context::ObsContext, flagging::MwafFlags, math::TileBaselineFlags}; -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub(crate) enum VisInputType { Raw, MeasurementSet, @@ -49,6 +51,14 @@ pub(crate) trait VisRead: Sync + Send { /// this trait object. fn get_flags(&self) -> Option<&MwafFlags>; + /// Get the raw data corrections that will be applied to the visibilities as + /// they're read in. These may be distinct from what the user specified. + fn get_raw_data_corrections(&self) -> Option; + + /// Set the raw data corrections that will be applied to the visibilities as + /// they're read in. These are only applied to raw data. + fn set_raw_data_corrections(&mut self, corrections: RawDataCorrections); + /// Read cross- and auto-correlation visibilities for all frequencies and /// baselines in a single timestep into corresponding arrays. #[allow(clippy::too_many_arguments)] @@ -60,7 +70,7 @@ pub(crate) trait VisRead: Sync + Send { auto_weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError>; /// Read cross-correlation visibilities for all frequencies and baselines in @@ -71,7 +81,7 @@ pub(crate) trait VisRead: Sync + Send { weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError>; /// Read auto-correlation visibilities for all frequencies and tiles in a @@ -82,7 +92,7 @@ pub(crate) trait VisRead: Sync + Send { weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError>; /// Get optional MWA information to give to `Marlu` when writing out diff --git a/src/io/read/ms/mod.rs b/src/io/read/ms/mod.rs index d4f67ced..293549b5 100644 --- a/src/io/read/ms/mod.rs +++ b/src/io/read/ms/mod.rs @@ -19,12 +19,12 @@ pub(crate) use error::*; use std::{ collections::{BTreeSet, HashMap}, - num::NonZeroUsize, + num::NonZeroU16, path::{Path, PathBuf}, }; use hifitime::{Duration, Epoch, TimeUnits}; -use log::{debug, trace, warn}; +use log::{debug, trace}; use marlu::{c32, rubbl_casatables, Jones, LatLngHeight, RADec, XyzGeocentric, XyzGeodetic}; use ndarray::prelude::*; use rubbl_casatables::{Table, TableError, TableOpenMode}; @@ -32,6 +32,8 @@ use rubbl_casatables::{Table, TableError, TableOpenMode}; use super::*; use crate::{ beam::Delays, + cli::Warn, + constants::DEFAULT_MS_DATA_COL_NAME, context::{ObsContext, Polarisations}, metafits, }; @@ -87,7 +89,7 @@ pub struct MsReader { obs_context: ObsContext, /// The path to the measurement set on disk. - pub(crate) ms: PathBuf, + ms: PathBuf, /// The "stride" of the data, i.e. the number of rows (baselines) before the /// time index changes. @@ -152,7 +154,8 @@ impl MsReader { return Err(MsReadError::MainTableEmpty); } let col_names = main_table.column_names()?; - let data_col_name = data_column_name.unwrap_or_else(|| "DATA".to_string()); + let data_col_name = + data_column_name.unwrap_or_else(|| DEFAULT_MS_DATA_COL_NAME.to_string()); // Validate the data column name, specified or not. if !col_names.contains(&data_col_name) { return Err(MsReadError::NoDataCol { col: data_col_name }); @@ -604,15 +607,10 @@ impl MsReader { } }; - let num_coarse_chans = mwa_coarse_chan_nums.as_ref().map(|ccs| { - NonZeroUsize::new(ccs.len()) - .expect("length is always > 0 because collection cannot be empty") - }); - let num_fine_chans_per_coarse_chan = num_coarse_chans.and_then(|num_coarse_chans| { - NonZeroUsize::new( - (total_bandwidth_hz / num_coarse_chans.get() as f64 / freq_res).round() as usize, - ) - }); + let num_fine_chans_per_coarse_chan = { + let n = (1.28e6 / freq_res).round() as u16; + Some(NonZeroU16::new(n).expect("is not 0")) + }; match ( mwa_coarse_chan_nums.as_ref(), @@ -688,8 +686,12 @@ impl MsReader { dipole_gains = Some(gains2); } else { // We have no choice but to leave the order as is. - warn!("The MS antenna names are different to those supplied in the metafits."); - warn!("Dipole delays/gains may be incorrectly mapped to MS antennas."); + [ + "The MS antenna names are different to those supplied in the metafits." + .into(), + "Dipole delays/gains may be incorrectly mapped to MS antennas.".into(), + ] + .warn(); dipole_delays = Some(Delays::Full(delays)); dipole_gains = Some(gains); } @@ -772,25 +774,44 @@ impl MsReader { let flagged_fine_chans_per_coarse_chan = { let mut flagged_fine_chans_per_coarse_chan = vec![]; - if let (Some(num_coarse_chans), Some(num_fine_chans_per_coarse_chan)) = ( - num_coarse_chans, - num_fine_chans_per_coarse_chan.map(|n| n.get()), - ) { - // Loop over all fine channels within a coarse channel. For each - // fine channel, check all coarse channels; is the fine channel - // flagged? If so, add it to our collection. + if let Some(num_fine_chans_per_coarse_chan) = + num_fine_chans_per_coarse_chan.map(|n| n.get()) + { + // Because we allow data to come in that might have channels + // missing at the edges, there might be a different number of + // fine channels in each coarse channel. Match the channel freqs + // with what we think their index should be inside a coarse + // channel. + let chan_info = fine_chan_freqs_f64 + .iter() + .enumerate() + .map(|(i_chan, &freq)| { + // Get the coarse channel number. + let cc_num = (freq / 1.28e6).round(); + // Find the offset from the coarse channel centre. + let offset_hz = freq - cc_num * 1.28e6; + let i_offset: u16 = ((offset_hz / freq_res).round() as i32 + + i32::from(num_fine_chans_per_coarse_chan) / 2) + .try_into() + .expect("smaller than u16::MAX"); + (i_chan, i_offset, cc_num as u32) + }) + .collect::>(); + for i_chan in 0..num_fine_chans_per_coarse_chan { let mut chan_is_flagged = true; - // Note that the coarse channel indices do not matter; the - // data in measurement sets is concatenated even if a coarse - // channel is missing. - for i_cc in 0..num_coarse_chans.get() { - if !flagged_fine_chans[i_cc * num_fine_chans_per_coarse_chan + i_chan] { + let mut chan_is_present = false; + for (full_index, _, _) in chan_info + .iter() + .filter(|(_, this_chan, _)| *this_chan == i_chan) + { + chan_is_present = true; + if !flagged_fine_chans[*full_index] { chan_is_flagged = false; break; } } - if chan_is_flagged { + if chan_is_flagged && chan_is_present { flagged_fine_chans_per_coarse_chan.push(i_chan); } } @@ -798,9 +819,8 @@ impl MsReader { Vec1::try_from_vec(flagged_fine_chans_per_coarse_chan).ok() }; - let flagged_fine_chans = flagged_fine_chans - .into_iter() - .enumerate() + let flagged_fine_chans = (0..) + .zip(flagged_fine_chans) .filter(|(_, f)| *f) .map(|(i, _)| i) .collect(); @@ -879,7 +899,8 @@ impl MsReader { *timestamps.first(), dut1, ) { - warn!("uvfits UVWs use the other baseline convention; will conjugate incoming visibilities"); + "MS UVWs use the other baseline convention; will conjugate incoming visibilities" + .warn(); true } else { false @@ -887,6 +908,7 @@ impl MsReader { }; let obs_context = ObsContext { + input_data_type: VisInputType::MeasurementSet, obsid, timestamps, all_timesteps, @@ -894,7 +916,7 @@ impl MsReader { phase_centre, pointing_centre, array_position, - _supplied_array_position: supplied_array_position, + supplied_array_position, dut1, tile_names, tile_xyzs, @@ -936,7 +958,7 @@ impl MsReader { mut crosses: Option, mut autos: Option, timestep: usize, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { // When reading in a new timestep's data, these indices should be // multiplied by `step` to get the amount of rows to stride in the main @@ -947,7 +969,7 @@ impl MsReader { let mut main_table = read_table(&self.ms, None).map_err(MsReadError::from)?; let chan_flags = (0..self.obs_context.fine_chan_freqs.len()) - .map(|i_chan| flagged_fine_chans.contains(&i_chan)) + .map(|i_chan| flagged_fine_chans.contains(&(i_chan as u16))) .collect::>(); main_table .for_each_row_in_range(row_range, |row| { @@ -1267,6 +1289,12 @@ impl VisRead for MsReader { None } + fn get_raw_data_corrections(&self) -> Option { + None + } + + fn set_raw_data_corrections(&mut self, _: RawDataCorrections) {} + fn read_crosses_and_autos( &self, cross_vis_fb: ArrayViewMut2>, @@ -1275,7 +1303,7 @@ impl VisRead for MsReader { auto_weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { let cross_data = Some(CrossData { vis_fb: cross_vis_fb, @@ -1305,7 +1333,7 @@ impl VisRead for MsReader { weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { let cross_data = Some(CrossData { vis_fb, @@ -1330,7 +1358,7 @@ impl VisRead for MsReader { weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { let auto_data = Some(AutoData { vis_fb, diff --git a/src/io/read/ms/tests.rs b/src/io/read/ms/tests.rs index 06344bbf..36616323 100644 --- a/src/io/read/ms/tests.rs +++ b/src/io/read/ms/tests.rs @@ -2,7 +2,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use std::{collections::HashSet, ffi::CString}; +use std::ffi::CString; use approx::{assert_abs_diff_eq, assert_abs_diff_ne}; use fitsio::errors::check_status as fits_check_status; @@ -14,34 +14,20 @@ use serial_test::serial; // Need to test serially because casacore is a steaming use tempfile::tempdir; use super::*; -use crate::{ - di_calibrate::{get_cal_vis, tests::test_1090008640_quality}, - math::TileBaselineFlags, - tests::{deflate_gz_into_tempfile, reduced_obsids::get_reduced_1090008640_ms}, -}; +use crate::tests::{deflate_gz_into_tempfile, get_reduced_1090008640_ms_pbs, DataAsPathBufs}; #[test] #[serial] fn test_1090008640_cross_vis() { - let args = get_reduced_1090008640_ms(); - let ms_reader = if let [metafits, ms] = &args.data.unwrap()[..] { - match MsReader::new( - PathBuf::from(ms), - None, - Some(&PathBuf::from(metafits)), - None, - ) { - Ok(m) => m, - Err(e) => panic!("{}", e), - } - } else { - panic!("There weren't 2 elements in args.data"); - }; + let DataAsPathBufs { + metafits, mut vis, .. + } = get_reduced_1090008640_ms_pbs(); + let ms_reader = MsReader::new(vis.swap_remove(0), None, Some(&metafits), None).unwrap(); - let obs_context = &ms_reader.obs_context; + let obs_context = ms_reader.get_obs_context(); let total_num_tiles = obs_context.get_total_num_tiles(); let num_baselines = (total_num_tiles * (total_num_tiles - 1)) / 2; - let num_chans = obs_context.num_fine_chans_per_coarse_chan.unwrap().get(); + let num_chans = usize::from(obs_context.num_fine_chans_per_coarse_chan.unwrap().get()); let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, HashSet::new()); assert_abs_diff_eq!( @@ -93,24 +79,14 @@ fn test_1090008640_cross_vis() { #[test] #[serial] fn read_1090008640_auto_vis() { - let args = get_reduced_1090008640_ms(); - let ms_reader = if let [metafits, ms] = &args.data.unwrap()[..] { - match MsReader::new( - PathBuf::from(ms), - None, - Some(&PathBuf::from(metafits)), - None, - ) { - Ok(m) => m, - Err(e) => panic!("{}", e), - } - } else { - panic!("There weren't 2 elements in args.data"); - }; + let DataAsPathBufs { + metafits, mut vis, .. + } = get_reduced_1090008640_ms_pbs(); + let ms_reader = MsReader::new(vis.swap_remove(0), None, Some(&metafits), None).unwrap(); - let obs_context = &ms_reader.obs_context; + let obs_context = ms_reader.get_obs_context(); let total_num_tiles = obs_context.get_total_num_tiles(); - let num_chans = obs_context.num_fine_chans_per_coarse_chan.unwrap().get(); + let num_chans = usize::from(obs_context.num_fine_chans_per_coarse_chan.unwrap().get()); let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, HashSet::new()); assert_abs_diff_eq!( @@ -188,24 +164,14 @@ fn read_1090008640_auto_vis() { #[test] #[serial] fn read_1090008640_auto_vis_with_flags() { - let args = get_reduced_1090008640_ms(); - let ms_reader = if let [metafits, ms] = &args.data.unwrap()[..] { - match MsReader::new( - PathBuf::from(ms), - None, - Some(&PathBuf::from(metafits)), - None, - ) { - Ok(m) => m, - Err(e) => panic!("{}", e), - } - } else { - panic!("There weren't 2 elements in args.data"); - }; + let DataAsPathBufs { + metafits, mut vis, .. + } = get_reduced_1090008640_ms_pbs(); + let ms_reader = MsReader::new(vis.swap_remove(0), None, Some(&metafits), None).unwrap(); - let obs_context = &ms_reader.obs_context; + let obs_context = ms_reader.get_obs_context(); let total_num_tiles = obs_context.get_total_num_tiles(); - let num_chans = obs_context.num_fine_chans_per_coarse_chan.unwrap().get(); + let num_chans = usize::from(obs_context.num_fine_chans_per_coarse_chan.unwrap().get()); let tile_flags = HashSet::from([1, 9]); let num_unflagged_tiles = total_num_tiles - tile_flags.len(); let chan_flags = HashSet::from([1]); @@ -291,25 +257,15 @@ fn read_1090008640_auto_vis_with_flags() { #[test] #[serial] fn read_1090008640_cross_and_auto_vis() { - let args = get_reduced_1090008640_ms(); - let ms_reader = if let [metafits, ms] = &args.data.unwrap()[..] { - match MsReader::new( - PathBuf::from(ms), - None, - Some(&PathBuf::from(metafits)), - None, - ) { - Ok(m) => m, - Err(e) => panic!("{}", e), - } - } else { - panic!("There weren't 2 elements in args.data"); - }; + let DataAsPathBufs { + metafits, mut vis, .. + } = get_reduced_1090008640_ms_pbs(); + let ms_reader = MsReader::new(vis.swap_remove(0), None, Some(&metafits), None).unwrap(); - let obs_context = &ms_reader.obs_context; + let obs_context = ms_reader.get_obs_context(); let total_num_tiles = obs_context.get_total_num_tiles(); let num_baselines = (total_num_tiles * (total_num_tiles - 1)) / 2; - let num_chans = obs_context.num_fine_chans_per_coarse_chan.unwrap().get(); + let num_chans = usize::from(obs_context.num_fine_chans_per_coarse_chan.unwrap().get()); let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, HashSet::new()); assert_abs_diff_eq!( @@ -422,25 +378,6 @@ fn read_1090008640_cross_and_auto_vis() { ); } -#[test] -#[serial] -fn test_1090008640_calibration_quality() { - let mut args = get_reduced_1090008640_ms(); - let temp_dir = tempdir().expect("Couldn't make temp dir"); - args.outputs = Some(vec![temp_dir.path().join("hyp_sols.fits")]); - // To be consistent with other data quality tests, add these flags. - args.fine_chan_flags = Some(vec![0, 1, 2, 16, 30, 31]); - - let result = args.into_params(); - let params = match result { - Ok(r) => r, - Err(e) => panic!("{}", e), - }; - - let cal_vis = get_cal_vis(¶ms, false).expect("Couldn't read data and generate a model"); - test_1090008640_quality(params, cal_vis); -} - #[test] #[serial] fn test_timestep_reading() { @@ -540,112 +477,17 @@ fn test_timestep_reading() { ); } -#[test] -#[serial] -fn test_trunc_data() { - let expected_num_tiles = 128; - let expected_unavailable_tiles = (2..128).collect::>(); - - let result = MsReader::new( - PathBuf::from("test_files/1090008640/1090008640_cotter_trunc_autos.ms"), - None, - None, - None, - ); - assert!(result.is_ok(), "{:?}", result.err()); - let reader = result.unwrap(); - assert!(reader.obs_context.autocorrelations_present); - assert_eq!(reader.obs_context.get_total_num_tiles(), expected_num_tiles); - assert_eq!(reader.obs_context.get_num_unflagged_tiles(), 2); - assert_eq!( - &reader.obs_context.unavailable_tiles, - &expected_unavailable_tiles - ); - assert_eq!( - &reader.obs_context.flagged_tiles, - &expected_unavailable_tiles - ); - assert_eq!(&reader.obs_context.all_timesteps, &[0, 1, 2]); - assert_eq!(&reader.obs_context.unflagged_timesteps, &[2]); - - let result = MsReader::new( - PathBuf::from("test_files/1090008640/1090008640_cotter_trunc_noautos.ms"), - None, - None, - None, - ); - assert!(result.is_ok(), "{:?}", result.err()); - let reader = result.unwrap(); - assert!(!reader.obs_context.autocorrelations_present); - assert_eq!(reader.obs_context.get_total_num_tiles(), expected_num_tiles); - assert_eq!(reader.obs_context.get_num_unflagged_tiles(), 2); - assert_eq!( - &reader.obs_context.unavailable_tiles, - &expected_unavailable_tiles - ); - assert_eq!( - &reader.obs_context.flagged_tiles, - &expected_unavailable_tiles - ); - assert_eq!(&reader.obs_context.all_timesteps, &[0, 1, 2]); - assert_eq!(&reader.obs_context.unflagged_timesteps, &[2]); - - let result = MsReader::new( - PathBuf::from("test_files/1090008640/1090008640_birli_trunc.ms"), - None, - None, - None, - ); - assert!(result.is_ok(), "{:?}", result.err()); - let reader = result.unwrap(); - assert!(reader.obs_context.autocorrelations_present); - assert_eq!(reader.obs_context.get_total_num_tiles(), expected_num_tiles); - assert_eq!(reader.obs_context.get_num_unflagged_tiles(), 2); - assert_eq!( - &reader.obs_context.unavailable_tiles, - &expected_unavailable_tiles - ); - assert_eq!( - &reader.obs_context.flagged_tiles, - &expected_unavailable_tiles - ); - assert_eq!(&reader.obs_context.all_timesteps, &[0, 1, 2]); - assert_eq!(&reader.obs_context.unflagged_timesteps, &[1, 2]); - - // Test that attempting to use all tiles still results in only 2 tiles being available. - let mut args = get_reduced_1090008640_ms(); - let temp_dir = tempdir().expect("Couldn't make temp dir"); - match args.data.as_mut() { - Some(d) => d[1] = "test_files/1090008640/1090008640_birli_trunc.ms".to_string(), - None => unreachable!(), - } - args.outputs = Some(vec![temp_dir.path().join("hyp_sols.fits")]); - args.ignore_input_data_tile_flags = true; - args.uvw_min = Some("0L".to_string()); - let result = args.into_params(); - assert!(result.is_ok(), "{:?}", result.err()); - let params = result.unwrap(); - - assert_eq!( - params.tile_baseline_flags.flagged_tiles.len(), - expected_unavailable_tiles.len() - ); -} - #[test] #[serial] fn test_map_metafits_antenna_order() { // First, check the delays and gains of the existing test data. Because this // MS has its tiles in the same order as the "metafits order", the delays // and gains are already correct without re-ordering. - let metafits_path = "test_files/1090008640/1090008640.metafits"; - let ms = MsReader::new( - PathBuf::from("test_files/1090008640/1090008640.ms"), - None, - Some(&PathBuf::from(metafits_path)), - None, - ) - .unwrap(); + let DataAsPathBufs { + metafits, mut vis, .. + } = get_reduced_1090008640_ms_pbs(); + let ms_pb = vis.swap_remove(0); + let ms = MsReader::new(ms_pb.clone(), None, Some(&metafits), None).unwrap(); let obs_context = ms.get_obs_context(); let delays = match obs_context.dipole_delays.as_ref() { Some(Delays::Full(d)) => d, @@ -662,10 +504,10 @@ fn test_map_metafits_antenna_order() { // Test that the dipole delays/gains get mapped correctly. As the test MS is // already in the same order as the metafits file, the easiest thing to do // is to modify the metafits file. - let metafits = tempfile::NamedTempFile::new().expect("couldn't make a temp file"); - std::fs::copy(metafits_path, metafits.path()).unwrap(); + let metafits_tmp = tempfile::NamedTempFile::new().expect("couldn't make a temp file"); + std::fs::copy(&metafits, metafits_tmp.path()).unwrap(); unsafe { - let metafits = CString::new(metafits.path().display().to_string()) + let metafits_c_str = CString::new(metafits_tmp.path().display().to_string()) .unwrap() .into_raw(); let mut fptr = std::ptr::null_mut(); @@ -673,13 +515,13 @@ fn test_map_metafits_antenna_order() { // ffopen = fits_open_file fitsio_sys::ffopen( - &mut fptr, /* O - FITS file pointer */ - metafits, /* I - full name of file to open */ - 1, /* I - 0 = open readonly; 1 = read/write */ - &mut status, /* IO - error status */ + &mut fptr, /* O - FITS file pointer */ + metafits_c_str, /* I - full name of file to open */ + 1, /* I - 0 = open readonly; 1 = read/write */ + &mut status, /* IO - error status */ ); fits_check_status(status).unwrap(); - drop(CString::from_raw(metafits)); + drop(CString::from_raw(metafits_c_str)); // ffmahd = fits_movabs_hdu fitsio_sys::ffmahd( fptr, /* I - FITS file pointer */ @@ -743,13 +585,7 @@ fn test_map_metafits_antenna_order() { fits_check_status(status).unwrap(); } - let ms = MsReader::new( - PathBuf::from("test_files/1090008640/1090008640.ms"), - None, - Some(metafits.path()), - None, - ) - .unwrap(); + let ms = MsReader::new(ms_pb.clone(), None, Some(metafits_tmp.path()), None).unwrap(); let obs_context = ms.get_obs_context(); let delays = match obs_context.dipole_delays.as_ref() { Some(Delays::Full(d)) => d, @@ -790,10 +626,10 @@ fn test_map_metafits_antenna_order() { // Test that the dipole delays/gains aren't mapped when an unknown tile name // is encountered. - let metafits = tempfile::NamedTempFile::new().expect("couldn't make a temp file"); - std::fs::copy(metafits_path, metafits.path()).unwrap(); + let metafits_tmp2 = tempfile::NamedTempFile::new().expect("couldn't make a temp file"); + std::fs::copy(metafits, metafits_tmp2.path()).unwrap(); unsafe { - let metafits = CString::new(metafits.path().display().to_string()) + let metafits_c_str = CString::new(metafits_tmp2.path().display().to_string()) .unwrap() .into_raw(); let mut fptr = std::ptr::null_mut(); @@ -801,13 +637,13 @@ fn test_map_metafits_antenna_order() { // ffopen = fits_open_file fitsio_sys::ffopen( - &mut fptr, /* O - FITS file pointer */ - metafits, /* I - full name of file to open */ - 1, /* I - 0 = open readonly; 1 = read/write */ - &mut status, /* IO - error status */ + &mut fptr, /* O - FITS file pointer */ + metafits_c_str, /* I - full name of file to open */ + 1, /* I - 0 = open readonly; 1 = read/write */ + &mut status, /* IO - error status */ ); fits_check_status(status).unwrap(); - drop(CString::from_raw(metafits)); + drop(CString::from_raw(metafits_c_str)); // ffmahd = fits_movabs_hdu fitsio_sys::ffmahd( fptr, /* I - FITS file pointer */ @@ -846,13 +682,7 @@ fn test_map_metafits_antenna_order() { fits_check_status(status).unwrap(); } - let ms = MsReader::new( - PathBuf::from("test_files/1090008640/1090008640.ms"), - None, - Some(metafits.path()), - None, - ) - .unwrap(); + let ms = MsReader::new(ms_pb, None, Some(metafits_tmp2.path()), None).unwrap(); let obs_context = ms.get_obs_context(); let delays = match obs_context.dipole_delays.as_ref() { Some(Delays::Full(d)) => d, @@ -868,6 +698,83 @@ fn test_map_metafits_antenna_order() { assert_abs_diff_eq!(gains, perturbed_gains); } +#[test] +#[serial] +fn test_trunc_data() { + let expected_num_tiles = 128; + let expected_unavailable_tiles = (2..128).collect::>(); + + let result = MsReader::new( + PathBuf::from("test_files/1090008640/1090008640_cotter_trunc_autos.ms"), + None, + None, + None, + ); + assert!(result.is_ok(), "{:?}", result.err()); + let reader = result.unwrap(); + let obs_context = reader.get_obs_context(); + let total_num_tiles = obs_context.get_total_num_tiles(); + let num_unflagged_tiles = total_num_tiles - obs_context.flagged_tiles.len(); + assert!(obs_context.autocorrelations_present); + assert_eq!(total_num_tiles, expected_num_tiles); + assert_eq!(num_unflagged_tiles, 2); + assert_eq!(&obs_context.unavailable_tiles, &expected_unavailable_tiles); + assert_eq!(&obs_context.flagged_tiles, &expected_unavailable_tiles); + assert_eq!(&obs_context.all_timesteps, &[0, 1, 2]); + assert_eq!(&obs_context.unflagged_timesteps, &[2]); + + let result = MsReader::new( + PathBuf::from("test_files/1090008640/1090008640_cotter_trunc_noautos.ms"), + None, + None, + None, + ); + assert!(result.is_ok(), "{:?}", result.err()); + let reader = result.unwrap(); + let obs_context = reader.get_obs_context(); + let total_num_tiles = obs_context.get_total_num_tiles(); + let num_unflagged_tiles = total_num_tiles - obs_context.flagged_tiles.len(); + assert!(!obs_context.autocorrelations_present); + assert_eq!(total_num_tiles, expected_num_tiles); + assert_eq!(num_unflagged_tiles, 2); + assert_eq!(&obs_context.unavailable_tiles, &expected_unavailable_tiles); + assert_eq!(&obs_context.flagged_tiles, &expected_unavailable_tiles); + assert_eq!(&obs_context.all_timesteps, &[0, 1, 2]); + assert_eq!(&obs_context.unflagged_timesteps, &[2]); + + let result = MsReader::new( + PathBuf::from("test_files/1090008640/1090008640_birli_trunc.ms"), + None, + None, + None, + ); + assert!(result.is_ok(), "{:?}", result.err()); + let reader = result.unwrap(); + let obs_context = reader.get_obs_context(); + let total_num_tiles = obs_context.get_total_num_tiles(); + let num_unflagged_tiles = total_num_tiles - obs_context.flagged_tiles.len(); + assert!(obs_context.autocorrelations_present); + assert_eq!(total_num_tiles, expected_num_tiles); + assert_eq!(num_unflagged_tiles, 2); + assert_eq!(&obs_context.unavailable_tiles, &expected_unavailable_tiles); + assert_eq!(&obs_context.flagged_tiles, &expected_unavailable_tiles); + assert_eq!(&obs_context.all_timesteps, &[0, 1, 2]); + assert_eq!(&obs_context.unflagged_timesteps, &[1, 2]); + + // Test that attempting to use all tiles still results in only 2 tiles being available. + let ms_reader = MsReader::new( + PathBuf::from("test_files/1090008640/1090008640_birli_trunc.ms"), + None, + None, + None, + ) + .unwrap(); + assert_eq!( + ms_reader.get_obs_context().flagged_tiles.len(), + expected_unavailable_tiles.len() + ); +} + #[test] fn test_sdc3() { let ms = tempdir().unwrap(); @@ -888,7 +795,7 @@ fn test_sdc3() { let obs_context = ms.get_obs_context(); assert_eq!(obs_context.timestamps.len(), 1); assert_eq!(obs_context.fine_chan_freqs.len(), 1); - let supplied_array_position = obs_context._supplied_array_position; + let supplied_array_position = obs_context.supplied_array_position; assert_abs_diff_eq!( supplied_array_position.longitude_rad.to_degrees(), 116.76444819999999, diff --git a/src/io/read/raw/error.rs b/src/io/read/raw/error.rs index 25313f76..11d5183d 100644 --- a/src/io/read/raw/error.rs +++ b/src/io/read/raw/error.rs @@ -4,11 +4,7 @@ //! Error-handling code associated with reading from raw MWA files. -use thiserror::Error; - -use crate::flagging::MwafMergeError; - -#[derive(Error, Debug)] +#[derive(thiserror::Error, Debug)] pub enum RawReadError { #[error("gpubox file {0} does not have a corresponding mwaf file specified")] GpuboxFileMissingMwafFile(usize), @@ -25,12 +21,21 @@ pub enum RawReadError { #[error("Attempted to read in MWA VCS data; this is unsupported")] Vcs, + #[error("The supplied mwaf files don't have flags for timestep {timestep} (GPS time {gps})")] + MwafFlagsMissingForTimestep { timestep: usize, gps: f64 }, + #[error(transparent)] - MwafMerge(#[from] Box), + MwafMerge(#[from] Box), #[error(transparent)] Glob(#[from] crate::io::GlobError), #[error("mwalib error: {0}")] Mwalib(#[from] Box), + + #[error(transparent)] + Selection(#[from] Box), + + #[error(transparent)] + Birli(#[from] Box), } diff --git a/src/io/read/raw/mod.rs b/src/io/read/raw/mod.rs index ea7074e9..c022a046 100644 --- a/src/io/read/raw/mod.rs +++ b/src/io/read/raw/mod.rs @@ -9,11 +9,12 @@ pub(crate) mod pfb_gains; #[cfg(test)] mod tests; -pub(crate) use error::*; +pub(crate) use error::RawReadError; use std::{ collections::HashSet, - num::NonZeroUsize, + fmt::Debug, + num::NonZeroU16, ops::Range, path::{Path, PathBuf}, }; @@ -21,17 +22,19 @@ use std::{ use birli::PreprocessContext; use hifitime::{Duration, Epoch}; use itertools::Itertools; -use log::{debug, trace, warn}; +use log::{debug, trace}; use marlu::{math::baseline_to_tiles, Jones, LatLngHeight, RADec, VisSelection, XyzGeodetic}; use mwalib::{ - CorrelatorContext, GeometricDelaysApplied, GpuboxError, MWAVersion, MwalibError, Pol, + CorrelatorContext, GeometricDelaysApplied, GpuboxError, MWAVersion, MetafitsContext, + MwalibError, Pol, }; use ndarray::prelude::*; use vec1::Vec1; -use super::*; +use super::{AutoData, CrossData, MarluMwaObsContext, VisInputType, VisRead, VisReadError}; use crate::{ beam::Delays, + cli::Warn, context::ObsContext, flagging::{MwafFlags, MwafProducer}, math::TileBaselineFlags, @@ -116,9 +119,8 @@ pub struct RawDataReader { /// Observation metadata. obs_context: ObsContext, - // Raw-data-specific things follow. /// The interface to the raw data via mwalib. - pub(crate) mwalib_context: CorrelatorContext, + mwalib_context: CorrelatorContext, /// The poly-phase filter bank gains to be used to correct the bandpass /// shape for each coarse channel. @@ -137,22 +139,10 @@ pub struct RawDataReader { impl RawDataReader { /// Create a new [`RawDataReader`]. - pub fn new>( - metadata: &T, - gpuboxes: &[T], - mwafs: Option<&[T]>, - corrections: RawDataCorrections, - array_position: Option, - ) -> Result { - Self::new_inner(metadata, gpuboxes, mwafs, corrections, array_position) - .map_err(VisReadError::from) - } - - /// Create a new [`RawDataReader`]. - fn new_inner>( - metadata: &T, - gpuboxes: &[T], - mwafs: Option<&[T]>, + pub fn new( + metafits: &Path, + gpuboxes: &[PathBuf], + mwafs: Option<&[PathBuf]>, corrections: RawDataCorrections, array_position: Option, ) -> Result { @@ -160,14 +150,12 @@ impl RawDataReader { // mwalib ensures that vectors aren't empty so when we convert a Vec to // Vec1, for example, we don't need to propagate a new error. - let meta_pb = metadata.as_ref().to_path_buf(); - let gpubox_pbs: Vec = gpuboxes.iter().map(|p| p.as_ref().to_path_buf()).collect(); - trace!("Using metafits: {}", meta_pb.display()); - trace!("Using gpubox files: {:#?}", gpubox_pbs); + trace!("Using metafits: {}", metafits.display()); + trace!("Using gpubox files: {:#?}", gpuboxes); trace!("Creating mwalib context"); let mwalib_context = crate::misc::expensive_op( - || CorrelatorContext::new(meta_pb, &gpubox_pbs).map_err(Box::new), + || CorrelatorContext::new(metafits, gpuboxes).map_err(Box::new), "Still waiting to inspect all gpubox metadata", )?; let metafits_context = &mwalib_context.metafits_context; @@ -195,7 +183,7 @@ impl RawDataReader { let num_unflagged_tiles = total_num_tiles - tile_flags_set.len(); debug!("There are {} unflagged tiles", num_unflagged_tiles); if num_unflagged_tiles == 0 { - warn!("All of this observation's tiles are flagged"); + "All of this observation's tiles are flagged".warn(); } // Check that the tile flags are sensible. @@ -212,7 +200,7 @@ impl RawDataReader { let listed_delays = &metafits_context.delays; debug!("Listed observation dipole delays: {listed_delays:?}"); if listed_delays.iter().all(|&d| d == 32) { - warn!("This observation has been flagged as \"do not use\", according to the metafits delays!"); + "This observation has been flagged as \"do not use\", according to the metafits delays!".warn(); true } else { false @@ -273,9 +261,16 @@ impl RawDataReader { let mwa_coarse_chan_nums = Vec1::try_from_vec(coarse_chan_nums).expect("MWA data always has coarse channel info"); + let num_corr_fine_chans_per_coarse = NonZeroU16::new( + metafits_context + .num_corr_fine_chans_per_coarse + .try_into() + .expect("is smaller than u16::MAX"), + ) + .expect("never 0"); let flagged_fine_chans_per_coarse_chan = get_80khz_fine_chan_flags_per_coarse_chan( metafits_context.corr_fine_chan_width_hz, - metafits_context.num_corr_fine_chans_per_coarse, + num_corr_fine_chans_per_coarse, is_mwax, ); // Given the provided "common good" coarse channels, find the missing @@ -299,23 +294,23 @@ impl RawDataReader { ); for i_cc in coarse_chan_span.clone() { if missing_coarse_chans.contains(&i_cc) { - for f in 0..metafits_context.num_corr_fine_chans_per_coarse { + for f in 0..num_corr_fine_chans_per_coarse.get() { // The flagged channels are relative to the start of the // frequency band we're interested in. So if this is the // first coarse channel we're interested in, the flags // should start from 0, not wherever the coarse channel sits // within the whole observation band. flagged_fine_chans.push( - (i_cc - *coarse_chan_span.start()) - * metafits_context.num_corr_fine_chans_per_coarse + (i_cc - *coarse_chan_span.start()) as u16 + * num_corr_fine_chans_per_coarse.get() + f, ); } } else { for &f in &flagged_fine_chans_per_coarse_chan { flagged_fine_chans.push( - (i_cc - *coarse_chan_span.start()) - * metafits_context.num_corr_fine_chans_per_coarse + (i_cc - *coarse_chan_span.start()) as u16 + * num_corr_fine_chans_per_coarse.get() + f, ); } @@ -331,7 +326,7 @@ impl RawDataReader { let phase_centre = RADec::from_degrees( metafits_context.ra_phase_center_degrees.unwrap_or_else(|| { - warn!("No phase centre specified; using the pointing centre as the phase centre"); + "No phase centre specified; using the pointing centre as the phase centre".warn(); metafits_context.ra_tile_pointing_degrees }), metafits_context @@ -407,12 +402,20 @@ impl RawDataReader { debug!("Flag start time (GPS): {}", flags_start.to_gpst_seconds()); debug!("(flags_start - data_start).to_seconds() / time_res.to_seconds(): {diff}"); if diff.fract().abs() > 0.0 { - warn!("These mwaf files do not have times corresponding to the data they were created from."); + let mut block = vec!["These mwaf files do not have times corresponding to the data they were created from.".into()]; match f.software { - MwafProducer::Cotter => warn!(" This is a Cotter bug. You should probably use Birli to make new flags."), - MwafProducer::Birli => warn!(" These mwafs were made by Birli. Please file an issue!"), - MwafProducer::Unknown => warn!(" Unknown software made these mwafs."), + MwafProducer::Cotter => block.push( + "This is a Cotter bug. You should probably use Birli to make new flags." + .into(), + ), + MwafProducer::Birli => { + block.push("These mwafs were made by Birli. Please file an issue!".into()) + } + MwafProducer::Unknown => { + block.push("Unknown software made these mwafs.".into()) + } } + block.warn(); f.offset_bug = true; } @@ -429,21 +432,28 @@ impl RawDataReader { end_offset -= 1.0; } if start_offset > 0.0 || end_offset > 0.0 { - warn!("Not all MWA data timesteps have mwaf flags available"); + let mut block = vec!["Not all MWA data timesteps have mwaf flags available".into()]; match (start_offset > 0.0, end_offset > 0.0) { - (true, true) => warn!( - " {} timesteps at the start and {} at the end are not represented", - start_offset, end_offset, + (true, true) => block.push( + format!( + "{} timesteps at the start and {} at the end are not represented", + start_offset, end_offset, + ) + .into(), ), - (true, false) => warn!( - " {} timesteps at the start are not represented", - start_offset, + (true, false) => block.push( + format!( + "{} timesteps at the start are not represented", + start_offset, + ) + .into(), + ), + (false, true) => block.push( + format!("{} timesteps at the end are not represented", end_offset).into(), ), - (false, true) => { - warn!(" {} timesteps at the end are not represented", end_offset) - } (false, false) => unreachable!(), } + block.warn(); } Some(f) @@ -452,6 +462,7 @@ impl RawDataReader { }; let obs_context = ObsContext { + input_data_type: VisInputType::Raw, obsid: Some(metafits_context.obs_id), timestamps, all_timesteps, @@ -459,7 +470,7 @@ impl RawDataReader { phase_centre, pointing_centre, array_position, - _supplied_array_position: supplied_array_position, + supplied_array_position, dut1: metafits_context.dut1.map(Duration::from_seconds), tile_names, tile_xyzs, @@ -470,9 +481,7 @@ impl RawDataReader { dipole_gains: Some(dipole_gains), time_res: Some(time_res), mwa_coarse_chan_nums: Some(mwa_coarse_chan_nums), - num_fine_chans_per_coarse_chan: NonZeroUsize::new( - metafits_context.num_corr_fine_chans_per_coarse, - ), + num_fine_chans_per_coarse_chan: Some(num_corr_fine_chans_per_coarse), freq_res: Some(metafits_context.corr_fine_chan_width_hz as f64), fine_chan_freqs, flagged_fine_chans, @@ -512,11 +521,12 @@ impl RawDataReader { let flags = self.mwaf_flags.as_ref().unwrap().flags[&(gpubox_channel as u8)] .slice(s![timestep, .., ..,]); // Select only the applicable frequencies. - let n = self - .obs_context - .num_fine_chans_per_coarse_chan - .expect("raw MWA data always specifies this") - .get(); + let n = usize::from( + self.obs_context + .num_fine_chans_per_coarse_chan + .expect("raw MWA data always specifies this") + .get(), + ); let selection = s![ (i_gpubox_chan * n)..((i_gpubox_chan + 1) * n), .. // All baselines @@ -556,8 +566,8 @@ impl RawDataReader { crosses: Option, autos: Option, timestep: usize, - flagged_fine_chans: &HashSet, - ) -> Result<(), VisReadError> { + flagged_fine_chans: &HashSet, + ) -> Result<(), RawReadError> { // Check that mwaf flags are available for this timestep. let mwaf_timestep = match &self.mwaf_flags { Some(mwaf_flags) => { @@ -567,7 +577,7 @@ impl RawDataReader { let flags_end = flags_start + mwaf_flags.num_time_steps as f64 * time_res; let timestamp = self.obs_context.timestamps[timestep]; if !(flags_start..flags_end).contains(×tamp) { - return Err(VisReadError::MwafFlagsMissingForTimestep { + return Err(RawReadError::MwafFlagsMissingForTimestep { timestep, gps: timestamp.to_gpst_seconds(), }); @@ -641,13 +651,13 @@ impl RawDataReader { timestep_index, coarse_chan_index, }) => { - warn!( + format!( "Flagging missing data at timestep {timestep_index}, coarse channel {coarse_chan_index}" - ); + ).warn(); flag_array_fb.fill(true); } - Err(e) => return Err(RawReadError::from(Box::new(MwalibError::from(e))).into()), + Err(e) => return Err(RawReadError::from(Box::new(MwalibError::from(e)))), } } @@ -692,13 +702,15 @@ impl RawDataReader { coarse_chan_range, baseline_idxs: (0..self.all_baseline_tile_pairs.len()).collect(), }; - prep_ctx.preprocess( - &self.mwalib_context, - jones_array_tfb.view_mut(), - weight_array_tfb.view_mut(), - flag_array_tfb.view_mut(), - &vis_sel, - )?; + prep_ctx + .preprocess( + &self.mwalib_context, + jones_array_tfb.view_mut(), + weight_array_tfb.view_mut(), + flag_array_tfb.view_mut(), + &vis_sel, + ) + .map_err(Box::new)?; } // Convert the data array into a vector so we can use `chunks_exact` @@ -729,19 +741,18 @@ impl RawDataReader { tile_baseline_flags, }) = crosses { - data_vis_fb - .chunks_exact(metafits_context.num_baselines) - .zip_eq(data_weights_fb.chunks_exact(metafits_context.num_baselines)) - .enumerate() + (0..) + .zip(data_vis_fb.chunks_exact(metafits_context.num_baselines)) + .zip(data_weights_fb.chunks_exact(metafits_context.num_baselines)) // Let only unflagged channels proceed. - .filter(|(i_chan, _)| !flagged_fine_chans.contains(i_chan)) + .filter(|((i_chan, _), _)| !flagged_fine_chans.contains(i_chan)) // Discard the channel index and then zip with the outgoing // array. - .map(|(_, data)| data) - .zip_eq(vis_fb.outer_iter_mut()) - .zip_eq(weights_fb.outer_iter_mut()) + .map(|((_, data), weights)| (data, weights)) + .zip(vis_fb.outer_iter_mut()) + .zip(weights_fb.outer_iter_mut()) .for_each( - |(((data_vis_b, data_weights_b), mut vis_b), mut weights_b)| { + |(((data_vis_b, data_weights_b), mut vis_b), mut weight_b)| { data_vis_b .iter() .zip_eq(data_weights_b) @@ -757,7 +768,7 @@ impl RawDataReader { // Discard the baseline index and then zip with the outgoing array. .map(|(_, data)| data) .zip_eq(vis_b.iter_mut()) - .zip_eq(weights_b.iter_mut()) + .zip_eq(weight_b.iter_mut()) .for_each(|(((data_vis, data_weight), vis), weight)| { *vis = *data_vis; *weight = *data_weight; @@ -774,22 +785,21 @@ impl RawDataReader { tile_baseline_flags, }) = autos { - data_vis_fb - .chunks_exact(metafits_context.num_baselines) + (0..) + .zip(data_vis_fb.chunks_exact(metafits_context.num_baselines)) .zip(data_weights_fb.chunks_exact(metafits_context.num_baselines)) - .enumerate() // Let only unflagged channels proceed. - .filter(|(i_chan, _)| !flagged_fine_chans.contains(i_chan)) + .filter(|((i_chan, _), _)| !flagged_fine_chans.contains(i_chan)) // Discard the channel index and then zip with the outgoing // array. - .map(|(_, data)| data) + .map(|((_, data), weights)| (data, weights)) .zip_eq(vis_fb.outer_iter_mut()) .zip_eq(weights_fb.outer_iter_mut()) .for_each( - |(((data_vis_b, data_weights_b), mut vis_b), mut weights_b)| { + |(((data_vis_b, data_weight_b), mut vis_b), mut weights_b)| { data_vis_b .iter() - .zip(data_weights_b) + .zip(data_weight_b) .enumerate() // Let only unflagged autos proceed. .filter(|(i_baseline, _)| { @@ -831,6 +841,14 @@ impl VisRead for RawDataReader { self.mwaf_flags.as_ref() } + fn get_raw_data_corrections(&self) -> Option { + Some(self.corrections) + } + + fn set_raw_data_corrections(&mut self, corrections: RawDataCorrections) { + self.corrections = corrections; + } + fn read_crosses_and_autos( &self, cross_vis_fb: ArrayViewMut2>, @@ -839,7 +857,7 @@ impl VisRead for RawDataReader { auto_weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { self.read_inner( Some(CrossData { @@ -854,7 +872,8 @@ impl VisRead for RawDataReader { }), timestep, flagged_fine_chans, - ) + )?; + Ok(()) } fn read_crosses( @@ -863,7 +882,7 @@ impl VisRead for RawDataReader { weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { self.read_inner( Some(CrossData { @@ -874,7 +893,8 @@ impl VisRead for RawDataReader { None, timestep, flagged_fine_chans, - ) + )?; + Ok(()) } fn read_autos( @@ -883,7 +903,7 @@ impl VisRead for RawDataReader { weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { self.read_inner( None, @@ -894,7 +914,8 @@ impl VisRead for RawDataReader { }), timestep, flagged_fine_chans, - ) + )?; + Ok(()) } fn get_marlu_mwa_info(&self) -> Option { @@ -906,21 +927,27 @@ impl VisRead for RawDataReader { fn get_80khz_fine_chan_flags_per_coarse_chan( fine_chan_width: u32, - num_fine_chans_per_coarse_chan: usize, + num_fine_chans_per_coarse_chan: NonZeroU16, is_mwax: bool, -) -> Vec { +) -> Vec { let mut flags = vec![]; // Any fractional parts are discarded, meaning e.g. if the resolution was // 79kHz per channel, only 1 edge channel is flagged rather than 2. - let num_flagged_fine_chans_per_edge = (80000 / fine_chan_width) as usize; + let num_flagged_fine_chans_per_edge = (80000 / fine_chan_width) + .try_into() + .expect("smaller than u16::MAX"); for i in 0..num_flagged_fine_chans_per_edge { flags.push(i); - flags.push(num_fine_chans_per_coarse_chan - 1 - i); + flags.push( + (num_fine_chans_per_coarse_chan.get() - 1) + .checked_sub(i) + .expect("algorithm is sound"), + ); } // Also put the centre channel in if this isn't an MWAX obs. if !is_mwax { - flags.push(num_fine_chans_per_coarse_chan / 2); + flags.push(num_fine_chans_per_coarse_chan.get() / 2); } flags.sort_unstable(); flags diff --git a/src/io/read/raw/tests.rs b/src/io/read/raw/tests.rs index bb7817d7..eec0ae93 100644 --- a/src/io/read/raw/tests.rs +++ b/src/io/read/raw/tests.rs @@ -4,17 +4,22 @@ //! Tests for reading from raw MWA files. +use std::{collections::HashSet, num::NonZeroU16}; + use approx::{abs_diff_eq, assert_abs_diff_eq, assert_abs_diff_ne}; use itertools::Itertools; -use marlu::c32; +use marlu::{c32, Jones}; use ndarray::prelude::*; -use tempfile::{tempdir, TempDir}; +use tempfile::TempDir; -use super::*; +use super::{get_80khz_fine_chan_flags_per_coarse_chan, RawDataReader}; use crate::{ - cli::di_calibrate::DiCalArgs, - di_calibrate::{get_cal_vis, tests::test_1090008640_quality}, - tests::{deflate_gz_into_file, reduced_obsids::get_reduced_1090008640}, + io::read::{ + pfb_gains::{PfbFlavour, EMPIRICAL_40KHZ, LEVINE_40KHZ}, + RawDataCorrections, VisRead, + }, + math::TileBaselineFlags, + tests::{deflate_gz_into_file, get_reduced_1090008640_raw_pbs, DataAsPathBufs}, }; struct CrossData { @@ -27,28 +32,27 @@ struct AutoData { weights_array: Array2, } -fn get_cross_vis(args: DiCalArgs) -> CrossData { - let result = args.into_params(); - let params = match result { - Ok(p) => p, - Err(e) => panic!("{}", e), - }; - - let num_unflagged_cross_baselines = params - .tile_baseline_flags +fn get_cross_vis( + raw_reader: &RawDataReader, + tile_baseline_flags: &TileBaselineFlags, + flagged_fine_chans: &HashSet, +) -> CrossData { + let obs_context = raw_reader.get_obs_context(); + let num_unflagged_cross_baselines = tile_baseline_flags .tile_to_unflagged_cross_baseline_map .len(); - let num_unflagged_fine_chans = params.unflagged_fine_chan_freqs.len(); + let num_unflagged_fine_chans = obs_context.fine_chan_freqs.len() - flagged_fine_chans.len(); let vis_shape = (num_unflagged_fine_chans, num_unflagged_cross_baselines); + let mut data_array = Array2::zeros(vis_shape); let mut weights_array = Array2::zeros(vis_shape); - let result = params.input_data.read_crosses( + let result = raw_reader.read_crosses( data_array.view_mut(), weights_array.view_mut(), - *params.timesteps.first(), - ¶ms.tile_baseline_flags, - ¶ms.flagged_fine_chans, + *obs_context.all_timesteps.first(), + tile_baseline_flags, + flagged_fine_chans, ); assert!(result.is_ok(), "{}", result.unwrap_err()); result.unwrap(); @@ -59,25 +63,25 @@ fn get_cross_vis(args: DiCalArgs) -> CrossData { } } -fn get_auto_vis(args: DiCalArgs) -> AutoData { - let result = args.into_params(); - let params = match result { - Ok(p) => p, - Err(e) => panic!("{}", e), - }; - - let num_unflagged_tiles = params.unflagged_tile_xyzs.len(); - let num_unflagged_fine_chans = params.unflagged_fine_chan_freqs.len(); +fn get_auto_vis( + raw_reader: &RawDataReader, + tile_baseline_flags: &TileBaselineFlags, + flagged_fine_chans: &HashSet, +) -> AutoData { + let obs_context = raw_reader.get_obs_context(); + let num_unflagged_tiles = obs_context.tile_xyzs.len() - tile_baseline_flags.flagged_tiles.len(); + let num_unflagged_fine_chans = obs_context.fine_chan_freqs.len() - flagged_fine_chans.len(); let vis_shape = (num_unflagged_fine_chans, num_unflagged_tiles); + let mut data_array = Array2::zeros(vis_shape); let mut weights_array = Array2::zeros(vis_shape); - let result = params.input_data.read_autos( + let result = raw_reader.read_autos( data_array.view_mut(), weights_array.view_mut(), - *params.timesteps.first(), - ¶ms.tile_baseline_flags, - ¶ms.flagged_fine_chans, + *obs_context.all_timesteps.first(), + tile_baseline_flags, + flagged_fine_chans, ); assert!(result.is_ok(), "{}", result.unwrap_err()); result.unwrap(); @@ -88,39 +92,38 @@ fn get_auto_vis(args: DiCalArgs) -> AutoData { } } -fn get_cross_and_auto_vis(args: DiCalArgs) -> (CrossData, AutoData) { - let result = args.into_params(); - let params = match result { - Ok(p) => p, - Err(e) => panic!("{}", e), - }; - - let num_unflagged_cross_baselines = params - .tile_baseline_flags +fn get_cross_and_auto_vis( + raw_reader: &RawDataReader, + tile_baseline_flags: &TileBaselineFlags, + flagged_fine_chans: &HashSet, +) -> (CrossData, AutoData) { + let obs_context = raw_reader.get_obs_context(); + let num_unflagged_cross_baselines = tile_baseline_flags .tile_to_unflagged_cross_baseline_map .len(); - let num_unflagged_fine_chans = params.unflagged_fine_chan_freqs.len(); + let num_unflagged_tiles = obs_context.tile_xyzs.len() - tile_baseline_flags.flagged_tiles.len(); + let num_unflagged_fine_chans = obs_context.fine_chan_freqs.len() - flagged_fine_chans.len(); + let vis_shape = (num_unflagged_fine_chans, num_unflagged_cross_baselines); let mut cross_data = CrossData { data_array: Array2::zeros(vis_shape), weights_array: Array2::zeros(vis_shape), }; - let num_unflagged_tiles = params.unflagged_tile_xyzs.len(); let vis_shape = (num_unflagged_fine_chans, num_unflagged_tiles); let mut auto_data = AutoData { data_array: Array2::zeros(vis_shape), weights_array: Array2::zeros(vis_shape), }; - let result = params.input_data.read_crosses_and_autos( + let result = raw_reader.read_crosses_and_autos( cross_data.data_array.view_mut(), cross_data.weights_array.view_mut(), auto_data.data_array.view_mut(), auto_data.weights_array.view_mut(), - *params.timesteps.first(), - ¶ms.tile_baseline_flags, - ¶ms.flagged_fine_chans, + *obs_context.all_timesteps.first(), + tile_baseline_flags, + flagged_fine_chans, ); assert!(result.is_ok(), "{}", result.unwrap_err()); result.unwrap(); @@ -132,16 +135,22 @@ fn get_cross_and_auto_vis(args: DiCalArgs) -> (CrossData, AutoData) { fn read_1090008640_cross_vis() { // Other tests check that PFB gains and digital gains are correctly applied. // These simple _vis tests just check that the values are right. - let mut args = get_reduced_1090008640(false, false); - args.pfb_flavour = Some("none".to_string()); - args.no_cable_length_correction = true; - args.no_geometric_correction = true; - args.no_digital_gains = true; - args.ignore_input_data_fine_channel_flags = true; + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: false, + cable_length: false, + geometric: false, + }; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); + let obs_context = raw_reader.get_obs_context(); + let tile_baseline_flags = + TileBaselineFlags::new(obs_context.get_total_num_tiles(), HashSet::new()); + let CrossData { data_array: vis, weights_array: weights, - } = get_cross_vis(args); + } = get_cross_vis(&raw_reader, &tile_baseline_flags, &HashSet::new()); assert_abs_diff_eq!( vis[(0, 0)], @@ -168,13 +177,24 @@ fn read_1090008640_cross_vis() { // Test the visibility values with corrections applied (except PFB gains). #[test] fn read_1090008640_cross_vis_with_corrections() { - let mut args = get_reduced_1090008640(false, false); - args.pfb_flavour = Some("none".to_string()); - args.ignore_input_data_fine_channel_flags = true; + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: true, + cable_length: true, + geometric: true, + }; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); + let obs_context = raw_reader.get_obs_context(); + let tile_baseline_flags = TileBaselineFlags::new( + obs_context.get_total_num_tiles(), + obs_context.flagged_tiles.iter().copied().collect(), + ); + let CrossData { data_array: vis, weights_array: weights, - } = get_cross_vis(args); + } = get_cross_vis(&raw_reader, &tile_baseline_flags, &HashSet::new()); assert_abs_diff_eq!( vis[(0, 0)], @@ -200,16 +220,24 @@ fn read_1090008640_cross_vis_with_corrections() { #[test] fn read_1090008640_auto_vis() { - let mut args = get_reduced_1090008640(false, false); - args.pfb_flavour = Some("none".to_string()); - args.no_cable_length_correction = true; - args.no_geometric_correction = true; - args.no_digital_gains = true; - args.ignore_input_data_fine_channel_flags = true; + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: false, + cable_length: false, + geometric: false, + }; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); + let obs_context = raw_reader.get_obs_context(); + let tile_baseline_flags = TileBaselineFlags::new( + obs_context.get_total_num_tiles(), + obs_context.flagged_tiles.iter().copied().collect(), + ); + let AutoData { data_array: vis, weights_array: weights, - } = get_auto_vis(args); + } = get_auto_vis(&raw_reader, &tile_baseline_flags, &HashSet::new()); assert_abs_diff_eq!( vis[(0, 0)], @@ -253,18 +281,29 @@ fn read_1090008640_auto_vis() { #[test] fn read_1090008640_auto_vis_with_flags() { - let mut args = get_reduced_1090008640(false, false); - args.pfb_flavour = Some("none".to_string()); - args.no_cable_length_correction = true; - args.no_geometric_correction = true; - args.no_digital_gains = true; - args.ignore_input_data_fine_channel_flags = true; - args.tile_flags = Some(vec!["1".to_string(), "9".to_string()]); - args.fine_chan_flags = Some(vec![1]); + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: false, + cable_length: false, + geometric: false, + }; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); + let obs_context = raw_reader.get_obs_context(); + let tile_baseline_flags = TileBaselineFlags::new( + obs_context.get_total_num_tiles(), + obs_context + .flagged_tiles + .iter() + .copied() + .chain([1, 9]) + .collect(), + ); + let AutoData { data_array: vis, weights_array: weights, - } = get_auto_vis(args); + } = get_auto_vis(&raw_reader, &tile_baseline_flags, &HashSet::from([1])); // Use the same values as the test above, adjusting only the indices. assert_abs_diff_eq!( @@ -312,13 +351,22 @@ fn read_1090008640_auto_vis_with_flags() { #[test] fn read_1090008640_cross_and_auto_vis() { - let mut args = get_reduced_1090008640(false, false); - args.pfb_flavour = Some("none".to_string()); - args.no_cable_length_correction = true; - args.no_geometric_correction = true; - args.no_digital_gains = true; - args.ignore_input_data_fine_channel_flags = true; - let (cross_data, auto_data) = get_cross_and_auto_vis(args); + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: false, + cable_length: false, + geometric: false, + }; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); + let obs_context = raw_reader.get_obs_context(); + let tile_baseline_flags = TileBaselineFlags::new( + obs_context.get_total_num_tiles(), + obs_context.flagged_tiles.iter().copied().collect(), + ); + + let (cross_data, auto_data) = + get_cross_and_auto_vis(&raw_reader, &tile_baseline_flags, &HashSet::new()); // Test values should match those used in "cross_vis" and "auto_vis" tests; assert_abs_diff_eq!( @@ -371,27 +419,43 @@ fn read_1090008640_cross_and_auto_vis() { #[test] fn pfb_empirical_gains() { - let mut args = get_reduced_1090008640(false, false); - args.pfb_flavour = Some("empirical".to_string()); - args.ignore_input_data_fine_channel_flags = true; + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::Empirical, + digital_gains: true, + cable_length: true, + geometric: true, + }; + let raw_reader_with_pfb = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); + let obs_context = raw_reader_with_pfb.get_obs_context(); + let tile_baseline_flags = TileBaselineFlags::new( + raw_reader_with_pfb.get_obs_context().get_total_num_tiles(), + obs_context.flagged_tiles.iter().copied().collect(), + ); + let CrossData { data_array: vis_pfb, weights_array: weights_pfb, - } = get_cross_vis(args); + } = get_cross_vis(&raw_reader_with_pfb, &tile_baseline_flags, &HashSet::new()); + + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: true, + cable_length: true, + geometric: true, + }; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); - let mut args = get_reduced_1090008640(false, false); - args.pfb_flavour = Some("none".to_string()); - args.ignore_input_data_fine_channel_flags = true; let CrossData { data_array: vis_no_pfb, weights_array: weights_no_pfb, - } = get_cross_vis(args); + } = get_cross_vis(&raw_reader, &tile_baseline_flags, &HashSet::new()); // Test each visibility individually. vis_pfb .outer_iter() .zip_eq(vis_no_pfb.outer_iter()) - .zip_eq(pfb_gains::EMPIRICAL_40KHZ.iter()) + .zip_eq(EMPIRICAL_40KHZ.iter()) .for_each(|((vis_pfb, vis_no_pfb), &gain)| { vis_pfb .iter() @@ -412,29 +476,43 @@ fn pfb_empirical_gains() { #[test] fn pfb_levine_gains() { - let mut args = get_reduced_1090008640(false, false); - args.pfb_flavour = Some("levine".to_string()); - args.no_digital_gains = true; - args.ignore_input_data_fine_channel_flags = true; + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::Levine, + digital_gains: false, + cable_length: true, + geometric: true, + }; + let raw_reader_with_pfb = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); + let obs_context = raw_reader_with_pfb.get_obs_context(); + let tile_baseline_flags = TileBaselineFlags::new( + raw_reader_with_pfb.get_obs_context().get_total_num_tiles(), + obs_context.flagged_tiles.iter().copied().collect(), + ); + let CrossData { data_array: vis_pfb, weights_array: weights_pfb, - } = get_cross_vis(args); + } = get_cross_vis(&raw_reader_with_pfb, &tile_baseline_flags, &HashSet::new()); + + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: false, + cable_length: true, + geometric: true, + }; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); - let mut args = get_reduced_1090008640(false, false); - args.pfb_flavour = Some("none".to_string()); - args.no_digital_gains = true; - args.ignore_input_data_fine_channel_flags = true; let CrossData { data_array: vis_no_pfb, weights_array: weights_no_pfb, - } = get_cross_vis(args); + } = get_cross_vis(&raw_reader, &tile_baseline_flags, &HashSet::new()); // Test each visibility individually. vis_pfb .outer_iter() .zip_eq(vis_no_pfb.outer_iter()) - .zip(pfb_gains::LEVINE_40KHZ.iter()) + .zip(LEVINE_40KHZ.iter()) .for_each(|((vis_pfb, vis_no_pfb), &gain)| { vis_pfb .iter() @@ -455,23 +533,37 @@ fn pfb_levine_gains() { #[test] fn test_digital_gains() { - let mut args = get_reduced_1090008640(false, false); - // Some("none") turns off the PFB correction, whereas None would be the - // default behaviour (apply PFB correction). - args.pfb_flavour = Some("none".to_string()); - args.no_digital_gains = false; + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: true, + cable_length: true, + geometric: true, + }; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); + let obs_context = raw_reader.get_obs_context(); + let tile_baseline_flags = TileBaselineFlags::new( + obs_context.get_total_num_tiles(), + obs_context.flagged_tiles.iter().copied().collect(), + ); + let CrossData { data_array: vis_dg, weights_array: weights_dg, - } = get_cross_vis(args); + } = get_cross_vis(&raw_reader, &tile_baseline_flags, &HashSet::new()); + + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: false, + cable_length: true, + geometric: true, + }; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); - let mut args = get_reduced_1090008640(false, false); - args.pfb_flavour = Some("none".to_string()); - args.no_digital_gains = true; let CrossData { data_array: vis_no_dg, weights_array: weights_no_dg, - } = get_cross_vis(args); + } = get_cross_vis(&raw_reader, &tile_baseline_flags, &HashSet::new()); let i_bl = 0; // Promote the Jones matrices for better accuracy. @@ -498,26 +590,31 @@ fn test_digital_gains() { #[test] fn test_mwaf_flags() { // First test without any mwaf flags. - let mut args = get_reduced_1090008640(false, false); - args.ignore_input_data_fine_channel_flags = true; - args.ignore_input_data_tile_flags = true; - args.pfb_flavour = Some("none".to_string()); - args.no_digital_gains = false; - - let result = args.into_params(); - let params = match result { - Ok(p) => p, - Err(e) => panic!("{}", e), + let DataAsPathBufs { + metafits, + vis, + mwafs, + .. + } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: true, + cable_length: true, + geometric: true, }; - let timesteps = params.timesteps; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); + let obs_context = raw_reader.get_obs_context(); + let flagged_fine_chans = HashSet::new(); + let num_tiles = obs_context.get_total_num_tiles(); + let tile_baseline_flags = TileBaselineFlags::new(num_tiles, HashSet::new()); // Set up our arrays for reading. - let num_unflagged_cross_baselines = params - .tile_baseline_flags + let num_unflagged_cross_baselines = tile_baseline_flags .tile_to_unflagged_cross_baseline_map .len(); - let num_unflagged_tiles = params.unflagged_tile_xyzs.len(); - let num_unflagged_fine_chans = params.unflagged_fine_chan_freqs.len(); + let num_unflagged_tiles = obs_context.tile_xyzs.len() - tile_baseline_flags.flagged_tiles.len(); + let num_unflagged_fine_chans = obs_context.fine_chan_freqs.len() - flagged_fine_chans.len(); + let cross_vis_shape = (num_unflagged_fine_chans, num_unflagged_cross_baselines); let mut cross_data_array = Array2::from_elem(cross_vis_shape, Jones::identity()); let mut cross_weights_array = Array2::ones(cross_vis_shape); @@ -525,44 +622,34 @@ fn test_mwaf_flags() { let mut auto_data_array = Array2::from_elem(auto_vis_shape, Jones::identity()); let mut auto_weights_array = Array2::ones(auto_vis_shape); - let result = params.input_data.read_crosses_and_autos( + let result = raw_reader.read_crosses_and_autos( cross_data_array.view_mut(), cross_weights_array.view_mut(), auto_data_array.view_mut(), auto_weights_array.view_mut(), - *timesteps.first(), - ¶ms.tile_baseline_flags, - ¶ms.flagged_fine_chans, + *obs_context.all_timesteps.first(), + &tile_baseline_flags, + &flagged_fine_chans, ); assert!(result.is_ok(), "{}", result.unwrap_err()); result.unwrap(); // Now use the flags from our doctored mwaf file. - let mut args = get_reduced_1090008640(false, true); - args.ignore_input_data_fine_channel_flags = true; - args.ignore_input_data_tile_flags = true; - args.pfb_flavour = Some("none".to_string()); - args.no_digital_gains = false; - - let result = args.into_params(); - let params = match result { - Ok(p) => p, - Err(e) => panic!("{}", e), - }; + let raw_reader = RawDataReader::new(&metafits, &vis, Some(&mwafs), corrections, None).unwrap(); let mut flagged_cross_data_array = Array2::from_elem(cross_vis_shape, Jones::identity()); let mut flagged_cross_weights_array = Array2::ones(cross_vis_shape); let mut flagged_auto_data_array = Array2::from_elem(auto_vis_shape, Jones::identity()); let mut flagged_auto_weights_array = Array2::ones(auto_vis_shape); - let result = params.input_data.read_crosses_and_autos( + let result = raw_reader.read_crosses_and_autos( flagged_cross_data_array.view_mut(), flagged_cross_weights_array.view_mut(), flagged_auto_data_array.view_mut(), flagged_auto_weights_array.view_mut(), - *timesteps.first(), - ¶ms.tile_baseline_flags, - ¶ms.flagged_fine_chans, + *obs_context.all_timesteps.first(), + &tile_baseline_flags, + &flagged_fine_chans, ); assert!(result.is_ok(), "{}", result.unwrap_err()); result.unwrap(); @@ -574,17 +661,15 @@ fn test_mwaf_flags() { assert_eq!(auto_weights_array, flagged_auto_weights_array); // Iterate over the weight arrays, checking for flags. - let num_tiles = params.get_total_num_tiles(); let num_bls = (num_tiles * (num_tiles + 1)) / 2; - let num_freqs = params.get_obs_context().fine_chan_freqs.len(); + let num_freqs = obs_context.fine_chan_freqs.len(); // Unfortunately we have to conditionally select either the auto or cross // visibilities. for i_chan in 0..num_freqs { let mut i_auto = 0; let mut i_cross = 0; for i_bl in 0..num_bls { - let (tile1, tile2) = - marlu::math::baseline_to_tiles(params.unflagged_tile_xyzs.len(), i_bl); + let (tile1, tile2) = marlu::math::baseline_to_tiles(num_tiles, i_bl); let weight = if tile1 == tile2 { i_auto += 1; @@ -608,26 +693,27 @@ fn test_mwaf_flags() { #[test] fn test_mwaf_flags_primes() { // First test without any mwaf flags. - let mut args = get_reduced_1090008640(false, false); - args.ignore_input_data_fine_channel_flags = true; - args.ignore_input_data_tile_flags = true; - args.pfb_flavour = Some("none".to_string()); - args.no_digital_gains = false; - - let result = args.into_params(); - let params = match result { - Ok(p) => p, - Err(e) => panic!("{}", e), + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: true, + cable_length: false, + geometric: false, }; - let timesteps = params.timesteps; + let raw_reader = RawDataReader::new(&metafits, &vis, None, corrections, None).unwrap(); + let obs_context = raw_reader.get_obs_context(); + let timesteps = &obs_context.all_timesteps; + let total_num_tiles = obs_context.get_total_num_tiles(); + let num_unflagged_tiles = total_num_tiles - obs_context.flagged_tiles.len(); + let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + let tile_baseline_flags = TileBaselineFlags::new( + total_num_tiles, + obs_context.flagged_tiles.iter().copied().collect(), + ); + let flagged_fine_chans = HashSet::new(); + let num_unflagged_fine_chans = obs_context.fine_chan_freqs.len() - flagged_fine_chans.len(); // Set up our arrays for reading. - let num_unflagged_cross_baselines = params - .tile_baseline_flags - .tile_to_unflagged_cross_baseline_map - .len(); - let num_unflagged_tiles = params.unflagged_tile_xyzs.len(); - let num_unflagged_fine_chans = params.unflagged_fine_chan_freqs.len(); let cross_vis_shape = (num_unflagged_fine_chans, num_unflagged_cross_baselines); let mut cross_data_array = Array2::from_elem(cross_vis_shape, Jones::identity()); let mut cross_weights_array = Array2::ones(cross_vis_shape); @@ -635,52 +721,39 @@ fn test_mwaf_flags_primes() { let mut auto_data_array = Array2::from_elem(auto_vis_shape, Jones::identity()); let mut auto_weights_array = Array2::ones(auto_vis_shape); - let result = params.input_data.read_crosses_and_autos( + let result = raw_reader.read_crosses_and_autos( cross_data_array.view_mut(), cross_weights_array.view_mut(), auto_data_array.view_mut(), auto_weights_array.view_mut(), *timesteps.first(), - ¶ms.tile_baseline_flags, - ¶ms.flagged_fine_chans, + &tile_baseline_flags, + &flagged_fine_chans, ); assert!(result.is_ok(), "{}", result.unwrap_err()); result.unwrap(); // Now use the flags from our "primes" mwaf file. - let mut args = get_reduced_1090008640(false, false); - args.ignore_input_data_fine_channel_flags = true; - args.ignore_input_data_tile_flags = true; - args.pfb_flavour = Some("none".to_string()); - args.no_digital_gains = false; let temp_dir = TempDir::new().unwrap(); let mwaf_pb = temp_dir.path().join("primes.mwaf"); let mut mwaf_file = std::fs::File::create(&mwaf_pb).unwrap(); deflate_gz_into_file("test_files/1090008640/primes_01.mwaf.gz", &mut mwaf_file); - match &mut args.data { - Some(d) => d.push(mwaf_pb.display().to_string()), - None => unreachable!(), - } - - let result = args.into_params(); - let params = match result { - Ok(p) => p, - Err(e) => panic!("{}", e), - }; + let raw_reader = + RawDataReader::new(&metafits, &vis, Some(&[mwaf_pb]), corrections, None).unwrap(); let mut flagged_cross_data_array = Array2::from_elem(cross_vis_shape, Jones::identity()); let mut flagged_cross_weights_array = Array2::ones(cross_vis_shape); let mut flagged_auto_data_array = Array2::from_elem(auto_vis_shape, Jones::identity()); let mut flagged_auto_weights_array = Array2::ones(auto_vis_shape); - let result = params.input_data.read_crosses_and_autos( + let result = raw_reader.read_crosses_and_autos( flagged_cross_data_array.view_mut(), flagged_cross_weights_array.view_mut(), flagged_auto_data_array.view_mut(), flagged_auto_weights_array.view_mut(), *timesteps.first(), - ¶ms.tile_baseline_flags, - ¶ms.flagged_fine_chans, + &tile_baseline_flags, + &flagged_fine_chans, ); assert!(result.is_ok(), "{}", result.unwrap_err()); result.unwrap(); @@ -691,9 +764,9 @@ fn test_mwaf_flags_primes() { // Iterate over the arrays, where are the differences? They should be // primes. - let num_tiles = params.get_total_num_tiles(); + let num_tiles = total_num_tiles; let num_bls = (num_tiles * (num_tiles + 1)) / 2; - let num_freqs = params.get_obs_context().fine_chan_freqs.len(); + let num_freqs = obs_context.fine_chan_freqs.len(); // Unfortunately we have to conditionally select either the auto or cross // visibilities. for i_chan in 0..num_freqs { @@ -703,8 +776,7 @@ fn test_mwaf_flags_primes() { // This mwaf file was created with baselines moving slower than // frequencies. let is_prime = crate::math::is_prime(i_bl * num_freqs + i_chan); - let (tile1, tile2) = - marlu::math::baseline_to_tiles(params.unflagged_tile_xyzs.len(), i_bl); + let (tile1, tile2) = marlu::math::baseline_to_tiles(num_unflagged_tiles, i_bl); let weight = if tile1 == tile2 { i_auto += 1; @@ -724,38 +796,36 @@ fn test_mwaf_flags_primes() { /// Test that cotter flags are correctly (as possible) ingested. #[test] fn test_mwaf_flags_cotter() { - let mut args = get_reduced_1090008640(false, false); - args.ignore_input_data_fine_channel_flags = true; - args.ignore_input_data_tile_flags = true; - args.pfb_flavour = Some("none".to_string()); - args.no_digital_gains = false; + let DataAsPathBufs { metafits, vis, .. } = get_reduced_1090008640_raw_pbs(); + let corrections = RawDataCorrections { + pfb_flavour: PfbFlavour::None, + digital_gains: true, + cable_length: true, + geometric: true, + }; let temp_dir = TempDir::new().unwrap(); - let mwaf_pb = temp_dir.path().join("cotter.mwaf"); - let mut mwaf_file = std::fs::File::create(&mwaf_pb).unwrap(); + let mwafs = [temp_dir.path().join("cotter.mwaf")]; + let mut mwaf_file = std::fs::File::create(&mwafs[0]).unwrap(); deflate_gz_into_file( "test_files/1090008640/1090008640_01_cotter.mwaf.gz", &mut mwaf_file, ); - match &mut args.data { - Some(d) => d.push(mwaf_pb.display().to_string()), - None => unreachable!(), - } - let result = args.clone().into_params(); - let params = match result { - Ok(p) => p, - Err(e) => panic!("{}", e), - }; - let timesteps = ¶ms.timesteps; + let raw_reader = RawDataReader::new(&metafits, &vis, Some(&mwafs), corrections, None).unwrap(); + let obs_context = raw_reader.get_obs_context(); + let timesteps = &obs_context.all_timesteps; + let total_num_tiles = obs_context.get_total_num_tiles(); + let num_unflagged_tiles = total_num_tiles - obs_context.flagged_tiles.len(); + let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + let tile_baseline_flags = TileBaselineFlags::new( + total_num_tiles, + obs_context.flagged_tiles.iter().copied().collect(), + ); + let flagged_fine_chans = HashSet::new(); + let num_unflagged_fine_chans = obs_context.fine_chan_freqs.len() - flagged_fine_chans.len(); // Set up our arrays for reading. - let num_unflagged_cross_baselines = params - .tile_baseline_flags - .tile_to_unflagged_cross_baseline_map - .len(); - let num_unflagged_tiles = params.unflagged_tile_xyzs.len(); - let num_unflagged_fine_chans = params.unflagged_fine_chan_freqs.len(); let cross_vis_shape = (num_unflagged_fine_chans, num_unflagged_cross_baselines); let mut cross_data_array = Array2::from_elem(cross_vis_shape, Jones::identity()); let mut cross_weights_array = Array2::ones(cross_vis_shape); @@ -763,30 +833,29 @@ fn test_mwaf_flags_cotter() { let mut auto_data_array = Array2::from_elem(auto_vis_shape, Jones::identity()); let mut auto_weights_array = Array2::ones(auto_vis_shape); - let result = params.input_data.read_crosses_and_autos( + let result = raw_reader.read_crosses_and_autos( cross_data_array.view_mut(), cross_weights_array.view_mut(), auto_data_array.view_mut(), auto_weights_array.view_mut(), *timesteps.first(), - ¶ms.tile_baseline_flags, - ¶ms.flagged_fine_chans, + &tile_baseline_flags, + &flagged_fine_chans, ); assert!(result.is_ok(), "{}", result.unwrap_err()); result.unwrap(); // Iterate over the weight arrays. - let num_tiles = params.get_total_num_tiles(); + let num_tiles = obs_context.get_total_num_tiles(); let num_bls = (num_tiles * (num_tiles + 1)) / 2; - let num_freqs = params.get_obs_context().fine_chan_freqs.len(); + let num_freqs = obs_context.fine_chan_freqs.len(); // Unfortunately we have to conditionally select either the auto or cross // visibilities. for i_chan in 0..num_freqs { let mut i_auto = 0; let mut i_cross = 0; for i_bl in 0..num_bls { - let (tile1, tile2) = - marlu::math::baseline_to_tiles(params.unflagged_tile_xyzs.len(), i_bl); + let (tile1, tile2) = marlu::math::baseline_to_tiles(num_unflagged_tiles, i_bl); let weight = if tile1 == tile2 { i_auto += 1; @@ -807,34 +876,21 @@ fn test_mwaf_flags_cotter() { } // Do it all again, but this time with the forward offset flags. - let mut mwaf_file = std::fs::File::create(&mwaf_pb).unwrap(); + let mut mwaf_file = std::fs::File::create(&mwafs[0]).unwrap(); deflate_gz_into_file( "test_files/1090008640/1090008640_01_cotter_offset_forwards.mwaf.gz", &mut mwaf_file, ); - match &mut args.data { - Some(d) => { - d.pop(); - d.push(mwaf_pb.display().to_string()) - } - None => unreachable!(), - } + let raw_reader = RawDataReader::new(&metafits, &vis, Some(&mwafs), corrections, None).unwrap(); - let result = args.clone().into_params(); - let params = match result { - Ok(p) => p, - Err(e) => panic!("{}", e), - }; - let timesteps = ¶ms.timesteps; - - let result = params.input_data.read_crosses_and_autos( + let result = raw_reader.read_crosses_and_autos( cross_data_array.view_mut(), cross_weights_array.view_mut(), auto_data_array.view_mut(), auto_weights_array.view_mut(), *timesteps.first(), - ¶ms.tile_baseline_flags, - ¶ms.flagged_fine_chans, + &tile_baseline_flags, + &flagged_fine_chans, ); assert!(result.is_ok(), "{}", result.unwrap_err()); result.unwrap(); @@ -843,8 +899,7 @@ fn test_mwaf_flags_cotter() { let mut i_auto = 0; let mut i_cross = 0; for i_bl in 0..num_bls { - let (tile1, tile2) = - marlu::math::baseline_to_tiles(params.unflagged_tile_xyzs.len(), i_bl); + let (tile1, tile2) = marlu::math::baseline_to_tiles(num_unflagged_tiles, i_bl); let weight = if tile1 == tile2 { i_auto += 1; @@ -865,34 +920,21 @@ fn test_mwaf_flags_cotter() { } // Finally the backward offset flags. - let mut mwaf_file = std::fs::File::create(&mwaf_pb).unwrap(); + let mut mwaf_file = std::fs::File::create(&mwafs[0]).unwrap(); deflate_gz_into_file( "test_files/1090008640/1090008640_01_cotter_offset_backwards.mwaf.gz", &mut mwaf_file, ); - match &mut args.data { - Some(d) => { - d.pop(); - d.push(mwaf_pb.display().to_string()) - } - None => unreachable!(), - } - - let result = args.clone().into_params(); - let params = match result { - Ok(p) => p, - Err(e) => panic!("{}", e), - }; - let timesteps = ¶ms.timesteps; + let raw_reader = RawDataReader::new(&metafits, &vis, Some(&mwafs), corrections, None).unwrap(); - let result = params.input_data.read_crosses_and_autos( + let result = raw_reader.read_crosses_and_autos( cross_data_array.view_mut(), cross_weights_array.view_mut(), auto_data_array.view_mut(), auto_weights_array.view_mut(), *timesteps.first(), - ¶ms.tile_baseline_flags, - ¶ms.flagged_fine_chans, + &tile_baseline_flags, + &flagged_fine_chans, ); assert!(result.is_ok(), "{}", result.unwrap_err()); result.unwrap(); @@ -901,8 +943,7 @@ fn test_mwaf_flags_cotter() { let mut i_auto = 0; let mut i_cross = 0; for i_bl in 0..num_bls { - let (tile1, tile2) = - marlu::math::baseline_to_tiles(params.unflagged_tile_xyzs.len(), i_bl); + let (tile1, tile2) = marlu::math::baseline_to_tiles(num_unflagged_tiles, i_bl); let weight = if tile1 == tile2 { i_auto += 1; @@ -926,58 +967,39 @@ fn test_mwaf_flags_cotter() { #[test] fn test_default_flags_per_coarse_chan() { assert_eq!( - get_80khz_fine_chan_flags_per_coarse_chan(10000, 128, true), + get_80khz_fine_chan_flags_per_coarse_chan(10000, NonZeroU16::new(128).unwrap(), true), &[0, 1, 2, 3, 4, 5, 6, 7, 120, 121, 122, 123, 124, 125, 126, 127] ); assert_eq!( - get_80khz_fine_chan_flags_per_coarse_chan(10000, 128, false), + get_80khz_fine_chan_flags_per_coarse_chan(10000, NonZeroU16::new(128).unwrap(), false), &[0, 1, 2, 3, 4, 5, 6, 7, 64, 120, 121, 122, 123, 124, 125, 126, 127] ); assert_eq!( - get_80khz_fine_chan_flags_per_coarse_chan(20000, 64, true), + get_80khz_fine_chan_flags_per_coarse_chan(20000, NonZeroU16::new(64).unwrap(), true), &[0, 1, 2, 3, 60, 61, 62, 63] ); assert_eq!( - get_80khz_fine_chan_flags_per_coarse_chan(20000, 64, false), + get_80khz_fine_chan_flags_per_coarse_chan(20000, NonZeroU16::new(64).unwrap(), false), &[0, 1, 2, 3, 32, 60, 61, 62, 63] ); assert_eq!( - get_80khz_fine_chan_flags_per_coarse_chan(40000, 32, true), + get_80khz_fine_chan_flags_per_coarse_chan(40000, NonZeroU16::new(32).unwrap(), true), &[0, 1, 30, 31] ); assert_eq!( - get_80khz_fine_chan_flags_per_coarse_chan(40000, 32, false), + get_80khz_fine_chan_flags_per_coarse_chan(40000, NonZeroU16::new(32).unwrap(), false), &[0, 1, 16, 30, 31] ); // Future proofing? assert_eq!( - get_80khz_fine_chan_flags_per_coarse_chan(7200, 100, true), + get_80khz_fine_chan_flags_per_coarse_chan(7200, NonZeroU16::new(100).unwrap(), true), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] ); assert_eq!( - get_80khz_fine_chan_flags_per_coarse_chan(7200, 100, false), + get_80khz_fine_chan_flags_per_coarse_chan(7200, NonZeroU16::new(100).unwrap(), false), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 50, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99] ); } - -#[test] -fn test_1090008640_calibration_quality() { - let mut args = get_reduced_1090008640(false, false); - let temp_dir = tempdir().expect("Couldn't make temp dir"); - args.outputs = Some(vec![temp_dir.path().join("hyp_sols.fits")]); - args.pfb_flavour = Some("none".to_string()); - // To be consistent with other data quality tests, add these flags. - args.fine_chan_flags = Some(vec![0, 1, 2, 16, 30, 31]); - - let result = args.into_params(); - let params = match result { - Ok(r) => r, - Err(e) => panic!("{}", e), - }; - - let cal_vis = get_cal_vis(¶ms, false).expect("Couldn't read data and generate a model"); - test_1090008640_quality(params, cal_vis); -} diff --git a/src/io/read/uvfits/mod.rs b/src/io/read/uvfits/mod.rs index ffd147d3..af67a6c9 100644 --- a/src/io/read/uvfits/mod.rs +++ b/src/io/read/uvfits/mod.rs @@ -16,7 +16,7 @@ pub(crate) use error::*; use std::{ borrow::Cow, collections::{HashMap, HashSet}, - num::NonZeroUsize, + num::NonZeroU16, os::raw::c_char, path::{Path, PathBuf}, }; @@ -35,6 +35,7 @@ use num_complex::Complex; use super::*; use crate::{ beam::Delays, + cli::Warn, context::{ObsContext, Polarisations}, io::read::{ fits::{ @@ -47,11 +48,11 @@ use crate::{ pub struct UvfitsReader { /// Observation metadata. - pub(super) obs_context: ObsContext, + obs_context: ObsContext, // uvfits-specific things follow. /// The path to the uvfits on disk. - pub(crate) uvfits: PathBuf, + uvfits: PathBuf, /// The uvfits-specific metadata, like which indices contain which /// parameters. @@ -139,7 +140,7 @@ impl UvfitsReader { let frame: Option = fits_get_optional_key(&mut uvfits_fptr, &antenna_table_hdu, "FRAME")?; if !matches!(frame.as_deref(), Some("ITRF")) { - warn!("Assuming that the uvfits antenna coordinate system is ITRF"); + "Assuming that the uvfits antenna coordinate system is ITRF".warn(); } // Because ARRAY{X,Y,Z} are defined to be the array position, the @@ -173,7 +174,7 @@ impl UvfitsReader { } if wrong_array_xyz { - warn!("It seems this uvfits file's antenna positions has been blessed by casacore. Unblessing."); + "It seems this uvfits file's antenna positions has been blessed by casacore. Unblessing.".warn(); // Get the supplied array position from the average tile // position. average_xyz.x /= tile_xyzs.len() as f64; @@ -310,8 +311,12 @@ impl UvfitsReader { dipole_gains = Some(gains2); } else { // We have no choice but to leave the order as is. - warn!("The uvfits antenna names are different to those supplied in the metafits."); - warn!("Dipole delays/gains may be incorrectly mapped to uvfits antennas."); + [ + "The uvfits antenna names are different to those supplied in the metafits." + .into(), + "Dipole delays/gains may be incorrectly mapped to uvfits antennas.".into(), + ] + .warn(); dipole_delays = Some(Delays::Full(delays)); dipole_gains = Some(gains); } @@ -523,17 +528,10 @@ impl UvfitsReader { } }; - let num_coarse_chans = mwa_coarse_chan_nums.as_ref().map(|ccs| { - NonZeroUsize::new(ccs.len()) - .expect("length is always > 0 because collection cannot be empty") - }); - let num_fine_chans_per_coarse_chan = num_coarse_chans.and_then(|num_coarse_chans| { - let total_bandwidth_hz = - *fine_chan_freqs_f64.last() - *fine_chan_freqs_f64.first() + freq_res; - NonZeroUsize::new( - (total_bandwidth_hz / num_coarse_chans.get() as f64 / freq_res).round() as usize, - ) - }); + let num_fine_chans_per_coarse_chan = { + let n = (1.28e6 / freq_res).round() as u16; + Some(NonZeroU16::new(n).expect("is not 0")) + }; match ( mwa_coarse_chan_nums.as_ref(), @@ -643,7 +641,7 @@ impl UvfitsReader { *timestamps.first(), dut1, ) { - warn!("uvfits UVWs use the other baseline convention; will conjugate incoming visibilities"); + "uvfits UVWs use the other baseline convention; will conjugate incoming visibilities".warn(); true } else { false @@ -651,6 +649,7 @@ impl UvfitsReader { }; let obs_context = ObsContext { + input_data_type: VisInputType::Uvfits, obsid, timestamps, all_timesteps, @@ -658,7 +657,7 @@ impl UvfitsReader { phase_centre, pointing_centre, array_position, - _supplied_array_position: supplied_array_position, + supplied_array_position, dut1, tile_names, tile_xyzs, @@ -698,7 +697,7 @@ impl UvfitsReader { mut crosses: Option, mut autos: Option, timestep: usize, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { let row_range_start = timestep * self.step; let row_range_end = (timestep + 1) * self.step; @@ -709,7 +708,7 @@ impl UvfitsReader { let mut uvfits_vis: Vec = vec![0.0; self.metadata.num_fine_freq_chans * NUM_POLS * NUM_FLOATS_PER_POL]; let flags = (0..self.metadata.num_fine_freq_chans) - .map(|i_chan| flagged_fine_chans.contains(&i_chan)) + .map(|i_chan| flagged_fine_chans.contains(&(i_chan as u16))) .collect::>(); for row in row_range_start..row_range_end { // Read in the row's group parameters. @@ -918,7 +917,7 @@ impl UvfitsReader { uvfits_vis .chunks_exact(NUM_POLS * NUM_FLOATS_PER_POL) .enumerate() - .filter(|(i_chan, _)| !flagged_fine_chans.contains(i_chan)) + .filter(|(i_chan, _)| !flagged_fine_chans.contains(&(*i_chan as u16))) .zip(out_vis.iter_mut()) .zip(out_weights.iter_mut()) .for_each(|(((_, in_data), out_vis), out_weight)| { @@ -1039,6 +1038,12 @@ impl VisRead for UvfitsReader { None } + fn get_raw_data_corrections(&self) -> Option { + None + } + + fn set_raw_data_corrections(&mut self, _: RawDataCorrections) {} + fn read_crosses_and_autos( &self, cross_vis_fb: ArrayViewMut2>, @@ -1047,7 +1052,7 @@ impl VisRead for UvfitsReader { auto_weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { let cross_data = Some(CrossData { vis_fb: cross_vis_fb, @@ -1084,7 +1089,7 @@ impl VisRead for UvfitsReader { weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { let cross_data = Some(CrossData { vis_fb, @@ -1115,7 +1120,7 @@ impl VisRead for UvfitsReader { weights_fb: ArrayViewMut2, timestep: usize, tile_baseline_flags: &TileBaselineFlags, - flagged_fine_chans: &HashSet, + flagged_fine_chans: &HashSet, ) -> Result<(), VisReadError> { let auto_data = Some(AutoData { vis_fb, @@ -1371,14 +1376,15 @@ impl UvfitsMetadata { Some(key) => match key.parse::() { Ok(n) => { if n.abs() > f32::EPSILON { - warn!("{pzero}, corresponding to the second DATE, was not 0; ignoring it anyway") + format!("uvfits {pzero}, corresponding to the second DATE, was not 0; ignoring it anyway").warn() } } Err(std::num::ParseFloatError { .. }) => { - warn!("Could not parse {pzero} as a float") + format!("Could not parse uvfits {pzero} as a float").warn() } }, - None => warn!("{pzero} does not exist, corresponding to the second DATE"), + None => format!("uvfits {pzero} does not exist, corresponding to the second DATE") + .warn(), } } let jd_zero = jd_zero_str @@ -1399,7 +1405,7 @@ impl UvfitsMetadata { // Don't round if the value is 0. Sigh. if jd_zero.abs() < f64::EPSILON { - warn!("PZERO{} is supposed to be non-zero!", indices.date1); + format!("uvfits PZERO{} is supposed to be non-zero!", indices.date1).warn(); e } else { e.round(1.hours()) @@ -1654,49 +1660,50 @@ impl Indices { if u_index.is_none() { u_index = Some(ii) } else { - warn!("Found another UU key -- only using the first"); + "Found another uvfits UU key -- only using the first".warn(); } } "VV" => { if v_index.is_none() { v_index = Some(ii) } else { - warn!("Found another VV key -- only using the first"); + "Found another uvfits VV key -- only using the first".warn(); } } "WW" => { if w_index.is_none() { w_index = Some(ii) } else { - warn!("Found another WW key -- only using the first"); + "Found another uvfits WW key -- only using the first".warn(); } } "BASELINE" => { if baseline_index.is_none() { baseline_index = Some(ii) } else { - warn!("Found another BASELINE key -- only using the first"); + "Found another uvfits BASELINE key -- only using the first".warn(); } } "ANTENNA1" => { if antenna1_index.is_none() { antenna1_index = Some(ii) } else { - warn!("Found another ANTENNA1 key -- only using the first"); + "Found another uvfits ANTENNA1 key -- only using the first".warn(); } } "ANTENNA2" => { if antenna2_index.is_none() { antenna2_index = Some(ii) } else { - warn!("Found another ANTENNA1 key -- only using the first"); + "Found another uvfits ANTENNA1 key -- only using the first".warn(); } } "DATE" | "_DATE" => match (date1_index, date2_index) { (None, None) => date1_index = Some(ii), (Some(_), None) => date2_index = Some(ii), (Some(_), Some(_)) => { - warn!("Found more than 2 DATE/_DATE keys -- only using the first two") + "Found more than 2 uvfits DATE/_DATE keys -- only using the first two" + .warn() } (None, Some(_)) => unreachable!(), }, @@ -1718,7 +1725,7 @@ impl Indices { (Some(index), None, None) => BaselineOrAntennas::Baseline { index }, (None, Some(index1), Some(index2)) => BaselineOrAntennas::Antennas { index1, index2 }, (Some(index), Some(_), _) | (Some(index), _, Some(_)) => { - warn!("Found both BASELINE and ANTENNA keys; only using BASELINE"); + "Found both uvfits BASELINE and ANTENNA keys; only using BASELINE".warn(); BaselineOrAntennas::Baseline { index } } // These are not. diff --git a/src/io/read/uvfits/tests.rs b/src/io/read/uvfits/tests.rs index 97714cee..33877e39 100644 --- a/src/io/read/uvfits/tests.rs +++ b/src/io/read/uvfits/tests.rs @@ -2,21 +2,21 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -use std::{collections::HashSet, ffi::CString, path::PathBuf}; +use std::ffi::CString; use approx::{assert_abs_diff_eq, assert_abs_diff_ne}; use fitsio::errors::check_status as fits_check_status; use hifitime::Duration; use itertools::Itertools; -use marlu::{c32, LatLngHeight, RADec, UvfitsWriter, VisContext, VisWrite, XyzGeodetic}; -use ndarray::prelude::*; +use marlu::{Jones, UvfitsWriter, VisContext, VisWrite, XyzGeodetic}; +use num_complex::Complex32 as c32; use tempfile::{tempdir, NamedTempFile}; use super::*; use crate::{ - di_calibrate::{get_cal_vis, tests::test_1090008640_quality}, + io::read::{UvfitsReader, VisRead}, math::TileBaselineFlags, - tests::reduced_obsids::get_reduced_1090008640_uvfits, + tests::{get_reduced_1090008640_uvfits_pbs, DataAsPathBufs}, }; // TODO(dev): move these to Marlu @@ -239,20 +239,15 @@ fn uvfits_io_works_for_auto_correlations() { #[test] fn test_1090008640_cross_vis() { - let args = get_reduced_1090008640_uvfits(); - let uvfits_reader = if let [metafits, uvfits] = &args.data.unwrap()[..] { - match UvfitsReader::new(PathBuf::from(uvfits), Some(&PathBuf::from(metafits)), None) { - Ok(u) => u, - Err(e) => panic!("{}", e), - } - } else { - panic!("There weren't 2 elements in args.data"); - }; + let DataAsPathBufs { + metafits, mut vis, .. + } = get_reduced_1090008640_uvfits_pbs(); + let uvfits_reader = UvfitsReader::new(vis.swap_remove(0), Some(&metafits), None).unwrap(); - let obs_context = &uvfits_reader.obs_context; + let obs_context = uvfits_reader.get_obs_context(); let total_num_tiles = obs_context.tile_xyzs.len(); let num_baselines = (total_num_tiles * (total_num_tiles - 1)) / 2; - let num_chans = obs_context.num_fine_chans_per_coarse_chan.unwrap().get(); + let num_chans = usize::from(obs_context.num_fine_chans_per_coarse_chan.unwrap().get()); let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, HashSet::new()); assert_abs_diff_eq!( @@ -303,19 +298,14 @@ fn test_1090008640_cross_vis() { #[test] fn test_1090008640_auto_vis() { - let args = get_reduced_1090008640_uvfits(); - let uvfits_reader = if let [metafits, uvfits] = &args.data.unwrap()[..] { - match UvfitsReader::new(PathBuf::from(uvfits), Some(&PathBuf::from(metafits)), None) { - Ok(u) => u, - Err(e) => panic!("{}", e), - } - } else { - panic!("There weren't 2 elements in args.data"); - }; + let DataAsPathBufs { + metafits, mut vis, .. + } = get_reduced_1090008640_uvfits_pbs(); + let uvfits_reader = UvfitsReader::new(vis.swap_remove(0), Some(&metafits), None).unwrap(); - let obs_context = &uvfits_reader.obs_context; + let obs_context = uvfits_reader.get_obs_context(); let total_num_tiles = obs_context.get_total_num_tiles(); - let num_chans = obs_context.num_fine_chans_per_coarse_chan.unwrap().get(); + let num_chans = usize::from(obs_context.num_fine_chans_per_coarse_chan.unwrap().get()); let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, HashSet::new()); assert_abs_diff_eq!( @@ -392,19 +382,14 @@ fn test_1090008640_auto_vis() { #[test] fn test_1090008640_auto_vis_with_flags() { - let args = get_reduced_1090008640_uvfits(); - let uvfits_reader = if let [metafits, uvfits] = &args.data.unwrap()[..] { - match UvfitsReader::new(PathBuf::from(uvfits), Some(&PathBuf::from(metafits)), None) { - Ok(u) => u, - Err(e) => panic!("{}", e), - } - } else { - panic!("There weren't 2 elements in args.data"); - }; + let DataAsPathBufs { + metafits, mut vis, .. + } = get_reduced_1090008640_uvfits_pbs(); + let uvfits_reader = UvfitsReader::new(vis.swap_remove(0), Some(&metafits), None).unwrap(); - let obs_context = &uvfits_reader.obs_context; + let obs_context = uvfits_reader.get_obs_context(); let total_num_tiles = obs_context.get_total_num_tiles(); - let num_chans = obs_context.num_fine_chans_per_coarse_chan.unwrap().get(); + let num_chans = usize::from(obs_context.num_fine_chans_per_coarse_chan.unwrap().get()); let flagged_tiles = HashSet::from([1, 9]); let num_unflagged_tiles = total_num_tiles - flagged_tiles.len(); let chan_flags = HashSet::from([1]); @@ -489,20 +474,15 @@ fn test_1090008640_auto_vis_with_flags() { #[test] fn read_1090008640_cross_and_auto_vis() { - let args = get_reduced_1090008640_uvfits(); - let uvfits_reader = if let [metafits, uvfits] = &args.data.unwrap()[..] { - match UvfitsReader::new(PathBuf::from(uvfits), Some(&PathBuf::from(metafits)), None) { - Ok(u) => u, - Err(e) => panic!("{}", e), - } - } else { - panic!("There weren't 2 elements in args.data"); - }; + let DataAsPathBufs { + metafits, mut vis, .. + } = get_reduced_1090008640_uvfits_pbs(); + let uvfits_reader = UvfitsReader::new(vis.swap_remove(0), Some(&metafits), None).unwrap(); - let obs_context = &uvfits_reader.obs_context; + let obs_context = uvfits_reader.get_obs_context(); let total_num_tiles = obs_context.get_total_num_tiles(); let num_baselines = (total_num_tiles * (total_num_tiles - 1)) / 2; - let num_chans = obs_context.num_fine_chans_per_coarse_chan.unwrap().get(); + let num_chans = usize::from(obs_context.num_fine_chans_per_coarse_chan.unwrap().get()); let tile_baseline_flags = TileBaselineFlags::new(total_num_tiles, HashSet::new()); assert_abs_diff_eq!( @@ -610,24 +590,6 @@ fn read_1090008640_cross_and_auto_vis() { ); } -#[test] -fn test_1090008640_calibration_quality() { - let mut args = get_reduced_1090008640_uvfits(); - let temp_dir = tempdir().expect("Couldn't make temp dir"); - args.outputs = Some(vec![temp_dir.path().join("hyp_sols.fits")]); - // To be consistent with other data quality tests, add these flags. - args.fine_chan_flags = Some(vec![0, 1, 2, 16, 30, 31]); - - let result = args.into_params(); - let params = match result { - Ok(r) => r, - Err(e) => panic!("{}", e), - }; - - let cal_vis = get_cal_vis(¶ms, false).expect("Couldn't read data and generate a model"); - test_1090008640_quality(params, cal_vis); -} - #[test] fn test_timestep_reading() { let temp_dir = tempdir().expect("Couldn't make temp dir"); @@ -693,7 +655,7 @@ fn test_timestep_reading() { writer.finalise().unwrap(); - let uvfits_reader = UvfitsReader::new(vis_path, None, Some(array_pos)).unwrap(); + let uvfits_reader = UvfitsReader::new(vis_path, None, None).unwrap(); let uvfits_ctx = uvfits_reader.get_obs_context(); let expected_timestamps = (0..num_timesteps) diff --git a/src/io/write/mod.rs b/src/io/write/mod.rs index 743b9ceb..2f1a8887 100644 --- a/src/io/write/mod.rs +++ b/src/io/write/mod.rs @@ -11,6 +11,7 @@ pub(crate) use error::{FileWriteError, VisWriteError}; use std::{ collections::HashSet, + num::NonZeroUsize, path::{Path, PathBuf}, }; @@ -19,7 +20,7 @@ use crossbeam_utils::atomic::AtomicCell; use hifitime::{Duration, Epoch}; use indicatif::ProgressBar; use itertools::Itertools; -use log::{debug, trace, warn}; +use log::{debug, trace}; use marlu::{ math::num_tiles_from_num_baselines, History, Jones, LatLngHeight, MeasurementSetWriter, MwaObsContext as MarluMwaObsContext, ObsContext as MarluObsContext, RADec, UvfitsWriter, @@ -28,9 +29,12 @@ use marlu::{ use ndarray::{prelude::*, ArcArray2}; use strum::IntoEnumIterator; use strum_macros::{Display, EnumIter, EnumString}; -use vec1::Vec1; +use vec1::{vec1, Vec1}; -use crate::averaging::Timeblock; +use crate::{ + averaging::{Spw, Timeblock}, + cli::Warn, +}; #[derive(Debug, Display, EnumIter, EnumString, Clone, Copy)] /// All write-supported visibility formats. @@ -68,10 +72,7 @@ pub(crate) struct VisTimestep { /// /// * `outputs` - each of the output files to be written, paired with the output /// type. -/// * `unflagged_baseline_tile_pairs` - the tile indices corresponding to -/// unflagged baselines. This includes auto-correlation "baselines" if they -/// are unflagged. -/// * `array_pos` - the position of the array that produced these visibilities. +/// * `array_pos` - the position of the array for the incoming visibilities. /// * `phase_centre` - the phase centre used for the incoming visibilities. /// * `pointing_centre` - the pointing centre used for the incoming /// visibilities. @@ -83,16 +84,14 @@ pub(crate) struct VisTimestep { /// start time of the observation and as an identifier. If not provided, the /// first timestep is used as the scheduled start time and a placeholder will /// be used for the identifier. -/// * `timestamps` - all possible timestamps that could be written out. These -/// represent the centre of the integration bin, i.e. "centroid" and not -/// "leading edge". Must be ascendingly sorted and be regularly spaced in -/// terms of `time_res`, but gaps are allowed. -/// * `timesteps` - the timesteps to be written out. These are indices into -/// `timestamps`. +/// * `timeblocks` - the details on all incoming timestamps and how to combine +/// them. /// * `time_res` - the time resolution of the incoming visibilities. -/// * `fine_chan_freqs` - all of the fine channel frequencies \[Hz\] (flagged -/// and unflagged). -/// * `freq_res` - the frequency resolution of the incoming visibilities \[Hz\]. +/// * `dut1` - the DUT1 to use in the UVWs of the outgoing visibilities. +/// * `spw` - the spectral window information of the outgoing visibilities. +/// * `unflagged_baseline_tile_pairs` - the tile indices corresponding to +/// unflagged baselines. This includes auto-correlation "baselines" if they +/// are unflagged. /// * `time_average_factor` - the time average factor (i.e. average this many /// visibilities in time before writing out). /// * `freq_average_factor` - the frequency average factor (i.e. average this @@ -111,67 +110,99 @@ pub(crate) struct VisTimestep { /// /// * A neatly-formatted string reporting all of the files that got written out. #[allow(clippy::too_many_arguments)] -pub(crate) fn write_vis<'a>( - outputs: &'a Vec1<(PathBuf, VisOutputType)>, +pub(crate) fn write_vis( + outputs: &Vec1<(PathBuf, VisOutputType)>, array_pos: LatLngHeight, phase_centre: RADec, pointing_centre: Option, - tile_positions: &'a [XyzGeodetic], - tile_names: &'a [String], + tile_positions: &[XyzGeodetic], + tile_names: &[String], obsid: Option, - timestamps: &'a Vec1, - timesteps: &'a Vec1, - timeblocks: &'a Vec1, + timeblocks: &Vec1, time_res: Duration, dut1: Duration, - freq_res: f64, - fine_chan_freqs: &'a Vec1, - unflagged_baseline_tile_pairs: &'a [(usize, usize)], - flagged_fine_chans: &HashSet, - time_average_factor: usize, - freq_average_factor: usize, + spw: &Spw, + unflagged_baseline_tile_pairs: &[(usize, usize)], + time_average_factor: NonZeroUsize, + freq_average_factor: NonZeroUsize, marlu_mwa_obs_context: Option<&MarluMwaObsContext>, + write_smallest_contiguous_band: bool, rx: Receiver, - error: &'a AtomicCell, + error: &AtomicCell, progress_bar: Option, ) -> Result { - // Ensure our timestamps are sensible. - for &t in timestamps { - let diff = (t - *timestamps.first()).total_nanoseconds(); + // Ensure our timestamps are regularly spaced in terms of `time_res`. + for t in timeblocks { + let diff = (t.median - timeblocks.first().median).total_nanoseconds(); if diff % time_res.total_nanoseconds() > 0 { return Err(VisWriteError::IrregularTimestamps { - first: timestamps.first().to_gpst_seconds(), - bad: t.to_gpst_seconds(), + first: timeblocks.first().median.to_gpst_seconds(), + bad: t.median.to_gpst_seconds(), time_res: time_res.to_seconds(), }); } } - let start_timestamp = timestamps[*timesteps.first()]; + // When writing out visibility data, the frequency axis *must* be + // contiguous. But, the incoming visibility data might not be contiguous due + // to flags. Set up the outgoing frequencies and set a flag so we know if + // the incoming data needs to be padded. + let chanblock_freqs = if write_smallest_contiguous_band { + match spw.chanblocks.as_slice() { + [] => panic!("There weren't any unflagged chanblocks in the SPW"), + [c] => vec1![c.freq], + [c1, .., cn] => { + let first_freq = c1.freq; + let last_freq = cn.freq; + let mut v = Array1::range(first_freq, last_freq, spw.freq_res).into_raw_vec(); + v.push(last_freq); // `Array1::range` is an exclusive range. + Vec1::try_from_vec(v).expect("v is never empty") + } + } + } else { + spw.get_all_freqs() + }; + let missing_chanblocks = { + let mut missing = HashSet::new(); + let incoming_chanblock_freqs = spw + .chanblocks + .iter() + .map(|c| c.freq as u64) + .collect::>(); + for (i_chanblock, chanblock_freq) in (0..).zip(chanblock_freqs.iter()) { + let chanblock_freq = *chanblock_freq as u64; + if !incoming_chanblock_freqs.contains(&chanblock_freq) { + missing.insert(i_chanblock); + } + } + missing + }; + + let start_timestamp = timeblocks.first().median; + let num_baselines = unflagged_baseline_tile_pairs.len(); let vis_ctx = VisContext { - num_sel_timesteps: timeblocks.len() * time_average_factor, + num_sel_timesteps: timeblocks.len() * time_average_factor.get(), start_timestamp, int_time: time_res, - num_sel_chans: fine_chan_freqs.len(), - start_freq_hz: *fine_chan_freqs.first(), - freq_resolution_hz: freq_res, + num_sel_chans: chanblock_freqs.len(), + start_freq_hz: *chanblock_freqs.first(), + freq_resolution_hz: spw.freq_res, sel_baselines: unflagged_baseline_tile_pairs.to_vec(), - avg_time: time_average_factor, - avg_freq: freq_average_factor, + avg_time: time_average_factor.get(), + avg_freq: freq_average_factor.get(), num_vis_pols: 4, }; - let obs_name = obsid.map(|o| format!("{o}")); let sched_start_timestamp = match obsid { Some(gpst) => Epoch::from_gpst_seconds(f64::from(gpst)), None => start_timestamp, }; - let sched_duration = timestamps[*timesteps.last()] + time_res - sched_start_timestamp; + let sched_duration = timeblocks.last().median + time_res - sched_start_timestamp; let (s_lat, c_lat) = array_pos.latitude_rad.sin_cos(); let marlu_obs_ctx = MarluObsContext { sched_start_timestamp, sched_duration, - name: obs_name, + name: obsid.map(|o| format!("{o}")), phase_centre, pointing_centre, array_pos, @@ -249,14 +280,18 @@ pub(crate) fn write_vis<'a>( // These arrays will contain the post-averaged values and are written out by // the writer when all relevant timesteps have been added. // [time][freq][baseline] - let out_shape = vis_ctx.sel_dims(); - let mut out_data = Array3::zeros((time_average_factor, out_shape.1, out_shape.2)); - let mut out_weights = Array3::from_elem((time_average_factor, out_shape.1, out_shape.2), -0.0); + let out_shape = ( + timeblocks.len() * time_average_factor.get(), + chanblock_freqs.len(), + num_baselines, + ); + let mut out_data_tfb = Array3::zeros((time_average_factor.get(), out_shape.1, out_shape.2)); + let mut out_weights_tfb = + Array3::from_elem((time_average_factor.get(), out_shape.1, out_shape.2), -0.0); // Track a reference to the timeblock we're writing. let mut this_timeblock = timeblocks.first(); // Also track the first timestamp of the tracked timeblock. - // let mut this_start_timestamp = None; let mut this_average_timestamp = None; let mut i_timeblock = 0; // And the timestep into the timeblock. @@ -266,8 +301,8 @@ pub(crate) fn write_vis<'a>( for ( i_timestep, VisTimestep { - cross_data_fb: cross_data, - cross_weights_fb: cross_weights, + cross_data_fb, + cross_weights_fb, autos, timestamp, }, @@ -281,7 +316,7 @@ pub(crate) fn write_vis<'a>( this_average_timestamp = Some( timeblocks .iter() - .find(|tb| tb.timestamps.contains(×tamp)) + .find(|tb| tb.timestamps.iter().any(|e| *e == timestamp)) .unwrap() .median, ); @@ -290,8 +325,9 @@ pub(crate) fn write_vis<'a>( if let Some(autos) = autos.as_ref() { // Get the number of tiles from the lengths of the cross and auto // arrays. - let num_cross_baselines = cross_data.len_of(Axis(1)); + let num_cross_baselines = cross_data_fb.len_of(Axis(1)); let num_auto_baselines = autos.0.len_of(Axis(1)); + assert_eq!(num_cross_baselines + num_auto_baselines, num_baselines); let num_tiles = num_tiles_from_num_baselines(num_cross_baselines + num_auto_baselines); assert_eq!( (num_tiles * (num_tiles + 1)) / 2, @@ -301,79 +337,78 @@ pub(crate) fn write_vis<'a>( // baseline assert_eq!(num_cross_baselines + num_auto_baselines, out_shape.2); // freq - assert_eq!( - cross_data.len_of(Axis(0)) + flagged_fine_chans.len(), - out_shape.1 - ); - assert_eq!(cross_data.len_of(Axis(0)), autos.0.len_of(Axis(0))); + assert_eq!(cross_data_fb.len_of(Axis(0)), autos.0.len_of(Axis(0))); } else { // baseline - assert_eq!(cross_data.len_of(Axis(1)), out_shape.2); - // freq - assert_eq!( - cross_data.len_of(Axis(0)) + flagged_fine_chans.len(), - out_shape.1 - ); + assert_eq!(cross_data_fb.len_of(Axis(1)), out_shape.2); } + // freq + assert_eq!(cross_data_fb.len_of(Axis(0)), spw.chanblocks.len()); // Pack `out_data` and `out_weights`. Start with cross-correlation data, // skipping any auto-correlation indices; we'll fill them soon. - out_data + out_data_tfb .slice_mut(s![this_timestep, .., ..]) .outer_iter_mut() .zip_eq( - out_weights + out_weights_tfb .slice_mut(s![this_timestep, .., ..]) .outer_iter_mut(), ) .enumerate() - .filter(|(i_chan, _)| !flagged_fine_chans.contains(i_chan)) + .filter(|(i_chan, _)| !missing_chanblocks.contains(&(*i_chan as u16))) // Discard the channel index - .map(|(_, t)| t) - .zip_eq(cross_data.outer_iter()) - .zip_eq(cross_weights.outer_iter()) - .for_each(|(((mut out_data, mut out_weights), in_data), in_weights)| { - out_data - .iter_mut() - .zip_eq(out_weights.iter_mut()) - .zip_eq(unflagged_baseline_tile_pairs.iter()) - .filter(|(_, baseline)| baseline.0 != baseline.1) - .zip_eq(in_data.iter()) - .zip_eq(in_weights.iter()) - .for_each(|((((out_jones, out_weight), _), in_jones), in_weight)| { - *out_jones = *in_jones; - *out_weight = *in_weight; - }); - }); - // Autos. - if let Some((auto_data, auto_weights)) = autos { - out_data - .slice_mut(s![this_timestep, .., ..]) - .outer_iter_mut() - .zip_eq( - out_weights - .slice_mut(s![this_timestep, .., ..]) - .outer_iter_mut(), - ) - .enumerate() - .filter(|(i_chan, _)| !flagged_fine_chans.contains(i_chan)) - // Discard the channel index - .map(|(_, t)| t) - .zip_eq(auto_data.outer_iter()) - .zip_eq(auto_weights.outer_iter()) - .for_each(|(((mut out_data, mut out_weights), in_data), in_weights)| { - out_data + .map(|(_, d)| d) + .zip_eq(cross_data_fb.outer_iter()) + .zip_eq(cross_weights_fb.outer_iter()) + .for_each( + |(((mut out_data_b, mut out_weights_b), in_data_b), in_weights_b)| { + out_data_b .iter_mut() - .zip_eq(out_weights.iter_mut()) + .zip_eq(out_weights_b.iter_mut()) .zip_eq(unflagged_baseline_tile_pairs.iter()) - .filter(|(_, baseline)| baseline.0 == baseline.1) - .zip_eq(in_data.iter()) - .zip_eq(in_weights.iter()) + .filter(|(_, baseline)| baseline.0 != baseline.1) + .zip_eq(in_data_b.iter()) + .zip_eq(in_weights_b.iter()) .for_each(|((((out_jones, out_weight), _), in_jones), in_weight)| { *out_jones = *in_jones; *out_weight = *in_weight; }); - }); + }, + ); + // Autos. + if let Some((auto_data_fb, auto_weights_fb)) = autos { + (0..) + .zip( + out_data_tfb + .slice_mut(s![this_timestep, .., ..]) + .outer_iter_mut(), + ) + .zip( + out_weights_tfb + .slice_mut(s![this_timestep, .., ..]) + .outer_iter_mut(), + ) + .filter(|((i_chan, _), _)| !missing_chanblocks.contains(i_chan)) + // Discard the channel index + .map(|((_, d), w)| (d, w)) + .zip_eq(auto_data_fb.outer_iter()) + .zip_eq(auto_weights_fb.outer_iter()) + .for_each( + |(((mut out_data_b, mut out_weights_b), in_data_b), in_weights_b)| { + out_data_b + .iter_mut() + .zip_eq(out_weights_b.iter_mut()) + .zip_eq(unflagged_baseline_tile_pairs.iter()) + .filter(|(_, baseline)| baseline.0 == baseline.1) + .zip_eq(in_data_b.iter()) + .zip_eq(in_weights_b.iter()) + .for_each(|((((out_jones, out_weight), _), in_jones), in_weight)| { + *out_jones = *in_jones; + *out_weight = *in_weight; + }); + }, + ); } // Should we continue? @@ -384,22 +419,22 @@ pub(crate) fn write_vis<'a>( // If the next timestep doesn't belong to our tracked timeblock, write // out this timeblock and track the next one. if !this_timeblock.range.contains(&(i_timestep + 1)) - || this_timestep + 1 >= time_average_factor + || this_timestep + 1 >= time_average_factor.get() { debug!("Writing timeblock {i_timeblock}"); let chunk_vis_ctx = VisContext { // TODO: Marlu expects "leading edge" timestamps, not centroids. // Fix this in Marlu. start_timestamp: this_average_timestamp.unwrap() - - time_res / 2 * time_average_factor as f64, + - time_res / 2 * time_average_factor.get() as f64, num_sel_timesteps: this_timeblock.range.len(), ..vis_ctx.clone() }; for vis_writer in writers.iter_mut() { vis_writer.write_vis( - out_data.slice(s![0..this_timeblock.range.len(), .., ..]), - out_weights.slice(s![0..this_timeblock.range.len(), .., ..]), + out_data_tfb.slice(s![0..this_timeblock.range.len(), .., ..]), + out_weights_tfb.slice(s![0..this_timeblock.range.len(), .., ..]), &chunk_vis_ctx, )?; // Should we continue? @@ -413,8 +448,8 @@ pub(crate) fn write_vis<'a>( } // Clear the output buffers. - out_data.fill(Jones::default()); - out_weights.fill(-0.0); + out_data_tfb.fill(Jones::default()); + out_weights_tfb.fill(-0.0); i_timeblock += 1; this_timeblock = match timeblocks.get(i_timeblock) { @@ -464,12 +499,12 @@ pub(crate) fn can_write_to_file(file: &Path) -> Result<(), FileWriteError> { if file.is_dir() { let exists = can_write_to_dir(file)?; if exists { - warn!("Will overwrite the existing directory '{}'", file.display()); + format!("Will overwrite the existing directory '{}'", file.display()).warn(); } } else { let exists = can_write_to_file_inner(file)?; if exists { - warn!("Will overwrite the existing file '{}'", file.display()); + format!("Will overwrite the existing file '{}'", file.display()).warn(); } } diff --git a/src/io/write/tests.rs b/src/io/write/tests.rs index 4070794e..fba2d371 100644 --- a/src/io/write/tests.rs +++ b/src/io/write/tests.rs @@ -14,7 +14,7 @@ use vec1::{vec1, Vec1}; use super::*; use crate::{ - averaging::timesteps_to_timeblocks, + averaging::{channels_to_chanblocks, timesteps_to_timeblocks}, io::read::{MsReader, UvfitsReader, VisRead}, math::TileBaselineFlags, }; @@ -46,8 +46,8 @@ fn synthesize_test_data( #[test] #[serial] fn test_vis_output_no_time_averaging_no_gaps() { - let vis_time_average_factor = 1; - let vis_freq_average_factor = 1; + let vis_time_average_factor = NonZeroUsize::new(1).unwrap(); + let vis_freq_average_factor = NonZeroUsize::new(1).unwrap(); let num_timesteps = 5; let num_channels = 10; @@ -64,23 +64,35 @@ fn test_vis_output_no_time_averaging_no_gaps() { .collect(), ) .unwrap(); - let timeblocks = timesteps_to_timeblocks(×tamps, vis_time_average_factor, ×teps); - - let freq_res = 10e3; + let timeblocks = timesteps_to_timeblocks( + ×tamps, + time_res, + vis_time_average_factor, + Some(×teps), + ); + + let freq_res = 10e3 as u64; let fine_chan_freqs = Vec1::try_from_vec( - (0..num_channels) - .map(|i| 150e6 + freq_res * i as f64) + (0..num_channels as u64) + .map(|i| 150_000_000 + freq_res * i) .collect(), ) .unwrap(); + let spw = &channels_to_chanblocks( + &fine_chan_freqs, + freq_res, + NonZeroUsize::new(1).unwrap(), + &HashSet::new(), + )[0]; + let vis_ctx = VisContext { num_sel_timesteps: timesteps.len(), start_timestamp, int_time: time_res, num_sel_chans: num_channels, start_freq_hz: 128_000_000., - freq_resolution_hz: freq_res, + freq_resolution_hz: freq_res as f64, sel_baselines: ant_pairs.clone(), avg_time: 1, avg_freq: 1, @@ -144,18 +156,15 @@ fn test_vis_output_no_time_averaging_no_gaps() { &tile_xyzs, &tile_names, Some(obsid), - ×tamps, - ×teps, &timeblocks, time_res, - Duration::from_seconds(0.0), - freq_res, - &fine_chan_freqs, + Duration::default(), + spw, &ant_pairs, - &HashSet::new(), vis_time_average_factor, vis_freq_average_factor, marlu_mwa_obs_context, + false, rx, &error, None, @@ -182,11 +191,9 @@ fn test_vis_output_no_time_averaging_no_gaps() { // Read the visibilities in and check everything is fine. for (path, vis_type) in out_vis_paths { let reader: Box = match vis_type { - VisOutputType::Uvfits => { - Box::new(UvfitsReader::new(path.to_path_buf(), None, None).unwrap()) - } + VisOutputType::Uvfits => Box::new(UvfitsReader::new(path, None, None).unwrap()), VisOutputType::MeasurementSet => { - Box::new(MsReader::new(path.to_path_buf(), None, None, None).unwrap()) + Box::new(MsReader::new(path, None, None, None).unwrap()) } }; let obs_context = reader.get_obs_context(); @@ -208,7 +215,7 @@ fn test_vis_output_no_time_averaging_no_gaps() { ); assert_eq!(obs_context.time_res, Some(time_res)); - assert_eq!(obs_context.freq_res, Some(freq_res)); + assert_eq!(obs_context.freq_res, Some(freq_res as f64)); let avg_shape = ( obs_context.fine_chan_freqs.len(), @@ -216,8 +223,8 @@ fn test_vis_output_no_time_averaging_no_gaps() { ); let mut avg_data = Array2::zeros(avg_shape); let mut avg_weights = Array2::zeros(avg_shape); - let flagged_fine_chans: HashSet = - obs_context.flagged_fine_chans.iter().cloned().collect(); + let flagged_fine_chans: HashSet = + obs_context.flagged_fine_chans.iter().copied().collect(); for i_timestep in 0..timesteps.len() { reader @@ -242,8 +249,8 @@ fn test_vis_output_no_time_averaging_no_gaps() { #[test] #[serial] fn test_vis_output_no_time_averaging_with_gaps() { - let vis_time_average_factor = 1; - let vis_freq_average_factor = 1; + let vis_time_average_factor = NonZeroUsize::new(1).unwrap(); + let vis_freq_average_factor = NonZeroUsize::new(1).unwrap(); let num_timestamps = 10; let num_channels = 10; @@ -260,23 +267,35 @@ fn test_vis_output_no_time_averaging_with_gaps() { .collect(), ) .unwrap(); - let timeblocks = timesteps_to_timeblocks(×tamps, vis_time_average_factor, ×teps); - - let freq_res = 10e3; + let timeblocks = timesteps_to_timeblocks( + ×tamps, + time_res, + vis_time_average_factor, + Some(×teps), + ); + + let freq_res = 10e3 as u64; let fine_chan_freqs = Vec1::try_from_vec( - (0..num_channels) - .map(|i| 150e6 + freq_res * i as f64) + (0..num_channels as u64) + .map(|i| 150_000_000 + freq_res * i) .collect(), ) .unwrap(); + let spw = &channels_to_chanblocks( + &fine_chan_freqs, + freq_res, + NonZeroUsize::new(1).unwrap(), + &HashSet::new(), + )[0]; + let vis_ctx = VisContext { num_sel_timesteps: timesteps.len(), start_timestamp, int_time: time_res, num_sel_chans: num_channels, start_freq_hz: 128_000_000., - freq_resolution_hz: freq_res, + freq_resolution_hz: freq_res as f64, sel_baselines: ant_pairs.clone(), avg_time: 1, avg_freq: 1, @@ -340,18 +359,15 @@ fn test_vis_output_no_time_averaging_with_gaps() { &tile_xyzs, &tile_names, Some(obsid), - ×tamps, - ×teps, &timeblocks, time_res, - Duration::from_seconds(0.0), - freq_res, - &fine_chan_freqs, + Duration::default(), + spw, &ant_pairs, - &HashSet::new(), vis_time_average_factor, vis_freq_average_factor, marlu_mwa_obs_context, + false, rx, &error, None, @@ -380,11 +396,9 @@ fn test_vis_output_no_time_averaging_with_gaps() { let timesteps = [0, 1, 2]; for (path, vis_type) in out_vis_paths { let reader: Box = match vis_type { - VisOutputType::Uvfits => { - Box::new(UvfitsReader::new(path.to_path_buf(), None, None).unwrap()) - } + VisOutputType::Uvfits => Box::new(UvfitsReader::new(path, None, None).unwrap()), VisOutputType::MeasurementSet => { - Box::new(MsReader::new(path.to_path_buf(), None, None, None).unwrap()) + Box::new(MsReader::new(path, None, None, None).unwrap()) } }; let obs_context = reader.get_obs_context(); @@ -403,7 +417,7 @@ fn test_vis_output_no_time_averaging_with_gaps() { expected.mapped_ref(|t| t.to_gpst_seconds()) ); assert_eq!(obs_context.time_res, Some(time_res)); - assert_eq!(obs_context.freq_res, Some(freq_res)); + assert_eq!(obs_context.freq_res, Some(freq_res as f64)); let avg_shape = ( obs_context.fine_chan_freqs.len(), @@ -411,8 +425,8 @@ fn test_vis_output_no_time_averaging_with_gaps() { ); let mut avg_data = Array2::zeros(avg_shape); let mut avg_weights = Array2::zeros(avg_shape); - let flagged_fine_chans: HashSet = - obs_context.flagged_fine_chans.iter().cloned().collect(); + let flagged_fine_chans: HashSet = + obs_context.flagged_fine_chans.iter().copied().collect(); for i_timestep in 0..timesteps.len() { reader @@ -437,8 +451,8 @@ fn test_vis_output_no_time_averaging_with_gaps() { #[test] #[serial] fn test_vis_output_time_averaging() { - let vis_time_average_factor = 3; - let vis_freq_average_factor = 1; + let vis_time_average_factor = NonZeroUsize::new(3).unwrap(); + let vis_freq_average_factor = NonZeroUsize::new(1).unwrap(); let num_timestamps = 10; let num_channels = 10; @@ -457,23 +471,35 @@ fn test_vis_output_time_averaging() { .collect(), ) .unwrap(); - let timeblocks = timesteps_to_timeblocks(×tamps, vis_time_average_factor, ×teps); - - let freq_res = 10e3; + let timeblocks = timesteps_to_timeblocks( + ×tamps, + time_res, + vis_time_average_factor, + Some(×teps), + ); + + let freq_res = 10e3 as u64; let fine_chan_freqs = Vec1::try_from_vec( - (0..num_channels) - .map(|i| 150e6 + freq_res * i as f64) + (0..num_channels as u64) + .map(|i| 150_000_000 + freq_res * i) .collect(), ) .unwrap(); + let spw = &channels_to_chanblocks( + &fine_chan_freqs, + freq_res, + NonZeroUsize::new(1).unwrap(), + &HashSet::new(), + )[0]; + let vis_ctx = VisContext { num_sel_timesteps: timesteps.len(), start_timestamp, int_time: time_res, num_sel_chans: num_channels, start_freq_hz: 128_000_000., - freq_resolution_hz: freq_res, + freq_resolution_hz: freq_res as f64, sel_baselines: ant_pairs.clone(), avg_time: 1, avg_freq: 1, @@ -540,18 +566,15 @@ fn test_vis_output_time_averaging() { &tile_xyzs, &tile_names, Some(obsid), - ×tamps, - ×teps, &timeblocks, time_res, - Duration::from_seconds(0.0), - freq_res, - &fine_chan_freqs, + Duration::default(), + spw, &ant_pairs, - &HashSet::new(), vis_time_average_factor, vis_freq_average_factor, marlu_mwa_obs_context, + false, rx, &error, None, @@ -580,11 +603,9 @@ fn test_vis_output_time_averaging() { let timesteps = [0, 1]; for (path, vis_type) in out_vis_paths { let reader: Box = match vis_type { - VisOutputType::Uvfits => { - Box::new(UvfitsReader::new(path.to_path_buf(), None, None).unwrap()) - } + VisOutputType::Uvfits => Box::new(UvfitsReader::new(path, None, None).unwrap()), VisOutputType::MeasurementSet => { - Box::new(MsReader::new(path.to_path_buf(), None, None, None).unwrap()) + Box::new(MsReader::new(path, None, None, None).unwrap()) } }; let obs_context = reader.get_obs_context(); @@ -602,7 +623,7 @@ fn test_vis_output_time_averaging() { expected.mapped_ref(|t| t.to_gpst_seconds()) ); assert_eq!(obs_context.time_res, Some(Duration::from_seconds(3.0))); - assert_eq!(obs_context.freq_res, Some(freq_res)); + assert_eq!(obs_context.freq_res, Some(freq_res as f64)); let avg_shape = ( obs_context.fine_chan_freqs.len(), @@ -610,8 +631,8 @@ fn test_vis_output_time_averaging() { ); let mut avg_data = Array2::zeros(avg_shape); let mut avg_weights = Array2::zeros(avg_shape); - let flagged_fine_chans: HashSet = - obs_context.flagged_fine_chans.iter().cloned().collect(); + let flagged_fine_chans: HashSet = + obs_context.flagged_fine_chans.iter().copied().collect(); for i_timestep in 0..timesteps.len() { reader diff --git a/src/lib.rs b/src/lib.rs index 3688e728..ca110e61 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,49 +7,57 @@ //! //! -pub mod averaging; -pub mod beam; +mod averaging; +mod beam; mod cli; -pub(crate) mod constants; -pub(crate) mod context; -pub mod di_calibrate; -pub(crate) mod error; -pub(crate) mod filenames; -pub(crate) mod flagging; -mod help_texts; -pub(crate) mod io; -pub(crate) mod math; -pub(crate) mod messages; -pub(crate) mod metafits; -pub(crate) mod misc; +mod constants; +mod context; +mod di_calibrate; +mod flagging; +mod io; +mod math; +mod metafits; +mod misc; pub mod model; -pub(crate) mod solutions; +mod params; +mod solutions; pub mod srclist; -pub(crate) mod unit_parsing; +mod unit_parsing; #[cfg(feature = "cuda")] -pub(crate) mod cuda; +mod cuda; #[cfg(test)] mod tests; +use crossbeam_utils::atomic::AtomicCell; +lazy_static::lazy_static! { + /// Are progress bars being drawn? This should only ever be enabled by CLI + /// code. + static ref PROGRESS_BARS: AtomicCell = AtomicCell::new(false); + + /// What device (GPU or CPU) are we using for modelling and beam responses? + /// This should only ever be changed from its default by CLI code. + static ref MODEL_DEVICE: AtomicCell = { + cfg_if::cfg_if! { + if #[cfg(feature = "cuda")] { + AtomicCell::new(ModelDevice::Cuda) + } else { + AtomicCell::new(ModelDevice::Cpu) + } + } + }; +} + // Re-exports. -pub use cli::{ - di_calibrate::DiCalArgs, - dipole_gains::DipoleGainsArgs, - solutions::{ - apply::SolutionsApplyArgs, convert::SolutionsConvertArgs, plot::SolutionsPlotArgs, - }, - srclist::{ - by_beam::SrclistByBeamArgs, convert::SrclistConvertArgs, shift::SrclistShiftArgs, - verify::SrclistVerifyArgs, - }, - vis_utils::{simulate::VisSimulateArgs, subtract::VisSubtractArgs}, -}; +pub use averaging::{Chanblock, Timeblock}; +pub use beam::{create_fee_beam_object, Delays}; +#[doc(hidden)] +pub use cli::Hyperdrive; +pub use cli::HyperdriveError; pub use context::Polarisations; -pub use error::HyperdriveError; -pub use io::read::{ - AutoData, CrossData, MsReader, RawDataCorrections, RawDataReader, UvfitsReader, -}; +pub use di_calibrate::calibrate_timeblocks; +pub use io::read::{CrossData, MsReader, RawDataCorrections, RawDataReader, UvfitsReader}; pub use math::TileBaselineFlags; +pub use model::ModelDevice; pub use solutions::CalibrationSolutions; diff --git a/src/math/mod.rs b/src/math/mod.rs index ee4b8389..bc574245 100644 --- a/src/math/mod.rs +++ b/src/math/mod.rs @@ -30,13 +30,15 @@ pub(crate) fn is_prime(n: usize) -> bool { true } -/// Given a collection of [Epoch]s, return one that is the average of their +/// Given a collection of [`Epoch`]s, return one that is the average of their /// times. -pub(crate) fn average_epoch(es: &[Epoch]) -> Epoch { - let duration_sum = es.iter().fold(Epoch::from_gpst_seconds(0.0), |acc, t| { - acc + t.to_gpst_seconds() - }); - let average = duration_sum.to_gpst_seconds() / es.len() as f64; +pub(crate) fn average_epoch>(es: I) -> Epoch { + let (duration_sum, num_epochs) = es + .into_iter() + .fold((Epoch::from_gpst_seconds(0.0), 0), |acc, t| { + (acc.0 + t.to_gpst_seconds(), acc.1 + 1) + }); + let average = duration_sum.to_gpst_seconds() / num_epochs as f64; Epoch::from_gpst_seconds(average).round(10.milliseconds()) } diff --git a/src/math/tests.rs b/src/math/tests.rs index 88abdc00..4ae2dc3f 100644 --- a/src/math/tests.rs +++ b/src/math/tests.rs @@ -15,7 +15,7 @@ fn test_average_epoch() { Epoch::from_gpst_seconds(1065880132.0), ]; - let average = average_epoch(&epochs); + let average = average_epoch(epochs); assert_abs_diff_eq!(average.to_gpst_seconds(), 1065880130.0); } @@ -27,7 +27,7 @@ fn test_average_epoch2() { Epoch::from_gpst_seconds(1118529192.0), ]; - let average = average_epoch(&epochs); + let average = average_epoch(epochs); // This epsilon is huge, but the epochs span years. At least the first test // is accurate to precision. assert_abs_diff_eq!(average.to_gpst_seconds(), 1091472653.0, epsilon = 0.4); diff --git a/src/messages.rs b/src/messages.rs deleted file mode 100644 index a7e63a39..00000000 --- a/src/messages.rs +++ /dev/null @@ -1,724 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//! Messages to report to the user. -//! -//! When unpacking input data, beam parameters etc. some things are useful to -//! report to the user. However, the order of these messages can appear -//! unrelated or random because of the way the code is ordered. This module -//! attempts to tidy this issue by categorising message types. - -use std::path::{Path, PathBuf}; - -use hifitime::{Duration, Epoch}; -use itertools::Itertools; -use log::{info, trace, warn}; -use marlu::{LatLngHeight, RADec}; -use ndarray::prelude::*; -use vec1::Vec1; - -use crate::{ - flagging::MwafFlags, - io::{ - read::{pfb_gains::PfbFlavour, RawDataCorrections}, - write::VisOutputType, - }, - model::ModellerInfo, - solutions::CalibrationSolutions, - srclist::{ComponentCounts, SourceList}, - unit_parsing::WavelengthUnit, -}; - -#[must_use = "This struct must be consumed with its print() method"] -pub(super) enum InputFileDetails<'a> { - Raw { - obsid: u32, - gpubox_count: usize, - metafits_file_name: String, - mwaf: Option<&'a MwafFlags>, - raw_data_corrections: RawDataCorrections, - }, - MeasurementSet { - obsid: Option, - file_name: String, - metafits_file_name: Option, - }, - UvfitsFile { - obsid: Option, - file_name: String, - metafits_file_name: Option, - }, -} - -impl InputFileDetails<'_> { - pub(super) fn print(self, operation: &str) { - match self { - InputFileDetails::Raw { - obsid, - gpubox_count, - metafits_file_name, - mwaf, - raw_data_corrections, - } => { - info!("{operation} obsid {obsid}"); - info!(" from {gpubox_count} gpubox files"); - info!(" with metafits {metafits_file_name}"); - match mwaf { - Some(flags) => { - let software_string = match flags.software_version.as_ref() { - Some(v) => format!("{} {}", flags.software, v), - None => flags.software.to_string(), - }; - info!( - " with {} mwaf files ({})", - flags.gpubox_nums.len(), - software_string, - ); - if let Some(s) = flags.aoflagger_version.as_deref() { - info!(" AOFlagger version: {s}"); - } - if let Some(s) = flags.aoflagger_strategy.as_deref() { - info!(" AOFlagger strategy: {s}"); - } - } - None => warn!("No mwaf files supplied"), - } - - let s = "Correcting PFB gains"; - match raw_data_corrections.pfb_flavour { - PfbFlavour::None => info!("Not doing any PFB correction"), - PfbFlavour::Jake => info!("{s} with 'Jake Jones' gains"), - PfbFlavour::Cotter2014 => info!("{s} with 'Cotter 2014' gains"), - PfbFlavour::Empirical => info!("{s} with 'RTS empirical' gains"), - PfbFlavour::Levine => info!("{s} with 'Alan Levine' gains"), - } - if raw_data_corrections.digital_gains { - info!("Correcting digital gains"); - } else { - info!("Not correcting digital gains"); - } - if raw_data_corrections.cable_length { - info!("Correcting cable lengths"); - } else { - info!("Not correcting cable lengths"); - } - if raw_data_corrections.geometric { - info!("Correcting geometric delays (if necessary)"); - } else { - info!("Not correcting geometric delays"); - } - } - - InputFileDetails::MeasurementSet { - obsid, - file_name, - metafits_file_name, - } => { - if let Some(o) = obsid { - info!("{operation} obsid {o}"); - info!(" from measurement set {file_name}"); - } else { - info!("{operation} measurement set {file_name}"); - } - if let Some(f) = metafits_file_name { - info!(" with metafits {f}"); - } - } - - InputFileDetails::UvfitsFile { - obsid, - file_name, - metafits_file_name, - } => { - if let Some(o) = obsid { - info!("{operation} obsid {o}"); - info!(" from uvfits {file_name}"); - } else { - info!("{operation} uvfits {file_name}"); - } - if let Some(f) = metafits_file_name { - info!(" with metafits {f}"); - } - } - } - } -} - -#[must_use = "This struct must be consumed with its print() method"] -pub(super) struct ArrayDetails<'a> { - pub(super) array_position: Option, - /// \[radians\] - pub(super) array_latitude_j2000: Option, - pub(super) total_num_tiles: usize, - pub(super) num_unflagged_tiles: usize, - pub(super) flagged_tiles: &'a [(&'a str, usize)], -} - -impl ArrayDetails<'_> { - pub(super) fn print(self) { - if let Some(pos) = self.array_position { - info!( - "Array latitude: {:>8.4}°", - pos.latitude_rad.to_degrees() - ); - } - if let Some(rad) = self.array_latitude_j2000 { - info!("Array latitude (J2000): {:>8.4}°", rad.to_degrees()); - } - if let Some(pos) = self.array_position { - info!( - "Array longitude: {:>9.4}°", - pos.longitude_rad.to_degrees() - ); - info!("Array height: {:>9.4}m", pos.height_metres); - } - - info!("Total number of tiles: {:>3}", self.total_num_tiles); - info!("Number of unflagged tiles: {:>3}", self.num_unflagged_tiles); - info!("Flagged tiles: {:?}", self.flagged_tiles); - } -} - -#[must_use = "This struct must be consumed with its print() method"] -pub(super) struct ObservationDetails<'a> { - /// If this is `None`, no dipole delay or alive/dead status reporting is - /// done. - pub(super) dipole_delays: Option<[u32; 16]>, - pub(super) beam_file: Option<&'a Path>, - /// If this is `None`, then report that we're assuming all dipoles are - /// "alive". - pub(super) num_tiles_with_dead_dipoles: Option, - - pub(super) phase_centre: RADec, - pub(super) pointing_centre: Option, - /// Only printed if it's populated. - pub(super) dut1: Option, - /// The local mean sidereal time of the first timestep \[radians\] - pub(super) lmst: Option, - /// The local mean sidereal time of the first timestep, precessed to the - /// J2000 epoch \[radians\] - pub(super) lmst_j2000: Option, - - pub(super) available_timesteps: Option<&'a [usize]>, - pub(super) unflagged_timesteps: Option<&'a [usize]>, - pub(super) using_timesteps: Option<&'a [usize]>, - pub(super) first_timestamp: Option, - pub(super) last_timestamp: Option, - pub(super) time_res: Option, - - pub(super) total_num_channels: usize, - pub(super) num_unflagged_channels: Option, - pub(super) flagged_chans_per_coarse_chan: Option<&'a [usize]>, - pub(super) first_freq_hz: Option, - pub(super) last_freq_hz: Option, - pub(super) first_unflagged_freq_hz: Option, - pub(super) last_unflagged_freq_hz: Option, - pub(super) freq_res_hz: Option, -} - -impl ObservationDetails<'_> { - pub(super) fn print(self) { - if let Some(d) = self.dipole_delays { - info!( - "Ideal dipole delays: [{:>2} {:>2} {:>2} {:>2}", - d[0], d[1], d[2], d[3] - ); - info!( - " {:>2} {:>2} {:>2} {:>2}", - d[4], d[5], d[6], d[7] - ); - info!( - " {:>2} {:>2} {:>2} {:>2}", - d[8], d[9], d[10], d[11] - ); - info!( - " {:>2} {:>2} {:>2} {:>2}]", - d[12], d[13], d[14], d[15] - ); - // No need to report additional beam information if there are no - // dipole delays; this implies that beam code isn't being used. - if let Some(beam_file) = self.beam_file { - info!("Using beam file {}", beam_file.display()); - } - if let Some(num_tiles_with_dead_dipoles) = self.num_tiles_with_dead_dipoles { - info!( - "Using dead dipole information ({num_tiles_with_dead_dipoles} tiles affected)" - ); - } else { - info!("Assuming all dipoles are \"alive\""); - } - } - - info!( - "Phase centre (J2000): {:>9.4}°, {:>8.4}°", - self.phase_centre.ra.to_degrees(), - self.phase_centre.dec.to_degrees(), - ); - if let Some(pc) = self.pointing_centre { - info!( - "Pointing centre: {:>9.4}°, {:>8.4}°", - pc.ra.to_degrees(), - pc.dec.to_degrees() - ); - } - - if let Some(dut1) = self.dut1 { - info!("DUT1: {} seconds", dut1.to_seconds()); - } - match (self.lmst, self.lmst_j2000) { - (Some(l), Some(l2)) => { - info!("LMST of first timestep: {:>9.6}°", l.to_degrees()); - info!("LMST of first timestep (J2000): {:>9.6}°", l2.to_degrees()); - } - (Some(l), None) => info!("LMST of first timestep: {:>9.6}°", l.to_degrees()), - (None, Some(l2)) => info!("LMST of first timestep (J2000): {:>9.6}°", l2.to_degrees()), - (None, None) => (), - } - - if let Some(available_timesteps) = self.available_timesteps { - info!( - "{}", - range_or_comma_separated(available_timesteps, Some("Available timesteps:")) - ); - } - if let Some(unflagged_timesteps) = self.unflagged_timesteps { - info!( - "{}", - range_or_comma_separated(unflagged_timesteps, Some("Unflagged timesteps:")) - ); - } - // We don't require the timesteps to be used in calibration to be - // sequential. But if they are, it looks a bit neater to print them out - // as a range rather than individual indices. - if let Some(using_timesteps) = self.using_timesteps { - info!( - "{}", - range_or_comma_separated(using_timesteps, Some("Using timesteps: ")) - ); - } - match ( - self.first_timestamp, - self.last_timestamp, - self.first_timestamp.or(self.last_timestamp), - ) { - (Some(f), Some(l), _) => { - info!("First timestamp (GPS): {:.2}", f.to_gpst_seconds()); - info!("Last timestamp (GPS): {:.2}", l.to_gpst_seconds()); - } - (_, _, Some(f)) => info!("Only timestamp (GPS): {:.2}", f.to_gpst_seconds()), - _ => (), - } - match self.time_res { - Some(r) => info!("Input data time resolution: {:.2} seconds", r.to_seconds()), - None => info!("Input data time resolution unknown"), - } - - match self.num_unflagged_channels { - Some(num_unflagged_channels) => { - info!( - "Total number of fine channels: {}", - self.total_num_channels - ); - info!( - "Number of unflagged fine channels: {}", - num_unflagged_channels - ); - } - None => { - info!("Total number of fine channels: {}", self.total_num_channels); - } - } - if let Some(flagged_chans_per_coarse_chan) = self.flagged_chans_per_coarse_chan { - info!( - "Input data's fine-channel flags per coarse channel: {:?}", - flagged_chans_per_coarse_chan - ); - } - match ( - self.first_freq_hz, - self.last_freq_hz, - self.first_unflagged_freq_hz, - self.last_unflagged_freq_hz, - ) { - (Some(f), Some(l), Some(fu), Some(lu)) => { - info!("First fine-channel frequency: {:.3} MHz", f / 1e6); - info!( - "First unflagged fine-channel frequency: {:.3} MHz", - fu / 1e6 - ); - info!("Last fine-channel frequency: {:.3} MHz", l / 1e6); - info!( - "Last unflagged fine-channel frequency: {:.3} MHz", - lu / 1e6 - ); - } - (Some(f), Some(l), None, None) => { - info!("First fine-channel frequency: {:.3} MHz", f / 1e6); - info!("Last fine-channel frequency: {:.3} MHz", l / 1e6); - } - (None, None, Some(f), Some(l)) => { - info!("First unflagged fine-channel frequency: {:.3} MHz", f / 1e6); - info!("Last unflagged fine-channel frequency: {:.3} MHz", l / 1e6); - } - _ => (), - } - match self.freq_res_hz { - Some(r) => info!("Input data freq. resolution: {:.2} kHz", r / 1e3), - None => info!("Input data freq. resolution unknown"), - } - } -} - -#[must_use = "This struct must be consumed with its print() method"] -pub(super) struct CalibrationDetails { - pub(super) timesteps_per_timeblock: usize, - pub(super) channels_per_chanblock: usize, - pub(super) num_timeblocks: usize, - pub(super) num_chanblocks: usize, - pub(super) uvw_min: (f64, WavelengthUnit), - pub(super) uvw_max: (f64, WavelengthUnit), - /// The number of baselines to use in calibration. - pub(super) num_calibration_baselines: usize, - /// The number of total number of baselines. - pub(super) total_num_baselines: usize, - /// If the user specified UVW cutoffs in terms of wavelength, we need to - /// come up with our own lambda to convert the cutoffs to metres (we use the - /// centroid frequency of the observation). \[metres\] - pub(super) lambda: f64, - /// \[Hz\] - pub(super) freq_centroid: f64, - pub(super) min_threshold: f64, - pub(super) stop_threshold: f64, - pub(super) max_iterations: u32, -} - -impl CalibrationDetails { - pub(super) fn print(self) { - // I'm quite bored right now. - let timeblock_plural = if self.num_timeblocks > 1 { - "timeblocks" - } else { - "timeblock" - }; - let chanblock_plural = if self.num_chanblocks > 1 { - "chanblocks" - } else { - "chanblock" - }; - - info!( - "{} calibration {timeblock_plural}, {} calibration {chanblock_plural}", - self.num_timeblocks, self.num_chanblocks - ); - info!(" {} timesteps per timeblock", self.timesteps_per_timeblock); - info!(" {} channels per chanblock", self.channels_per_chanblock); - - // Report extra info if we need to use our own lambda (the user - // specified wavelengths). - if matches!(self.uvw_min.1, WavelengthUnit::L) - || matches!(self.uvw_max.1, WavelengthUnit::L) - { - info!( - "Using observation centroid frequency {} MHz to convert lambdas to metres", - self.freq_centroid / 1e6 - ); - } - - info!( - "Calibrating with {} of {} baselines", - self.num_calibration_baselines, self.total_num_baselines - ); - match (self.uvw_min, self.uvw_min.0.is_infinite()) { - // Again, bored. - (_, true) => info!(" Minimum UVW cutoff: ∞"), - ((quantity, WavelengthUnit::M), _) => info!(" Minimum UVW cutoff: {quantity}m"), - ((quantity, WavelengthUnit::L), _) => info!( - " Minimum UVW cutoff: {quantity}λ ({:.3}m)", - quantity * self.lambda - ), - } - match (self.uvw_max, self.uvw_max.0.is_infinite()) { - (_, true) => info!(" Maximum UVW cutoff: ∞"), - ((quantity, WavelengthUnit::M), _) => info!(" Maximum UVW cutoff: {quantity}m"), - ((quantity, WavelengthUnit::L), _) => info!( - " Maximum UVW cutoff: {quantity}λ ({:.3}m)", - quantity * self.lambda - ), - } - - info!("Chanblocks will stop iterating"); - info!( - " when the error is less than {:e} (stop threshold)", - self.stop_threshold - ); - info!(" or after {} iterations.", self.max_iterations); - info!( - "Chanblocks with an error less than {:e} are considered converged (min. threshold)", - self.min_threshold - ) - } -} - -#[must_use = "This struct must be consumed with its print() method"] -pub(super) struct SkyModelDetails<'a> { - pub(super) source_list: &'a SourceList, -} - -impl SkyModelDetails<'_> { - pub(super) fn print(self) { - let ComponentCounts { - num_points, - num_gaussians, - num_shapelets, - .. - } = self.source_list.get_counts(); - let num_components = num_points + num_gaussians + num_shapelets; - info!( - "Using {} sources with a total of {} components", - self.source_list.len(), - num_components - ); - info!(" {num_points} points, {num_gaussians} Gaussians, {num_shapelets} shapelets"); - if num_components > 10000 { - warn!("Using more than 10,000 components!"); - } - if log::log_enabled!(log::Level::Trace) { - trace!("Using sources:"); - let mut v = Vec::with_capacity(5); - for source in self.source_list.keys() { - if v.len() == 5 { - trace!(" {v:?}"); - v.clear(); - } - v.push(source); - } - if !v.is_empty() { - trace!(" {v:?}"); - } - } - } -} - -#[must_use = "This struct must be consumed with its print() method"] -pub(super) struct OutputFileDetails<'a> { - pub(super) output_solutions: &'a [PathBuf], - pub(super) vis_type: &'a str, - pub(super) output_vis: Option<&'a Vec1<(PathBuf, VisOutputType)>>, - pub(super) input_vis_time_res: Option, - /// \[Hz\] - pub(super) input_vis_freq_res: Option, - pub(super) output_vis_time_average_factor: usize, - pub(super) output_vis_freq_average_factor: usize, -} - -impl OutputFileDetails<'_> { - pub(super) fn print(self) { - if !self.output_solutions.is_empty() { - info!( - "Writing calibration solutions to: {}", - self.output_solutions - .iter() - .map(|pb| pb.display()) - .join(", ") - ); - } - if let Some(output_vis) = self.output_vis { - info!( - "Writing {} visibilities to: {}", - self.vis_type, - output_vis.iter().map(|pb| pb.0.display()).join(", ") - ); - - if self.output_vis_time_average_factor != 1 || self.output_vis_freq_average_factor != 1 - { - info!("Averaging output visibilities"); - if let Some(tr) = self.input_vis_time_res { - info!( - " {}x in time ({}s)", - self.output_vis_time_average_factor, - tr.to_seconds() * self.output_vis_time_average_factor as f64 - ); - } else { - info!( - " {}x (only one timestep)", - self.output_vis_time_average_factor - ); - } - - if let Some(fr) = self.input_vis_freq_res { - info!( - " {}x in freq. ({}kHz)", - self.output_vis_freq_average_factor, - fr * self.output_vis_freq_average_factor as f64 / 1000.0 - ); - } else { - info!( - " {}x (only one fine channel)", - self.output_vis_freq_average_factor - ); - } - } - } - } -} - -#[must_use = "This struct must be consumed with its print() method"] -pub(super) struct CalSolDetails<'a> { - pub(super) filename: &'a Path, - pub(super) sols: &'a CalibrationSolutions, -} - -impl CalSolDetails<'_> { - pub(super) fn print(self) { - let s = self.sols; - let num_timeblocks = s.di_jones.len_of(Axis(0)); - info!( - "Using calibration solutions from {}", - self.filename.display() - ); - info!( - " {num_timeblocks} timeblocks, {} tiles, {} chanblocks", - s.di_jones.len_of(Axis(1)), - s.di_jones.len_of(Axis(2)) - ); - - if let Some(c) = s.raw_data_corrections { - info!(" Raw data corrections:"); - info!(" PFB flavour: {}", c.pfb_flavour); - info!( - " digital gains: {}", - match c.digital_gains { - true => "yes", - false => "no", - } - ); - info!( - " cable lengths: {}", - match c.cable_length { - true => "yes", - false => "no", - } - ); - info!( - " geometric delays: {}", - match c.geometric { - true => "yes", - false => "no", - } - ); - } else { - info!(" No raw data correction information"); - } - - // If there's more than one timeblock, we can report dodgy-looking - // solutions based on the available metadata. - if num_timeblocks > 1 - && match ( - &s.start_timestamps, - &s.end_timestamps, - &s.average_timestamps, - ) { - // Are all types of timestamps available? - (Some(s), Some(e), Some(a)) => { - // Are all the lengths the same? - num_timeblocks != s.len() - || num_timeblocks != e.len() - || num_timeblocks != a.len() - } - _ => true, - } - { - warn!(" Time information is inconsistent; solution timeblocks"); - warn!(" may not be applied properly. hyperdrive-formatted"); - warn!(" solutions should be used to prevent this issue."); - } - } -} - -pub(super) fn print_modeller_info(modeller_info: &ModellerInfo) { - #[cfg(feature = "cuda")] - let using_cuda = matches!(modeller_info, crate::model::ModellerInfo::Cuda { .. }); - #[cfg(not(feature = "cuda"))] - let using_cuda = false; - - if using_cuda { - cfg_if::cfg_if! { - if #[cfg(feature = "cuda-single")] { - info!("Generating sky model visibilities on the GPU (single precision)"); - } else { - info!("Generating sky model visibilities on the GPU (double precision)"); - } - } - } else { - info!("Generating sky model visibilities on the CPU (double precision)"); - } - - match modeller_info { - crate::model::ModellerInfo::Cpu => (), - - #[cfg(feature = "cuda")] - crate::model::ModellerInfo::Cuda { - device_info, - driver_info, - } => { - info!( - " CUDA device: {} (capability {}, {} MiB)", - device_info.name, device_info.capability, device_info.total_global_mem - ); - info!( - " CUDA driver: {}, runtime: {}", - driver_info.driver_version, driver_info.runtime_version - ); - } - } -} - -// It looks a bit neater to print out a collection of numbers as a range rather -// than individual indices if they're sequential. This function inspects a -// collection and returns a string to be printed. -fn range_or_comma_separated(collection: &[usize], prefix: Option<&str>) -> String { - if collection.is_empty() { - return "".to_string(); - } - - let mut iter = collection.iter(); - let mut prev = *iter.next().unwrap(); - // Innocent until proven guilty. - let mut is_sequential = true; - for next in iter { - if *next == prev + 1 { - prev = *next; - } else { - is_sequential = false; - break; - } - } - - if is_sequential { - let suffix = if collection.len() == 1 { - format!("[{}]", collection[0]) - } else { - format!( - "[{:?})", - (*collection.first().unwrap()..*collection.last().unwrap() + 1) - ) - }; - if let Some(p) = prefix { - format!("{p} {suffix}") - } else { - suffix - } - } else { - let suffix = collection - .iter() - .map(|t| t.to_string()) - .collect::>() - .join(", "); - if let Some(p) = prefix { - format!("{p} [{suffix}]") - } else { - suffix - } - } -} diff --git a/src/model/cpu.rs b/src/model/cpu.rs index a0dca5cc..ff07d071 100644 --- a/src/model/cpu.rs +++ b/src/model/cpu.rs @@ -34,7 +34,7 @@ use crate::{ const GAUSSIAN_EXP_CONST: f64 = -(FRAC_PI_2 * FRAC_PI_2) / LN_2; const SHAPELET_CONST: f64 = SQRT_FRAC_PI_SQ_2_LN_2 / shapelets::SBF_DX; -pub(crate) struct SkyModellerCpu<'a> { +pub struct SkyModellerCpu<'a> { pub(super) beam: &'a dyn Beam, /// The phase centre used for all modelling. @@ -68,7 +68,7 @@ pub(crate) struct SkyModellerCpu<'a> { impl<'a> SkyModellerCpu<'a> { #[allow(clippy::too_many_arguments)] - pub(super) fn new( + pub fn new( beam: &'a dyn Beam, source_list: &SourceList, pols: Polarisations, diff --git a/src/model/cuda.rs b/src/model/cuda.rs index 74929de0..0676103f 100644 --- a/src/model/cuda.rs +++ b/src/model/cuda.rs @@ -28,7 +28,7 @@ use crate::{ /// The first axis of `*_list_fds` is unflagged fine channel frequency, the /// second is the source component. The length of `hadecs`, `lmns`, /// `*_list_fds`'s second axis are the same. -pub(crate) struct SkyModellerCuda<'a> { +pub struct SkyModellerCuda<'a> { /// The trait object to use for beam calculations. cuda_beam: Box, @@ -141,7 +141,7 @@ impl<'a> SkyModellerCuda<'a> { /// function, because using power laws is more efficient and probably more /// accurate. #[allow(clippy::too_many_arguments)] - pub(super) fn new( + pub fn new( beam: &dyn Beam, source_list: &SourceList, pols: Polarisations, diff --git a/src/model/mod.rs b/src/model/mod.rs index 1719811a..f414aed6 100644 --- a/src/model/mod.rs +++ b/src/model/mod.rs @@ -8,15 +8,13 @@ mod cpu; #[cfg(feature = "cuda")] mod cuda; mod error; -#[cfg(test)] -mod integration_tests; pub(crate) mod shapelets; #[cfg(test)] mod tests; -use cpu::SkyModellerCpu; +pub use cpu::SkyModellerCpu; #[cfg(feature = "cuda")] -use cuda::SkyModellerCuda; +pub use cuda::SkyModellerCuda; pub(crate) use error::ModelError; use std::collections::HashSet; @@ -25,10 +23,10 @@ use hifitime::{Duration, Epoch}; use marlu::{c32, Jones, RADec, XyzGeodetic, UVW}; use ndarray::{Array2, ArrayViewMut2}; -use crate::{beam::Beam, srclist::SourceList, Polarisations}; +use crate::{beam::Beam, context::Polarisations, srclist::SourceList, MODEL_DEVICE}; -#[derive(Debug, Clone)] -pub(crate) enum ModellerInfo { +#[derive(Debug, Clone, Copy)] +pub enum ModelDevice { /// The CPU is used for modelling. This always uses double-precision floats /// when modelling. Cpu, @@ -36,10 +34,80 @@ pub(crate) enum ModellerInfo { /// A CUDA-capable device is used for modelling. The precision depends on /// the compile features used. #[cfg(feature = "cuda")] - Cuda { - device_info: crate::cuda::CudaDeviceInfo, - driver_info: crate::cuda::CudaDriverInfo, - }, + Cuda, +} + +impl ModelDevice { + pub(crate) fn get_precision(self) -> &'static str { + match self { + ModelDevice::Cpu => "double", + + #[cfg(feature = "cuda-single")] + ModelDevice::Cuda => "single", + + #[cfg(all(feature = "cuda", not(feature = "cuda-single")))] + ModelDevice::Cuda => "double", + } + } + + /// Get a formatted string with information on the device used for + /// modelling. + pub(crate) fn get_device_info(self) -> Result { + match self { + ModelDevice::Cpu => Ok(get_cpu_info()), + + #[cfg(feature = "cuda")] + ModelDevice::Cuda => { + let (device_info, driver_info) = crate::cuda::get_device_info()?; + Ok(format!( + "{} (capability {}, {} MiB), CUDA driver {}, runtime {}", + device_info.name, + device_info.capability, + device_info.total_global_mem, + driver_info.driver_version, + driver_info.runtime_version + )) + } + } + } +} + +#[derive(thiserror::Error, Debug)] +pub(crate) enum DeviceError { + #[cfg(feature = "cuda")] + #[error(transparent)] + Cuda(#[from] crate::cuda::CudaError), +} + +/// Get a formatted string with information on the device used for modelling. +// TODO: Is there a way to get the name of the CPU without some crazy +// dependencies? +pub(crate) fn get_cpu_info() -> String { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + // Non-exhaustive but perhaps most-interesting CPU features. + let avx = std::arch::is_x86_feature_detected!("avx"); + let avx2 = std::arch::is_x86_feature_detected!("avx2"); + let avx512 = std::arch::is_x86_feature_detected!("avx512f"); + + match (avx512, avx2, avx) { + (true, _, _) => { + format!("{} CPU (AVX512 available)", std::env::consts::ARCH) + } + (false, true, _) => { + format!("{} CPU (AVX2 available)", std::env::consts::ARCH) + } + (false, false, true) => { + format!("{} CPU (AVX available)", std::env::consts::ARCH) + } + (false, false, false) => { + format!("{} CPU (AVX unavailable!)", std::env::consts::ARCH) + } + } + } + + #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))] + Ok(format!("{} CPU", std::env::consts::ARCH)); } /// An object that simulates sky-model visibilities. @@ -102,7 +170,6 @@ pub trait SkyModeller<'a> { /// executed, or if there was a problem in setting up a `BeamCUDA`. #[allow(clippy::too_many_arguments)] pub fn new_sky_modeller<'a>( - #[cfg(feature = "cuda")] use_cpu_for_modelling: bool, beam: &'a dyn Beam, source_list: &SourceList, pols: Polarisations, @@ -115,40 +182,24 @@ pub fn new_sky_modeller<'a>( dut1: Duration, apply_precession: bool, ) -> Result + 'a>, ModelError> { - cfg_if::cfg_if! { - if #[cfg(feature = "cuda")] { - if use_cpu_for_modelling { - Ok(Box::new(SkyModellerCpu::new( - beam, - source_list, - pols, - unflagged_tile_xyzs, - unflagged_fine_chan_freqs, - flagged_tiles, - phase_centre, - array_longitude_rad, - array_latitude_rad, - dut1, - apply_precession, - ))) - } else { - let modeller = SkyModellerCuda::new( - beam, - source_list, - pols, - unflagged_tile_xyzs, - unflagged_fine_chan_freqs, - flagged_tiles, - phase_centre, - array_longitude_rad, - array_latitude_rad, - dut1, - apply_precession, - )?; - Ok(Box::new(modeller)) - } - } else { - Ok(Box::new(SkyModellerCpu::new( + match MODEL_DEVICE.load() { + ModelDevice::Cpu => Ok(Box::new(SkyModellerCpu::new( + beam, + source_list, + pols, + unflagged_tile_xyzs, + unflagged_fine_chan_freqs, + flagged_tiles, + phase_centre, + array_longitude_rad, + array_latitude_rad, + dut1, + apply_precession, + ))), + + #[cfg(feature = "cuda")] + ModelDevice::Cuda => { + let modeller = SkyModellerCuda::new( beam, source_list, pols, @@ -160,7 +211,8 @@ pub fn new_sky_modeller<'a>( array_latitude_rad, dut1, apply_precession, - ))) + )?; + Ok(Box::new(modeller)) } } } diff --git a/src/params/di_calibration.rs b/src/params/di_calibration.rs new file mode 100644 index 00000000..6388b2f4 --- /dev/null +++ b/src/params/di_calibration.rs @@ -0,0 +1,605 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +use std::{ + path::PathBuf, + thread::{self, ScopedJoinHandle}, +}; + +use crossbeam_channel::{unbounded, Sender}; +use crossbeam_utils::atomic::AtomicCell; +use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; +use itertools::{izip, Itertools}; +use log::{debug, info, log_enabled, Level::Debug}; +use marlu::{ + constants::{FREQ_WEIGHT_FACTOR, TIME_WEIGHT_FACTOR}, + Jones, +}; +use ndarray::{iter::AxisIterMut, prelude::*}; +use rayon::prelude::*; +use scopeguard::defer_on_unwind; +use vec1::Vec1; + +use super::{InputVisParams, ModellingParams, OutputVisParams}; +use crate::{ + averaging::Timeblock, + beam::Beam, + context::Polarisations, + di_calibrate::calibrate_timeblocks, + io::{ + read::VisReadError, + write::{write_vis, VisTimestep, VisWriteError}, + }, + misc::expensive_op, + model::{new_sky_modeller, ModelError}, + solutions::CalSolutionType, + srclist::SourceList, + CalibrationSolutions, PROGRESS_BARS, +}; + +/// Parameters needed to perform calibration. +pub(crate) struct DiCalParams { + /// The interface to the input data, metadata and flags. + pub(crate) input_vis_params: InputVisParams, + + /// Beam object. + pub(crate) beam: Box, + + /// The sky-model source list. + pub(crate) source_list: SourceList, + + /// Blocks of timesteps used for calibration. Each timeblock contains + /// indices of the input data to average together during calibration. Each + /// timeblock may have a different number of timesteps; the number of blocks + /// and their lengths depends on which input data timesteps are being used + /// as well as the `time_average_factor` (i.e. the number of timesteps to + /// average during calibration; by default we average all timesteps). + /// + /// Simple examples: If we are averaging all data over time to form + /// calibration solutions, there will only be one timeblock, and that block + /// will contain all input data timestep indices. On the other hand, if + /// `time_average_factor` is 1, then there are as many timeblocks as there + /// are timesteps, and each block contains 1 timestep index. + /// + /// A more complicated example: If we are using input data timesteps 10, 11, + /// 12 and 15 with a `time_average_factor` of 4, then there will be 2 + /// timeblocks, even though there are only 4 timesteps. This is because + /// timestep 10 and 15 can't occupy the same timeblock with the "length" is + /// 4. So the first timeblock contains 10, 11 and 12, whereas the second + /// contains only 15. + pub(crate) cal_timeblocks: Vec1, + + /// The minimum UVW cutoff used in calibration \[metres\]. + pub(crate) uvw_min: f64, + + /// The maximum UVW cutoff used in calibration \[metres\]. + pub(crate) uvw_max: f64, + + /// The centroid frequency of the observation used to convert UVW cutoffs + /// specified in lambdas to metres \[Hz\]. + pub(crate) freq_centroid: f64, + + /// Multiplicative factors to apply to unflagged baselines. These are mostly + /// all 1.0, but flagged baselines (perhaps due to a UVW cutoff) have values + /// of 0.0. + pub(crate) baseline_weights: Vec1, + + /// The maximum number of times to iterate when performing calibration. + pub(crate) max_iterations: u32, + + /// The threshold at which we stop convergence when performing calibration. + /// This is smaller than `min_threshold`. + pub(crate) stop_threshold: f64, + + /// The minimum threshold to satisfy convergence when performing calibration. + /// Reaching this threshold counts as "converged", but it's not as good as + /// the stop threshold. This is bigger than `stop_threshold`. + pub(crate) min_threshold: f64, + + /// The paths to the files where the calibration solutions are written. The + /// same solutions are written to each file here, but the format may be + /// different (indicated by the second part of the tuples). + pub(crate) output_solution_files: Vec1<(PathBuf, CalSolutionType)>, + + /// The parameters for optional sky-model visibilities files. If specified, + /// model visibilities will be written out before calibration. + pub(crate) output_model_vis_params: Option, + + /// Parameters for modelling. + pub(crate) modelling_params: ModellingParams, +} + +impl DiCalParams { + /// Use the [`DiCalParams`] to perform calibration and obtain solutions. + pub(crate) fn run(&self) -> Result { + let input_vis_params = &self.input_vis_params; + + let CalVis { + vis_data, + vis_weights, + vis_model, + pols, + } = self.get_cal_vis()?; + assert_eq!(vis_weights.len_of(Axis(2)), self.baseline_weights.len()); + + // The shape of the array containing output Jones matrices. + let num_timeblocks = input_vis_params.timeblocks.len(); + let num_chanblocks = input_vis_params.spw.chanblocks.len(); + let num_unflagged_tiles = input_vis_params.get_num_unflagged_tiles(); + + if log_enabled!(Debug) { + let shape = (num_timeblocks, num_unflagged_tiles, num_chanblocks); + debug!( + "Shape of DI Jones matrices array: ({} timeblocks, {} tiles, {} chanblocks; {} MiB)", + shape.0, + shape.1, + shape.2, + shape.0 * shape.1 * shape.2 * std::mem::size_of::>() + // 1024 * 1024 == 1 MiB. + / 1024 / 1024 + ); + } + + let (sols, results) = calibrate_timeblocks( + vis_data.view(), + vis_model.view(), + &self.cal_timeblocks, + &input_vis_params.spw.chanblocks, + self.max_iterations, + self.stop_threshold, + self.min_threshold, + pols, + true, + ); + + // "Complete" the solutions. + let sols = sols.into_cal_sols(self, Some(results.map(|r| r.max_precision))); + + Ok(sols) + } + + /// For calibration, read in unflagged visibilities and generate sky-model + /// visibilities. + pub(crate) fn get_cal_vis(&self) -> Result { + let input_vis_params = &self.input_vis_params; + + // Get the time and frequency resolutions once; these functions issue + // warnings if they have to guess, so doing this once means we aren't + // issuing too many warnings. + let obs_context = input_vis_params.get_obs_context(); + let num_unflagged_tiles = input_vis_params.get_num_unflagged_tiles(); + let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + + let vis_shape = ( + input_vis_params + .timeblocks + .iter() + .flat_map(|t| &t.timestamps) + .count(), + input_vis_params.spw.chanblocks.len(), + num_unflagged_cross_baselines, + ); + let num_elems = vis_shape.0 * vis_shape.1 * vis_shape.2; + // We need this many bytes for each of the data and model arrays to do + // calibration. + let size = indicatif::HumanBytes((num_elems * std::mem::size_of::>()) as u64); + debug!("Shape of data and model arrays: ({} timesteps, {} channels, {} baselines; {size} each)", vis_shape.0, vis_shape.1, vis_shape.2); + + macro_rules! fallible_allocator { + ($default:expr) => {{ + let mut v = Vec::new(); + match v.try_reserve_exact(num_elems) { + Ok(()) => { + v.resize(num_elems, $default); + Ok(Array3::from_shape_vec(vis_shape, v).unwrap()) + } + Err(_) => { + // We need this many gibibytes to do calibration (two + // visibility arrays and one weights array). + let need_gib = indicatif::HumanBytes( + (num_elems + * (2 * std::mem::size_of::>() + + std::mem::size_of::())) as u64, + ); + + Err(DiCalibrateError::InsufficientMemory { + // Instead of erroring out with how many bytes we need + // for the array we just tried to allocate, error out + // with how many bytes we need total. + need_gib, + }) + } + } + }}; + } + + debug!("Allocating memory for input data visibilities and model visibilities"); + let cal_vis = expensive_op( + || -> Result<_, DiCalibrateError> { + let vis_data: Array3> = fallible_allocator!(Jones::default())?; + let vis_model: Array3> = fallible_allocator!(Jones::default())?; + let vis_weights: Array3 = fallible_allocator!(0.0)?; + Ok(CalVis { + vis_data, + vis_weights, + vis_model, + pols: Polarisations::default(), + }) + }, + "Still waiting to allocate visibility memory", + )?; + let CalVis { + mut vis_data, + mut vis_model, + mut vis_weights, + pols: _, + } = cal_vis; + + // Sky-modelling communication channel. Used to tell the model writer when + // visibilities have been generated and they're ready to be written. + let (tx_model, rx_model) = unbounded(); + + // Progress bars. Courtesy Dev. + let multi_progress = MultiProgress::with_draw_target(if PROGRESS_BARS.load() { + ProgressDrawTarget::stdout() + } else { + ProgressDrawTarget::hidden() + }); + let pb = ProgressBar::new(input_vis_params.timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Reading data"); + let read_progress = multi_progress.add(pb); + let pb = ProgressBar::new(input_vis_params.timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Sky modelling"); + let model_progress = multi_progress.add(pb); + // Only add a model writing progress bar if we need it. + let model_write_progress = self.output_model_vis_params.as_ref().map(|o| { + let pb = ProgressBar::new(o.output_timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timeblocks ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Model writing"); + multi_progress.add(pb) + }); + + // Use a variable to track whether any threads have an issue. + let error = AtomicCell::new(false); + info!("Reading input data and sky modelling"); + thread::scope(|scope| -> Result<(), DiCalibrateError> { + // Mutable slices of the "global" arrays. These allow threads to mutate + // the global arrays in parallel (using the Arc> pattern would + // kill performance here). + let vis_data_slices = vis_data.outer_iter_mut(); + let vis_model_slices = vis_model.outer_iter_mut(); + let vis_weight_slices = vis_weights.outer_iter_mut(); + + // Input visibility-data reading thread. + let data_handle: ScopedJoinHandle> = thread::Builder::new() + .name("read".to_string()) + .spawn_scoped(scope, || { + // If a panic happens, update our atomic error. + defer_on_unwind! { error.store(true); } + read_progress.tick(); + + for (timeblock, vis_data_fb, vis_weights_fb) in izip!( + &input_vis_params.timeblocks, + vis_data_slices, + vis_weight_slices + ) { + let result = input_vis_params.read_timeblock( + timeblock, + vis_data_fb, + vis_weights_fb, + None, + &error, + ); + + // If the result of reading data was an error, allow the other + // threads to see this so they can abandon their work early. + if result.is_err() { + error.store(true); + } + result?; + + // Should we continue? + if error.load() { + return Ok(()); + } + + read_progress.inc(1); + } + + debug!("Finished reading"); + read_progress.abandon_with_message("Finished reading visibilities"); + Ok(()) + }) + .expect("OS can create threads"); + + // Sky-model generation thread. + let model_handle: ScopedJoinHandle> = thread::Builder::new() + .name("model".to_string()) + .spawn_scoped(scope, || { + defer_on_unwind! { error.store(true); } + model_progress.tick(); + + let result = model_thread( + &*self.beam, + &self.source_list, + input_vis_params, + self.modelling_params.apply_precession, + vis_model_slices, + tx_model, + &error, + model_progress, + ); + if result.is_err() { + error.store(true); + } + result + }) + .expect("OS can create threads"); + + // Model writing thread. If the user hasn't specified to write the model + // to a file, then this thread just consumes messages from the modeller. + let writer_handle: ScopedJoinHandle> = thread::Builder::new() + .name("model writer".to_string()) + .spawn_scoped(scope, || { + defer_on_unwind! { error.store(true); } + + // If the user wants the sky model written out, + // `output_model_vis_params` is populated. + if let Some(OutputVisParams { + output_files, + output_time_average_factor, + output_freq_average_factor, + output_timeblocks, + write_smallest_contiguous_band, + }) = &self.output_model_vis_params + { + if let Some(pb) = model_write_progress.as_ref() { + pb.tick(); + } + + let unflagged_baseline_tile_pairs = input_vis_params + .tile_baseline_flags + .tile_to_unflagged_cross_baseline_map + .keys() + .copied() + .sorted() + .collect::>(); + + let result = write_vis( + output_files, + obs_context.array_position, + obs_context.phase_centre, + obs_context.pointing_centre, + &obs_context.tile_xyzs, + &obs_context.tile_names, + obs_context.obsid, + output_timeblocks, + input_vis_params.time_res, + input_vis_params.dut1, + &input_vis_params.spw, + &unflagged_baseline_tile_pairs, + *output_time_average_factor, + *output_freq_average_factor, + input_vis_params.vis_reader.get_marlu_mwa_info().as_ref(), + *write_smallest_contiguous_band, + rx_model, + &error, + model_write_progress, + ); + if result.is_err() { + error.store(true); + } + // Discard the result string. + result?; + Ok(()) + } else { + // There's no model to write out, but we still need to handle all of the + // incoming messages. + for _ in rx_model.iter() {} + Ok(()) + } + }) + .expect("OS can create threads"); + + // Join all thread handles. This propagates any errors and lets us know + // if any threads panicked, if panics aren't aborting as per the + // Cargo.toml. (It would be nice to capture the panic information, if + // it's possible, but I don't know how, so panics are currently + // aborting.) + data_handle.join().unwrap()?; + model_handle.join().unwrap()?; + writer_handle.join().unwrap()?; + Ok(()) + })?; + + debug!("Multiplying visibilities by weights"); + + // Multiply the visibilities by the weights (and baseline weights based on + // UVW cuts). If a weight is negative, it means the corresponding visibility + // should be flagged, so that visibility is set to 0; this means it does not + // affect calibration. Not iterating over weights during calibration makes + // makes calibration run significantly faster. + vis_data + .outer_iter_mut() + .into_par_iter() + .zip(vis_model.outer_iter_mut()) + .zip(vis_weights.outer_iter()) + .for_each(|((mut vis_data, mut vis_model), vis_weights)| { + vis_data + .outer_iter_mut() + .zip(vis_model.outer_iter_mut()) + .zip(vis_weights.outer_iter()) + .for_each(|((mut vis_data, mut vis_model), vis_weights)| { + vis_data + .iter_mut() + .zip(vis_model.iter_mut()) + .zip(vis_weights.iter()) + .zip(self.baseline_weights.iter()) + .for_each(|(((vis_data, vis_model), &vis_weight), baseline_weight)| { + let weight = f64::from(vis_weight) * *baseline_weight; + if weight <= 0.0 { + *vis_data = Jones::default(); + *vis_model = Jones::default(); + } else { + *vis_data = + Jones::::from(Jones::::from(*vis_data) * weight); + *vis_model = + Jones::::from(Jones::::from(*vis_model) * weight); + } + }); + }); + }); + + info!("Finished reading input data and sky modelling"); + + Ok(CalVis { + vis_data, + vis_weights, + vis_model, + pols: obs_context.polarisations, + }) + } +} + +#[allow(clippy::too_many_arguments)] +fn model_thread( + beam: &dyn Beam, + source_list: &SourceList, + input_vis_params: &InputVisParams, + apply_precession: bool, + vis_model_slices: AxisIterMut<'_, Jones, Ix2>, + tx: Sender, + error: &AtomicCell, + progress_bar: ProgressBar, +) -> Result<(), ModelError> { + let obs_context = input_vis_params.get_obs_context(); + let unflagged_tile_xyzs = obs_context + .tile_xyzs + .iter() + .enumerate() + .filter(|(i, _)| { + !input_vis_params + .tile_baseline_flags + .flagged_tiles + .contains(i) + }) + .map(|(_, xyz)| *xyz) + .collect::>(); + let freqs = input_vis_params + .spw + .chanblocks + .iter() + .map(|c| c.freq) + .collect::>(); + let modeller = new_sky_modeller( + beam, + source_list, + obs_context.polarisations, + &unflagged_tile_xyzs, + &freqs, + &input_vis_params.tile_baseline_flags.flagged_tiles, + obs_context.phase_centre, + obs_context.array_position.longitude_rad, + obs_context.array_position.latitude_rad, + input_vis_params.dut1, + apply_precession, + )?; + + let weight_factor = ((input_vis_params.spw.freq_res / FREQ_WEIGHT_FACTOR) + * (input_vis_params.time_res.to_seconds() / TIME_WEIGHT_FACTOR)) + as f32; + + // Iterate over all calibration timesteps and write to the model slices. + for (timestamp, mut vis_model_fb) in input_vis_params + .timeblocks + .iter() + .map(|tb| tb.median) + .zip(vis_model_slices) + { + debug!("Modelling timestamp {}", timestamp.to_gpst_seconds()); + modeller.model_timestep_with(timestamp, vis_model_fb.view_mut())?; + + // Should we continue? + if error.load() { + return Ok(()); + } + + match tx.send(VisTimestep { + cross_data_fb: vis_model_fb.to_shared(), + cross_weights_fb: ArcArray::from_elem(vis_model_fb.dim(), weight_factor), + autos: None, + timestamp, + }) { + Ok(()) => (), + // If we can't send the message, it's because the channel has + // been closed on the other side. That should only happen + // because the writer has exited due to error; in that case, + // just exit this thread. + Err(_) => return Ok(()), + } + progress_bar.inc(1); + } + + debug!("Finished modelling"); + progress_bar.abandon_with_message("Finished generating sky model"); + Ok(()) +} + +pub(crate) struct CalVis { + /// Visibilites read from input data. + pub(crate) vis_data: Array3>, + + /// The weights on the visibilites read from input data. + pub(crate) vis_weights: Array3, + + /// Visibilites generated from the sky-model source list. + pub(crate) vis_model: Array3>, + + /// The available polarisations within the data. + pub(crate) pols: Polarisations, +} + +#[derive(thiserror::Error, Debug)] +pub(crate) enum DiCalibrateError { + #[error("Insufficient memory available to perform calibration; need {need_gib} of memory.\nYou could try using fewer timesteps and channels.")] + InsufficientMemory { need_gib: indicatif::HumanBytes }, + + #[error(transparent)] + SolutionsRead(#[from] crate::solutions::SolutionsReadError), + + #[error(transparent)] + SolutionsWrite(#[from] crate::solutions::SolutionsWriteError), + + #[error(transparent)] + Model(#[from] crate::model::ModelError), + + #[error(transparent)] + VisRead(#[from] crate::io::read::VisReadError), + + #[error(transparent)] + VisWrite(#[from] crate::io::write::VisWriteError), + + #[error(transparent)] + Fitsio(#[from] fitsio::errors::Error), + + #[error(transparent)] + IO(#[from] std::io::Error), +} diff --git a/src/params/input_vis.rs b/src/params/input_vis.rs new file mode 100644 index 00000000..ec823806 --- /dev/null +++ b/src/params/input_vis.rs @@ -0,0 +1,578 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Parameters for input data. The main struct here ([`InputVisParams`]) +//! includes a method (`read_timeblock`) for reading in averaged and/or +//! calibrated visibilities, depending on whether averaging is requested and +//! whether calibration solutions have been supplied. Other info like +//! chanblocks, tile and channel flags etc. also live here. + +use std::collections::HashSet; + +use crossbeam_utils::atomic::AtomicCell; +use hifitime::{Duration, Epoch}; +use itertools::{izip, Itertools}; +use log::debug; +use marlu::Jones; +use ndarray::prelude::*; +use vec1::Vec1; + +use crate::{ + averaging::{vis_average, Spw, Timeblock}, + context::ObsContext, + io::read::{VisRead, VisReadError}, + math::TileBaselineFlags, + CalibrationSolutions, +}; + +pub(crate) struct InputVisParams { + /// The object to read visibility data. + pub(crate) vis_reader: Box, + + /// Calibration solutions. If available, these are automatically applied + /// when `InputVisParams::read_timeblock` is called. + pub(crate) solutions: Option, + + /// The timeblocks to be used from the averaged data. If there is no + /// averaging to be done, then these are the same as the timesteps to be + /// read from the data. + pub(crate) timeblocks: Vec1, + + /// The time resolution of the data *after* averaging (i.e. when using + /// `InputVisParams::read_timeblock`). + pub(crate) time_res: Duration, + + /// Channel and frequency information. Note that this is a single contiguous + /// spectral window, not multiple spectral windows (a.k.a. picket fence). + pub(crate) spw: Spw, + + pub(crate) tile_baseline_flags: TileBaselineFlags, + + /// Are autocorrelations to be read? + pub(crate) using_autos: bool, + + /// Are we ignoring weights? + pub(crate) ignore_weights: bool, + + /// The UT1 - UTC offset. If this is 0, effectively UT1 == UTC, which is a + /// wrong assumption by up to 0.9s. We assume the this value does not change + /// over the timestamps used in this [`InputVisParams`]. + /// + /// Note that this need not be the same DUT1 in the input data's + /// [`ObsContext`]; the user may choose to suppress that DUT1 or supply + /// their own. + pub(crate) dut1: Duration, +} + +impl InputVisParams { + pub(crate) fn get_obs_context(&self) -> &ObsContext { + self.vis_reader.get_obs_context() + } + + pub(crate) fn get_total_num_tiles(&self) -> usize { + self.get_obs_context().get_total_num_tiles() + } + + pub(crate) fn get_num_unflagged_tiles(&self) -> usize { + self.tile_baseline_flags + .unflagged_auto_index_to_tile_map + .len() + } + + /// Read the cross-correlation visibilities out of the input data, averaged + /// to the target resolution. If calibration solutions were supplied, then + /// these are applied before averaging. + pub(crate) fn read_timeblock( + &self, + timeblock: &Timeblock, + mut cross_data_fb: ArrayViewMut2>, + mut cross_weights_fb: ArrayViewMut2, + mut autos_fb: Option<(ArrayViewMut2>, ArrayViewMut2)>, + error: &AtomicCell, + ) -> Result<(), VisReadError> { + let obs_context = self.get_obs_context(); + let num_unflagged_tiles = self.get_num_unflagged_tiles(); + let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + let avg_cross_vis_shape = (self.spw.chanblocks.len(), num_unflagged_cross_baselines); + let avg_auto_vis_shape = (self.spw.chanblocks.len(), num_unflagged_tiles); + assert_eq!(cross_data_fb.dim(), avg_cross_vis_shape); + assert_eq!(cross_weights_fb.dim(), avg_cross_vis_shape); + if let Some((auto_data_fb, auto_weights_fb)) = autos_fb.as_ref() { + assert_eq!(auto_data_fb.dim(), avg_auto_vis_shape); + assert_eq!(auto_weights_fb.dim(), avg_auto_vis_shape); + } + + let averaging = timeblock.timestamps.len() > 1 || self.spw.chans_per_chanblock.get() > 1; + + if averaging { + let cross_vis_shape = ( + timeblock.timestamps.len(), + obs_context.fine_chan_freqs.len(), + num_unflagged_cross_baselines, + ); + let mut unaveraged_cross_data_tfb = Array3::zeros(cross_vis_shape); + let mut unaveraged_cross_weights_tfb = Array3::zeros(cross_vis_shape); + // If the user has supplied arrays for autos and the input data has + // autos, read those out. + let mut unaveraged_autos = + match (autos_fb.as_ref(), obs_context.autocorrelations_present) { + (Some(_), true) => { + let auto_vis_shape = ( + timeblock.timestamps.len(), + obs_context.fine_chan_freqs.len(), + num_unflagged_tiles, + ); + let unaveraged_auto_data_tfb = Array3::zeros(auto_vis_shape); + let unaveraged_auto_weights_tfb = Array3::zeros(auto_vis_shape); + Some((unaveraged_auto_data_tfb, unaveraged_auto_weights_tfb)) + } + + _ => None, + }; + + if let Some((unaveraged_auto_data_tfb, unaveraged_auto_weights_tfb)) = + unaveraged_autos.as_mut() + { + for ( + ×tamp, + ×tep, + unaveraged_cross_data_fb, + unaveraged_cross_weights_fb, + unaveraged_auto_data_fb, + unaveraged_auto_weights_fb, + ) in izip!( + timeblock.timestamps.iter(), + timeblock.timesteps.iter(), + unaveraged_cross_data_tfb.outer_iter_mut(), + unaveraged_cross_weights_tfb.outer_iter_mut(), + unaveraged_auto_data_tfb.outer_iter_mut(), + unaveraged_auto_weights_tfb.outer_iter_mut() + ) { + debug!("Reading timestamp {}", timestamp.to_gpst_seconds()); + + self.read_timestep( + timestep, + unaveraged_cross_data_fb, + unaveraged_cross_weights_fb, + Some((unaveraged_auto_data_fb, unaveraged_auto_weights_fb)), + &HashSet::new(), + )?; + + // Should we continue? + if error.load() { + return Ok(()); + } + } + } else { + for ( + ×tamp, + ×tep, + unaveraged_cross_data_fb, + unaveraged_cross_weights_fb, + ) in izip!( + timeblock.timestamps.iter(), + timeblock.timesteps.iter(), + unaveraged_cross_data_tfb.outer_iter_mut(), + unaveraged_cross_weights_tfb.outer_iter_mut() + ) { + debug!("Reading timestamp {}", timestamp.to_gpst_seconds()); + + self.read_timestep( + timestep, + unaveraged_cross_data_fb, + unaveraged_cross_weights_fb, + None, + &HashSet::new(), + )?; + + // Should we continue? + if error.load() { + return Ok(()); + } + } + }; + + // Apply flagged channels. + for i_chan in &self.spw.flagged_chan_indices { + let i_chan = usize::from(*i_chan); + unaveraged_cross_weights_tfb + .slice_mut(s![.., i_chan, ..]) + .mapv_inplace(|w| -w.abs()); + unaveraged_cross_weights_tfb + .slice_mut(s![.., i_chan, ..]) + .mapv_inplace(|w| -w.abs()); + if let Some((_, unaveraged_auto_weights_tfb)) = unaveraged_autos.as_mut() { + unaveraged_auto_weights_tfb + .slice_mut(s![.., i_chan, ..]) + .mapv_inplace(|w| -w.abs()); + } + } + + // We've now read in all of the timesteps for this timeblock. If + // there are calibration solutions, these now need to be applied. + if self.solutions.is_some() { + debug!( + "Applying calibration solutions to input data from timeblock {}", + timeblock.index + ); + + let chan_freqs = obs_context.fine_chan_freqs.mapped_ref(|f| *f as f64); + if let Some((unaveraged_auto_data_tfb, unaveraged_auto_weights_tfb)) = + unaveraged_autos.as_mut() + { + for ( + ×tamp, + cross_data_fb, + cross_weights_fb, + auto_data_fb, + auto_weights_fb, + ) in izip!( + timeblock.timestamps.iter(), + unaveraged_cross_data_tfb.outer_iter_mut(), + unaveraged_cross_weights_tfb.outer_iter_mut(), + unaveraged_auto_data_tfb.outer_iter_mut(), + unaveraged_auto_weights_tfb.outer_iter_mut() + ) { + self.apply_solutions( + timestamp, + cross_data_fb, + cross_weights_fb, + Some((auto_data_fb, auto_weights_fb)), + &chan_freqs, + ); + } + } else { + { + for (×tamp, cross_data_fb, cross_weights_fb) in izip!( + timeblock.timestamps.iter(), + unaveraged_cross_data_tfb.outer_iter_mut(), + unaveraged_cross_weights_tfb.outer_iter_mut(), + ) { + self.apply_solutions( + timestamp, + cross_data_fb, + cross_weights_fb, + None, + &chan_freqs, + ); + } + } + } + } + + // Now that solutions have been applied, we can average the data + // into the supplied arrays. + debug!("Averaging input data from timeblock {}", timeblock.index); + vis_average( + unaveraged_cross_data_tfb.view(), + cross_data_fb, + unaveraged_cross_weights_tfb.view(), + cross_weights_fb, + &self.spw.flagged_chanblock_indices, + ); + if let ( + Some((mut auto_data_fb, mut auto_weights_fb)), + Some((unaveraged_auto_data_tfb, unaveraged_auto_weights_tfb)), + ) = (autos_fb, unaveraged_autos) + { + vis_average( + unaveraged_auto_data_tfb.view(), + auto_data_fb.view_mut(), + unaveraged_auto_weights_tfb.view(), + auto_weights_fb.view_mut(), + &self.spw.flagged_chanblock_indices, + ); + }; + } else { + // Not averaging; read the data directly into the supplied arrays. + let timestamp = *timeblock.timestamps.first(); + let timestep = *timeblock.timesteps.first(); + debug!("Reading timestamp {}", timestamp.to_gpst_seconds()); + self.read_timestep( + timestep, + cross_data_fb.view_mut(), + cross_weights_fb.view_mut(), + autos_fb.as_mut().map(|(auto_data_fb, auto_weights_fb)| { + (auto_data_fb.view_mut(), auto_weights_fb.view_mut()) + }), + &self.spw.flagged_chan_indices, + )?; + + // Should we continue? + if error.load() { + return Ok(()); + } + + // Apply calibration solutions, if they're supplied. + if self.solutions.is_some() { + debug!("Applying calibration solutions to input data from timestep {timestep}"); + self.apply_solutions( + timestamp, + cross_data_fb, + cross_weights_fb.view_mut(), + autos_fb, + &self + .spw + .chanblocks + .iter() + .map(|c| c.freq) + .collect::>(), + ); + } + } + + // Should we continue? + if error.load() { + return Ok(()); + } + + Ok(()) + } + + fn read_timestep( + &self, + timestep: usize, + mut cross_data_fb: ArrayViewMut2>, + mut cross_weights_fb: ArrayViewMut2, + autos_fb: Option<(ArrayViewMut2>, ArrayViewMut2)>, + flagged_channels: &HashSet, + ) -> Result<(), VisReadError> { + let obs_context = self.get_obs_context(); + + match (autos_fb, obs_context.autocorrelations_present) { + (Some((mut auto_data_fb, mut auto_weights_fb)), true) => { + debug!("Reading crosses and autos for timestep {timestep}"); + + self.vis_reader.read_crosses_and_autos( + cross_data_fb.view_mut(), + cross_weights_fb.view_mut(), + auto_data_fb.view_mut(), + auto_weights_fb.view_mut(), + timestep, + &self.tile_baseline_flags, + flagged_channels, + )?; + + if self.ignore_weights { + cross_weights_fb.fill(1.0); + auto_weights_fb.fill(1.0); + } + } + + // Otherwise, just read the crosses. + _ => { + debug!("Reading crosses for timestep {timestep}"); + + self.vis_reader.read_crosses( + cross_data_fb.view_mut(), + cross_weights_fb.view_mut(), + timestep, + &self.tile_baseline_flags, + flagged_channels, + )?; + + if self.ignore_weights { + cross_weights_fb.fill(1.0); + } + } + } + + Ok(()) + } + + fn apply_solutions( + &self, + timestamp: Epoch, + mut cross_data_fb: ArrayViewMut2>, + mut cross_weights_fb: ArrayViewMut2, + mut autos_fb: Option<(ArrayViewMut2>, ArrayViewMut2)>, + chan_freqs: &[f64], + ) { + assert_eq!(cross_data_fb.dim(), cross_weights_fb.dim()); + assert_eq!(cross_data_fb.len_of(Axis(0)), chan_freqs.len()); + let solutions = match self.solutions.as_ref() { + Some(s) => s, + None => return, + }; + let obs_context = self.get_obs_context(); + let solution_freqs = solutions.chanblock_freqs.as_ref(); + // If there aren't any solution frequencies, we can only apply solutions + // to equally-sized arrays (i.e. if the incoming data and the solutions + // have 768 channels, then we're OK, otherwise we don't know how to map + // the solutions). Note that in this scenario, this assumes that the + // frequencies corresponding to the solutions are the same as what's in + // the data, but there's no way of checking. + if solution_freqs.is_none() + && cross_data_fb.len_of(Axis(0)) + self.spw.flagged_chanblock_indices.len() + != solutions.di_jones.len_of(Axis(2)) + { + panic!("Cannot apply calibration solutions to unequal sized data"); + } + + let timestamps = &obs_context.timestamps; + let span = *timestamps.last() - *timestamps.first(); + let timestamp_fraction = ((timestamp - *timestamps.first()).to_seconds() + / span.to_seconds()) + // Stop stupid values. + .clamp(0.0, 0.99); + + // Find solutions corresponding to this timestamp. + let sols = solutions.get_timeblock(timestamp, timestamp_fraction); + // Now make a lookup vector for the channels. This is better than + // searching for the right solution channel for each channel below (we + // use more memory but avoid a quadratic-complexity algorithm). + let solution_freq_indices: Option> = + solution_freqs.as_ref().map(|solution_freqs| { + chan_freqs + .iter() + .map(|freq| { + // Find the nearest solution freq to our data freq. + let mut best = f64::INFINITY; + let mut i_sol_freq = 0; + for (i, &sol_freq) in solution_freqs.iter().enumerate() { + let this_diff = (sol_freq - freq).abs(); + if this_diff < best { + best = this_diff; + i_sol_freq = i; + } else { + // Because the frequencies are always + // ascendingly sorted, if the frequency + // difference is getting bigger, we can break + // early. + break; + } + } + i_sol_freq + }) + .collect() + }); + + for (i_baseline, (mut cross_data_f, mut cross_weights_f)) in cross_data_fb + .axis_iter_mut(Axis(1)) + .zip_eq(cross_weights_fb.axis_iter_mut(Axis(1))) + .enumerate() + { + let (tile1, tile2) = self.tile_baseline_flags.unflagged_cross_baseline_to_tile_map + .get(&i_baseline) + .copied() + .unwrap_or_else(|| { + panic!("Couldn't find baseline index {i_baseline} in unflagged_cross_baseline_to_tile_map") + }); + + if let Some(solution_freq_indices) = solution_freq_indices.as_ref() { + cross_data_f + .iter_mut() + .zip_eq(cross_weights_f.iter_mut()) + .zip_eq(solution_freq_indices.iter().copied()) + .for_each(|((vis_data, vis_weight), i_sol_freq)| { + // Get the solutions for both tiles and apply them. + let sol1 = sols[(tile1, i_sol_freq)]; + let sol2 = sols[(tile2, i_sol_freq)]; + + // One of the tiles doesn't have a solution; flag. + if sol1.any_nan() || sol2.any_nan() { + *vis_weight = -vis_weight.abs(); + *vis_data = Jones::default(); + } else { + // Promote the data before demoting it again. + let d: Jones = Jones::from(*vis_data); + *vis_data = Jones::from((sol1 * d) * sol2.h()); + } + }); + } else { + // Get the solutions for both tiles and apply them. + let sols_tile1 = sols.slice(s![tile1, ..]); + let sols_tile2 = sols.slice(s![tile2, ..]); + izip!( + (0..), + cross_data_f.iter_mut(), + cross_weights_f.iter_mut(), + sols_tile1.iter(), + sols_tile2.iter() + ) + .for_each(|(i_chan, vis_data, vis_weight, sol1, sol2)| { + // One of the tiles doesn't have a solution; flag. + if sol1.any_nan() || sol2.any_nan() { + *vis_weight = -vis_weight.abs(); + *vis_data = Jones::default(); + } else { + if self.spw.flagged_chan_indices.contains(&i_chan) { + // The channel is flagged, but we still have a solution for it. + *vis_weight = -vis_weight.abs(); + } + // Promote the data before demoting it again. + let d: Jones = Jones::from(*vis_data); + *vis_data = Jones::from((*sol1 * d) * sol2.h()); + } + }); + } + } + + if let Some((auto_data_fb, auto_weights_fb)) = autos_fb.as_mut() { + for (i_tile, (mut auto_data_f, mut auto_weights_f)) in auto_data_fb + .axis_iter_mut(Axis(1)) + .zip_eq(auto_weights_fb.axis_iter_mut(Axis(1))) + .enumerate() + { + let i_tile = self + .tile_baseline_flags + .unflagged_auto_index_to_tile_map + .get(&i_tile) + .copied() + .unwrap_or_else(|| { + panic!( + "Couldn't find auto index {i_tile} in unflagged_auto_index_to_tile_map" + ) + }); + + if let Some(solution_freq_indices) = solution_freq_indices.as_ref() { + auto_data_f + .iter_mut() + .zip_eq(auto_weights_f.iter_mut()) + .zip_eq(solution_freq_indices.iter().copied()) + .for_each(|((vis_data, vis_weight), i_sol_freq)| { + // Get the solutions for the tile and apply it twice. + let sol = sols[(i_tile, i_sol_freq)]; + + // No solution; flag. + if sol.any_nan() { + *vis_weight = -vis_weight.abs(); + *vis_data = Jones::default(); + } else { + // Promote the data before demoting it again. + let d: Jones = Jones::from(*vis_data); + *vis_data = Jones::from((sol * d) * sol.h()); + } + }); + } else { + // Get the solutions for the tile and apply it twice. + let sols = sols.slice(s![i_tile, ..]); + izip!( + (0..), + auto_data_f.iter_mut(), + auto_weights_f.iter_mut(), + sols.iter() + ) + .for_each(|(i_chan, vis_data, vis_weight, sol)| { + // No solution; flag. + if sol.any_nan() { + *vis_weight = -vis_weight.abs(); + *vis_data = Jones::default(); + } else { + if self.spw.flagged_chan_indices.contains(&i_chan) { + // The channel is flagged, but we still have a solution for it. + *vis_weight = -vis_weight.abs(); + } + // Promote the data before demoting it again. + let d: Jones = Jones::from(*vis_data); + *vis_data = Jones::from((*sol * d) * sol.h()); + } + }); + } + } + } + + debug!("Finished applying solutions"); + } +} diff --git a/src/params/mod.rs b/src/params/mod.rs new file mode 100644 index 00000000..6ec39908 --- /dev/null +++ b/src/params/mod.rs @@ -0,0 +1,52 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Parameters that are kept modular to be used in multiple aspects of +//! `hyperdrive`. +//! +//! The code here is kind of "mirroring" the code within the `cli` module; the +//! idea is that `cli` is unparsed, user-facing code, whereas parameters have +//! been parsed and are ready to be used directly. The code here should be +//! public to the entire `hyperdrive` crate. + +mod di_calibration; +mod input_vis; +mod solutions_apply; +mod vis_convert; +mod vis_simulate; +mod vis_subtract; + +#[cfg(test)] +pub(crate) use di_calibration::CalVis; +pub(crate) use di_calibration::{DiCalParams, DiCalibrateError}; +pub(crate) use input_vis::InputVisParams; +pub(crate) use solutions_apply::SolutionsApplyParams; +pub(crate) use vis_convert::{VisConvertError, VisConvertParams}; +pub(crate) use vis_simulate::{VisSimulateError, VisSimulateParams}; +pub(crate) use vis_subtract::{VisSubtractError, VisSubtractParams}; + +use std::{num::NonZeroUsize, path::PathBuf}; + +use vec1::Vec1; + +use crate::{averaging::Timeblock, io::write::VisOutputType}; + +pub(crate) struct OutputVisParams { + pub(crate) output_files: Vec1<(PathBuf, VisOutputType)>, + pub(crate) output_time_average_factor: NonZeroUsize, + pub(crate) output_freq_average_factor: NonZeroUsize, + pub(crate) output_timeblocks: Vec1, + + /// Rather than writing out the entire input bandwidth, write out only the + /// smallest contiguous band. e.g. Typical 40 kHz MWA data has 768 channels, + /// but the first 2 and last 2 channels are usually flagged. Turning this + /// option on means that 764 channels would be written out instead of 768. + /// Note that other flagged channels in the band are unaffected, because the + /// data written out must be contiguous. + pub(crate) write_smallest_contiguous_band: bool, +} + +pub(crate) struct ModellingParams { + pub(crate) apply_precession: bool, +} diff --git a/src/params/solutions_apply.rs b/src/params/solutions_apply.rs new file mode 100644 index 00000000..b8ad9e85 --- /dev/null +++ b/src/params/solutions_apply.rs @@ -0,0 +1,26 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +use super::{InputVisParams, OutputVisParams, VisConvertError, VisConvertParams}; + +pub(crate) struct SolutionsApplyParams { + pub(crate) input_vis_params: InputVisParams, + pub(crate) output_vis_params: OutputVisParams, +} + +impl SolutionsApplyParams { + pub(crate) fn run(&self) -> Result<(), VisConvertError> { + let Self { + input_vis_params, + output_vis_params, + } = self; + + assert!( + input_vis_params.solutions.is_some(), + "No calibration solutions are in the input vis params; this shouldn't be possible" + ); + + VisConvertParams::run_inner(input_vis_params, output_vis_params) + } +} diff --git a/src/params/vis_convert.rs b/src/params/vis_convert.rs new file mode 100644 index 00000000..947c35de --- /dev/null +++ b/src/params/vis_convert.rs @@ -0,0 +1,238 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +use std::thread::{self, ScopedJoinHandle}; + +use crossbeam_channel::bounded; +use crossbeam_utils::atomic::AtomicCell; +use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; +use itertools::Itertools; +use log::{debug, info}; +use ndarray::prelude::*; +use scopeguard::defer_on_unwind; + +use super::{InputVisParams, OutputVisParams}; +use crate::{ + io::{ + read::VisReadError, + write::{write_vis, VisTimestep}, + }, + PROGRESS_BARS, +}; + +pub(crate) struct VisConvertParams { + pub(crate) input_vis_params: InputVisParams, + pub(crate) output_vis_params: OutputVisParams, +} + +impl VisConvertParams { + pub(crate) fn run(&self) -> Result<(), VisConvertError> { + let Self { + input_vis_params, + output_vis_params, + } = self; + + Self::run_inner(input_vis_params, output_vis_params) + } + + // This function does the actual work, and only exists because + // `SolutionsApplyParams` is doing the exact same thing, but I can't work + // out how to make a `&VisConvertParams` from `&InputVisParams` and + // `&OutputVisParams` (if it's possible). + pub(super) fn run_inner( + input_vis_params: &InputVisParams, + output_vis_params: &OutputVisParams, + ) -> Result<(), VisConvertError> { + let obs_context = input_vis_params.get_obs_context(); + + // Channel for transferring visibilities from the reader to the writer. + let (tx_data, rx_data) = bounded(3); + + // Progress bars. + let multi_progress = MultiProgress::with_draw_target(if PROGRESS_BARS.load() { + ProgressDrawTarget::stdout() + } else { + ProgressDrawTarget::hidden() + }); + let pb = ProgressBar::new(input_vis_params.timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:18}: [{wide_bar:.blue}] {pos:2}/{len:2} timeblocks ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Reading data"); + let read_progress = multi_progress.add(pb); + let pb = ProgressBar::new(output_vis_params.output_timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:18}: [{wide_bar:.blue}] {pos:2}/{len:2} timeblocks ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Writing data"); + let write_progress = multi_progress.add(pb); + + // Use a variable to track whether any threads have an issue. + let error = AtomicCell::new(false); + + info!("Reading input data and writing"); + let scoped_threads_result: Result = thread::scope(|scope| { + // Input visibility-data reading thread. + let data_handle: ScopedJoinHandle> = thread::Builder::new() + .name("read".to_string()) + .spawn_scoped(scope, || { + // If a panic happens, update our atomic error. + defer_on_unwind! { error.store(true); } + read_progress.tick(); + + let num_unflagged_tiles = input_vis_params.get_num_unflagged_tiles(); + let num_unflagged_cross_baselines = + (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + let cross_vis_shape = ( + input_vis_params.spw.chanblocks.len(), + num_unflagged_cross_baselines, + ); + let auto_vis_shape = + (input_vis_params.spw.chanblocks.len(), num_unflagged_tiles); + + for timeblock in &input_vis_params.timeblocks { + let mut cross_data_fb = Array2::zeros(cross_vis_shape); + let mut cross_weights_fb = Array2::zeros(cross_vis_shape); + let mut autos_fb = if input_vis_params.using_autos { + Some((Array2::zeros(auto_vis_shape), Array2::zeros(auto_vis_shape))) + } else { + None + }; + + let result = input_vis_params.read_timeblock( + timeblock, + cross_data_fb.view_mut(), + cross_weights_fb.view_mut(), + autos_fb.as_mut().map(|(d, w)| (d.view_mut(), w.view_mut())), + &error, + ); + // If the result of reading data was an error, allow the + // other threads to see this so they can abandon their work + // early. + if result.is_err() { + error.store(true); + } + result?; + + // Send the data as timesteps. + match tx_data.send(VisTimestep { + cross_data_fb: cross_data_fb.into_shared(), + cross_weights_fb: cross_weights_fb.into_shared(), + autos: autos_fb.map(|(d, w)| (d.into_shared(), w.into_shared())), + timestamp: timeblock.median, + }) { + Ok(()) => (), + // If we can't send the message, it's because the + // channel has been closed on the other side. That + // should only happen because the writer has exited due + // to error; in that case, just exit this thread. + Err(_) => return Ok(()), + } + + read_progress.inc(1); + } + + drop(tx_data); + debug!("Finished reading"); + read_progress.abandon_with_message("Finished reading visibilities"); + Ok(()) + }) + .expect("OS can create threads"); + + // Calibrated vis writing thread. + let write_handle = thread::Builder::new() + .name("write".to_string()) + .spawn_scoped(scope, || { + defer_on_unwind! { error.store(true); } + write_progress.tick(); + + // If we're not using autos, "disable" the + // `unflagged_tiles_iter` by making it not iterate over + // anything. + let total_num_tiles = if input_vis_params.using_autos { + obs_context.get_total_num_tiles() + } else { + 0 + }; + let unflagged_tiles_iter = (0..total_num_tiles) + .filter(|i_tile| { + !input_vis_params + .tile_baseline_flags + .flagged_tiles + .contains(i_tile) + }) + .map(|i_tile| (i_tile, i_tile)); + // Form (sorted) unflagged baselines from our cross- and + // auto-correlation baselines. + let unflagged_cross_and_auto_baseline_tile_pairs = input_vis_params + .tile_baseline_flags + .tile_to_unflagged_cross_baseline_map + .keys() + .copied() + .chain(unflagged_tiles_iter) + .sorted() + .collect::>(); + + let result = write_vis( + &output_vis_params.output_files, + obs_context.array_position, + obs_context.phase_centre, + obs_context.pointing_centre, + &obs_context.tile_xyzs, + &obs_context.tile_names, + obs_context.obsid, + &output_vis_params.output_timeblocks, + input_vis_params.time_res, + input_vis_params.dut1, + &input_vis_params.spw, + &unflagged_cross_and_auto_baseline_tile_pairs, + output_vis_params.output_time_average_factor, + output_vis_params.output_freq_average_factor, + input_vis_params.vis_reader.get_marlu_mwa_info().as_ref(), + output_vis_params.write_smallest_contiguous_band, + rx_data, + &error, + Some(write_progress), + ); + if result.is_err() { + error.store(true); + } + result + }) + .expect("OS can create threads"); + + // Join all thread handles. This propagates any errors and lets us + // know if any threads panicked, if panics aren't aborting as per + // the Cargo.toml. (It would be nice to capture the panic + // information, if it's possible, but I don't know how, so panics + // are currently aborting.) + data_handle.join().unwrap()?; + let write_message = write_handle.join().unwrap()?; + Ok(write_message) + }); + + // Propagate errors and print out the write message. + info!("{}", scoped_threads_result?); + + Ok(()) + } +} + +#[derive(thiserror::Error, Debug)] +pub(crate) enum VisConvertError { + #[error(transparent)] + VisRead(#[from] crate::io::read::VisReadError), + + #[error(transparent)] + VisWrite(#[from] crate::io::write::VisWriteError), + + #[error(transparent)] + IO(#[from] std::io::Error), +} diff --git a/src/params/vis_simulate.rs b/src/params/vis_simulate.rs new file mode 100644 index 00000000..c5077c03 --- /dev/null +++ b/src/params/vis_simulate.rs @@ -0,0 +1,328 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Generate sky-model visibilities from a sky-model source list. + +use std::{ + collections::HashSet, + num::NonZeroUsize, + thread::{self, ScopedJoinHandle}, +}; + +use crossbeam_channel::{bounded, Sender}; +use crossbeam_utils::atomic::AtomicCell; +use hifitime::{Duration, Epoch}; +use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; +use itertools::Itertools; +use log::info; +use marlu::{ + constants::{FREQ_WEIGHT_FACTOR, TIME_WEIGHT_FACTOR}, + Jones, LatLngHeight, MwaObsContext, RADec, XyzGeodetic, +}; +use mwalib::MetafitsContext; +use ndarray::ArcArray2; +use scopeguard::defer_on_unwind; +use thiserror::Error; +use vec1::Vec1; + +use crate::{ + averaging::channels_to_chanblocks, + beam::Beam, + context::Polarisations, + io::write::{write_vis, VisTimestep, VisWriteError}, + math::TileBaselineFlags, + model::{self, ModelError}, + params::{ModellingParams, OutputVisParams}, + srclist::SourceList, + PROGRESS_BARS, +}; + +/// Parameters needed to do sky-model visibility simulation. +pub(crate) struct VisSimulateParams { + /// Sky-model source list. + pub(crate) source_list: SourceList, + + /// mwalib metafits context + pub(crate) metafits: MetafitsContext, + + /// The output visibility files. + pub(crate) output_vis_params: OutputVisParams, + + /// The phase centre. + pub(crate) phase_centre: RADec, + + /// The fine channel frequencies \[Hz\]. + pub(crate) fine_chan_freqs: Vec1, + + /// The frequency resolution of the fine channels. + pub(crate) freq_res_hz: f64, + + /// The [`XyzGeodetic`] positions of the tiles. + pub(crate) tile_xyzs: Vec, + + /// The names of the tiles. + pub(crate) tile_names: Vec, + + /// Information on flagged tiles, baselines and mapping between indices. + pub(crate) tile_baseline_flags: TileBaselineFlags, + + /// Timestamps to be simulated. + pub(crate) timestamps: Vec1, + + pub(crate) time_res: Duration, + + /// Interface to beam code. + pub(crate) beam: Box, + + /// The Earth position of the interferometer. + pub(crate) array_position: LatLngHeight, + + /// UT1 - UTC. + pub(crate) dut1: Duration, + + /// Should we be precessing? + pub(crate) modelling_params: ModellingParams, +} + +impl VisSimulateParams { + pub(crate) fn run(&self) -> Result<(), VisSimulateError> { + let VisSimulateParams { + source_list, + metafits, + output_vis_params: + OutputVisParams { + output_files, + output_time_average_factor, + output_freq_average_factor, + output_timeblocks, + write_smallest_contiguous_band, + }, + phase_centre, + fine_chan_freqs, + freq_res_hz, + tile_xyzs, + tile_names, + tile_baseline_flags, + timestamps, + time_res, + beam, + array_position, + dut1, + modelling_params: ModellingParams { apply_precession }, + } = self; + + // Channel for writing simulated visibilities. + let (tx_model, rx_model) = bounded(5); + + // Progress bar. + let multi_progress = MultiProgress::with_draw_target(if PROGRESS_BARS.load() { + ProgressDrawTarget::stdout() + } else { + ProgressDrawTarget::hidden() + }); + let model_progress = multi_progress.add( + ProgressBar::new(timestamps.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Sky modelling"), + ); + let write_progress = multi_progress.add( + ProgressBar::new(output_timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timeblocks ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Model writing"), + ); + + // Generate the visibilities and write them out asynchronously. + let error = AtomicCell::new(false); + let scoped_threads_result: Result = thread::scope(|scope| { + // Modelling thread. + let model_handle: ScopedJoinHandle> = thread::Builder::new() + .name("model".to_string()) + .spawn_scoped(scope, || { + defer_on_unwind! { error.store(true); } + model_progress.tick(); + + let cross_vis_shape = ( + fine_chan_freqs.len(), + tile_baseline_flags + .unflagged_cross_baseline_to_tile_map + .len(), + ); + let weight_factor = (freq_res_hz / FREQ_WEIGHT_FACTOR) + * (time_res.to_seconds() / TIME_WEIGHT_FACTOR); + let result = model_thread( + &**beam, + source_list, + tile_xyzs, + tile_baseline_flags, + timestamps, + fine_chan_freqs, + *phase_centre, + *array_position, + *dut1, + *apply_precession, + cross_vis_shape, + weight_factor, + tx_model, + &error, + model_progress, + ); + if result.is_err() { + error.store(true); + } + result + }) + .expect("OS can create threads"); + + // Writing thread. + let write_handle: ScopedJoinHandle> = + thread::Builder::new() + .name("write".to_string()) + .spawn_scoped(scope, || { + defer_on_unwind! { error.store(true); } + write_progress.tick(); + + // Form (sorted) unflagged baselines from our cross- and + // auto-correlation baselines. + let unflagged_baseline_tile_pairs = tile_baseline_flags + .unflagged_cross_baseline_to_tile_map + .values() + .copied() + .sorted() + .collect::>(); + + let spw = &channels_to_chanblocks( + &fine_chan_freqs.mapped_ref(|f| *f as u64), + freq_res_hz.round() as u64, + NonZeroUsize::new(1).unwrap(), + &HashSet::new(), + )[0]; + let result = write_vis( + output_files, + *array_position, + *phase_centre, + None, + tile_xyzs, + tile_names, + Some(metafits.obs_id), + output_timeblocks, + *time_res, + *dut1, + spw, + &unflagged_baseline_tile_pairs, + *output_time_average_factor, + *output_freq_average_factor, + Some(&MwaObsContext::from_mwalib(metafits)), + *write_smallest_contiguous_band, + rx_model, + &error, + Some(write_progress), + ); + if result.is_err() { + error.store(true); + } + result + }) + .expect("OS can create threads"); + + // Join all thread handles. This propagates any errors and lets us + // know if any threads panicked, if panics aren't aborting as per + // the Cargo.toml. (It would be nice to capture the panic + // information, if it's possible, but I don't know how, so panics + // are currently aborting.) + model_handle.join().unwrap()?; + let write_message = write_handle.join().unwrap()?; + Ok(write_message) + }); + + // Propagate errors and print out the write message. + info!("{}", scoped_threads_result?); + + Ok(()) + } +} + +#[allow(clippy::too_many_arguments)] +fn model_thread( + beam: &dyn Beam, + source_list: &SourceList, + unflagged_tile_xyzs: &[XyzGeodetic], + tile_baseline_flags: &TileBaselineFlags, + timestamps: &[Epoch], + fine_chan_freqs: &[f64], + phase_centre: RADec, + array_position: LatLngHeight, + dut1: Duration, + apply_precession: bool, + vis_shape: (usize, usize), + weight_factor: f64, + tx: Sender, + error: &AtomicCell, + progress_bar: ProgressBar, +) -> Result<(), ModelError> { + let modeller = model::new_sky_modeller( + beam, + source_list, + Polarisations::XX_XY_YX_YY, + unflagged_tile_xyzs, + fine_chan_freqs, + &tile_baseline_flags.flagged_tiles, + phase_centre, + array_position.longitude_rad, + array_position.latitude_rad, + dut1, + apply_precession, + )?; + + for ×tamp in timestamps { + let mut cross_data_fb: ArcArray2> = ArcArray2::zeros(vis_shape); + + modeller.model_timestep_with(timestamp, cross_data_fb.view_mut())?; + + // Should we continue? + if error.load() { + return Ok(()); + } + + match tx.send(VisTimestep { + cross_data_fb, + cross_weights_fb: ArcArray2::from_elem(vis_shape, weight_factor as f32), + autos: None, + timestamp, + }) { + Ok(()) => (), + // If we can't send the message, it's because the channel + // has been closed on the other side. That should only + // happen because the writer has exited due to error; in + // that case, just exit this thread. + Err(_) => return Ok(()), + } + + progress_bar.inc(1); + } + + progress_bar.abandon_with_message("Finished generating sky model"); + Ok(()) +} + +#[derive(Error, Debug)] +pub(crate) enum VisSimulateError { + #[error(transparent)] + VisWrite(#[from] crate::io::write::VisWriteError), + + #[error(transparent)] + Model(#[from] crate::model::ModelError), + + #[error(transparent)] + IO(#[from] std::io::Error), +} diff --git a/src/params/vis_subtract.rs b/src/params/vis_subtract.rs new file mode 100644 index 00000000..56d4be13 --- /dev/null +++ b/src/params/vis_subtract.rs @@ -0,0 +1,351 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +//! Given input data, a sky model and specific sources, subtract those specific +//! sources from the input data and write them out. + +use std::thread::{self, ScopedJoinHandle}; + +use crossbeam_channel::{bounded, Receiver, Sender}; +use crossbeam_utils::atomic::AtomicCell; +use indicatif::{MultiProgress, ProgressBar, ProgressDrawTarget, ProgressStyle}; +use itertools::Itertools; +use log::{debug, info}; +use marlu::Jones; +use ndarray::prelude::*; +use scopeguard::defer_on_unwind; + +use super::{InputVisParams, ModellingParams, OutputVisParams}; +use crate::{ + beam::Beam, + io::{ + read::VisReadError, + write::{write_vis, VisTimestep}, + }, + model::{new_sky_modeller, ModelError}, + srclist::SourceList, + PROGRESS_BARS, +}; + +pub(crate) struct VisSubtractParams { + pub(crate) input_vis_params: InputVisParams, + pub(crate) output_vis_params: OutputVisParams, + pub(crate) beam: Box, + pub(crate) source_list: SourceList, + pub(crate) modelling_params: ModellingParams, +} + +impl VisSubtractParams { + pub(crate) fn run(&self) -> Result<(), VisSubtractError> { + // Expose all the struct fields to ensure they're all used. + let VisSubtractParams { + input_vis_params, + output_vis_params, + beam, + source_list, + modelling_params: ModellingParams { apply_precession }, + } = self; + + let obs_context = input_vis_params.get_obs_context(); + let num_unflagged_tiles = input_vis_params.get_num_unflagged_tiles(); + let num_unflagged_cross_baselines = (num_unflagged_tiles * (num_unflagged_tiles - 1)) / 2; + let vis_shape = ( + input_vis_params.spw.chanblocks.len(), + num_unflagged_cross_baselines, + ); + + // Channel for modelling and subtracting. + let (tx_model, rx_model) = bounded(5); + // Channel for writing subtracted visibilities. + let (tx_write, rx_write) = bounded(5); + + // Progress bars. + let multi_progress = MultiProgress::with_draw_target(if PROGRESS_BARS.load() { + ProgressDrawTarget::stdout() + } else { + ProgressDrawTarget::hidden() + }); + let pb = ProgressBar::new(input_vis_params.timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Reading data"); + let read_progress = multi_progress.add(pb); + let pb = ProgressBar::new(input_vis_params.timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timesteps ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Sky modelling"); + let model_progress = multi_progress.add(pb); + let pb = ProgressBar::new(output_vis_params.output_timeblocks.len() as _) + .with_style( + ProgressStyle::default_bar() + .template("{msg:17}: [{wide_bar:.blue}] {pos:2}/{len:2} timeblocks ({elapsed_precise}<{eta_precise})").unwrap() + .progress_chars("=> "), + ) + .with_position(0) + .with_message("Subtracted writing"); + let write_progress = multi_progress.add(pb); + + // Use a variable to track whether any threads have an issue. + let error = AtomicCell::new(false); + + info!("Reading input data, sky modelling, and writing"); + let scoped_threads_result: Result = thread::scope(|scope| { + // Input visibility-data reading thread. + let data_handle: thread::ScopedJoinHandle> = + thread::Builder::new() + .name("read".to_string()) + .spawn_scoped(scope, || { + // If a panic happens, update our atomic error. + defer_on_unwind! { error.store(true); } + read_progress.tick(); + + for timeblock in &input_vis_params.timeblocks { + // Read data to fill the buffer, pausing when the buffer is + // full to write it all out. + let mut cross_data_fb = Array2::zeros(vis_shape); + let mut cross_weights_fb = Array2::zeros(vis_shape); + + let result = self.input_vis_params.read_timeblock( + timeblock, + cross_data_fb.view_mut(), + cross_weights_fb.view_mut(), + None, + &error, + ); + + // If the result of reading data was an error, allow the other + // threads to see this so they can abandon their work early. + if result.is_err() { + error.store(true); + } + result?; + + // Should we continue? + if error.load() { + return Ok(()); + } + + match tx_model.send(VisTimestep { + cross_data_fb: cross_data_fb.into_shared(), + cross_weights_fb: cross_weights_fb.into_shared(), + autos: None, + timestamp: timeblock.median, + }) { + Ok(()) => (), + // If we can't send the message, it's because the channel + // has been closed on the other side. That should only + // happen because the writer has exited due to error; in + // that case, just exit this thread. + Err(_) => return Ok(()), + } + + read_progress.inc(1); + } + + debug!("Finished reading"); + read_progress.abandon_with_message("Finished reading visibilities"); + drop(tx_model); + Ok(()) + }) + .expect("OS can create threads"); + + // Sky-model generation and subtraction thread. + let model_handle: ScopedJoinHandle> = thread::Builder::new() + .name("model".to_string()) + .spawn_scoped(scope, || { + defer_on_unwind! { error.store(true); } + model_progress.tick(); + + let result = model_thread( + &**beam, + source_list, + input_vis_params, + *apply_precession, + vis_shape, + rx_model, + tx_write, + &error, + model_progress, + ); + if result.is_err() { + error.store(true); + } + result + }) + .expect("OS can create threads"); + + // Subtracted vis writing thread. + let write_handle = thread::Builder::new() + .name("write".to_string()) + .spawn_scoped(scope, || { + defer_on_unwind! { error.store(true); } + write_progress.tick(); + + let result = write_vis( + &output_vis_params.output_files, + obs_context.array_position, + obs_context.phase_centre, + obs_context.pointing_centre, + &obs_context.tile_xyzs, + &obs_context.tile_names, + obs_context.obsid, + &output_vis_params.output_timeblocks, + input_vis_params.time_res, + input_vis_params.dut1, + &input_vis_params.spw, + &input_vis_params + .tile_baseline_flags + .unflagged_cross_baseline_to_tile_map + .values() + .copied() + .sorted() + .collect::>(), + output_vis_params.output_time_average_factor, + output_vis_params.output_freq_average_factor, + input_vis_params.vis_reader.get_marlu_mwa_info().as_ref(), + output_vis_params.write_smallest_contiguous_band, + rx_write, + &error, + Some(write_progress), + ); + if result.is_err() { + error.store(true); + } + result + }) + .expect("OS can create threads"); + + // Join all thread handles. This propagates any errors and lets us know + // if any threads panicked, if panics aren't aborting as per the + // Cargo.toml. (It would be nice to capture the panic information, if + // it's possible, but I don't know how, so panics are currently + // aborting.) + data_handle.join().unwrap()?; + model_handle.join().unwrap()?; + let write_message = write_handle.join().unwrap()?; + Ok(write_message) + }); + + // Propagate errors and print out the write message. + info!("{}", scoped_threads_result?); + + Ok(()) + } +} + +#[allow(clippy::too_many_arguments)] +fn model_thread( + beam: &dyn Beam, + source_list: &SourceList, + input_vis_params: &InputVisParams, + apply_precession: bool, + vis_shape: (usize, usize), + rx: Receiver, + tx: Sender, + error: &AtomicCell, + progress_bar: ProgressBar, +) -> Result<(), ModelError> { + let obs_context = input_vis_params.get_obs_context(); + let unflagged_tile_xyzs = obs_context + .tile_xyzs + .iter() + .enumerate() + .filter(|(i, _)| { + !input_vis_params + .tile_baseline_flags + .flagged_tiles + .contains(i) + }) + .map(|(_, xyz)| *xyz) + .collect::>(); + let freqs = input_vis_params + .spw + .chanblocks + .iter() + .map(|c| c.freq) + .collect::>(); + let modeller = new_sky_modeller( + beam, + source_list, + obs_context.polarisations, + &unflagged_tile_xyzs, + &freqs, + &input_vis_params.tile_baseline_flags.flagged_tiles, + obs_context.phase_centre, + obs_context.array_position.longitude_rad, + obs_context.array_position.latitude_rad, + input_vis_params.dut1, + apply_precession, + )?; + + // Recycle an array for model visibilities. + let mut vis_model_fb = Array2::zeros(vis_shape); + + // Iterate over the incoming data. + for VisTimestep { + mut cross_data_fb, + cross_weights_fb, + autos, + timestamp, + } in rx.iter() + { + debug!("Modelling timestamp {}", timestamp.to_gpst_seconds()); + modeller.model_timestep_with(timestamp, vis_model_fb.view_mut())?; + cross_data_fb + .iter_mut() + .zip_eq(vis_model_fb.iter()) + .for_each(|(vis_data, vis_model)| { + *vis_data = + Jones::from(Jones::::from(*vis_data) - Jones::::from(*vis_model)); + }); + vis_model_fb.fill(Jones::default()); + + // Should we continue? + if error.load() { + return Ok(()); + } + + match tx.send(VisTimestep { + cross_data_fb, + cross_weights_fb, + autos, + timestamp, + }) { + Ok(()) => (), + Err(_) => return Ok(()), + } + progress_bar.inc(1); + } + + debug!("Finished modelling"); + progress_bar.abandon_with_message("Finished generating sky model"); + Ok(()) +} + +#[derive(thiserror::Error, Debug)] +pub(crate) enum VisSubtractError { + #[error(transparent)] + VisRead(#[from] crate::io::read::VisReadError), + + #[error(transparent)] + VisWrite(#[from] crate::io::write::VisWriteError), + + #[error(transparent)] + Model(#[from] crate::model::ModelError), + + #[error(transparent)] + IO(#[from] std::io::Error), + + #[cfg(feature = "cuda")] + #[error(transparent)] + Cuda(#[from] crate::cuda::CudaError), +} diff --git a/src/solutions/ao.rs b/src/solutions/ao.rs index 84ee78cb..bb17d12a 100644 --- a/src/solutions/ao.rs +++ b/src/solutions/ao.rs @@ -111,7 +111,7 @@ pub(crate) fn read(file: &Path) -> Result ( Some(vec1![s]), Some(vec1![e]), - Some(vec1![average_epoch(&[s, e])]), + Some(vec1![average_epoch([s, e])]), ), (Some(s), None) => (Some(vec1![s]), None, None), (None, Some(e)) => (None, Some(vec1![e]), None), diff --git a/src/solutions/rts/mod.rs b/src/solutions/rts/mod.rs index 32707673..d203e488 100644 --- a/src/solutions/rts/mod.rs +++ b/src/solutions/rts/mod.rs @@ -27,7 +27,7 @@ use std::{ }; use itertools::Itertools; -use log::{debug, trace, warn}; +use log::{debug, trace}; use marlu::Jones; use mwalib::MetafitsContext; use ndarray::prelude::*; @@ -36,7 +36,7 @@ use thiserror::Error; use vec1::Vec1; use super::CalibrationSolutions; -use crate::io::get_all_matches_from_glob; +use crate::{cli::Warn, io::get_all_matches_from_glob}; lazy_static::lazy_static! { static ref NODE_NUM: Regex = Regex::new(r"node(\d{3})\.dat$").unwrap(); @@ -174,9 +174,16 @@ fn read_no_files( let available_num_coarse_chans = receiver_channel_to_data.len(); let total_num_coarse_chans = context.num_metafits_coarse_chans; if available_num_coarse_chans != total_num_coarse_chans { - warn!("The number of coarse channels expected by the metafits ({total_num_coarse_chans})"); - warn!(" wasn't equal to the number of node files ({available_num_coarse_chans})."); - warn!(" We will use NaNs for the missing coarse channels."); + [ + format!( + "The number of coarse channels expected by the metafits ({total_num_coarse_chans})" + ) + .into(), + format!("wasn't equal to the number of node files ({available_num_coarse_chans}).") + .into(), + "We will use NaNs for the missing coarse channels.".into(), + ] + .warn(); }; // Check that the number of tiles is the same everywhere. diff --git a/src/srclist/ao/read.rs b/src/srclist/ao/read.rs index ec146ae4..83687656 100644 --- a/src/srclist/ao/read.rs +++ b/src/srclist/ao/read.rs @@ -7,13 +7,15 @@ //! The code here is probably incomplete, but it should work for the majority of //! source lists. -use log::warn; use marlu::{sexagesimal::*, RADec}; use vec1::vec1; -use crate::srclist::{ - error::{ReadSourceListAOError, ReadSourceListCommonError, ReadSourceListError}, - ComponentType, FluxDensity, FluxDensityType, Source, SourceComponent, SourceList, +use crate::{ + cli::Warn, + srclist::{ + error::{ReadSourceListAOError, ReadSourceListCommonError, ReadSourceListError}, + ComponentType, FluxDensity, FluxDensityType, Source, SourceComponent, SourceList, + }, }; /// Parse a buffer containing an AO-style source list into a [SourceList]. @@ -57,8 +59,8 @@ pub(crate) fn parse_source_list( Some("skymodel") => { match items.next() { Some("fileformat") => (), - Some(s) => warn!("Malformed AO source list 'skymodel' line; expected 'fileformat', got '{s}'"), - None => warn!("Malformed AO source list 'skymodel' line; expected 'fileformat', got nothing"), + Some(s) => format!("Malformed AO source list 'skymodel' line; expected 'fileformat', got '{s}'").warn(), + None => "Malformed AO source list 'skymodel' line; expected 'fileformat', got nothing".warn(), } match items.next() { @@ -66,11 +68,13 @@ pub(crate) fn parse_source_list( // Stokes. Some("1.0") => one_point_oh = true, Some("1.1") => (), - Some(v) => { - warn!("Unrecognised AO source list fileformat '{v}'; pretending it is 1.1") - } + Some(v) => format!( + "Unrecognised AO source list fileformat '{v}'; pretending it is 1.1" + ) + .warn(), None => { - warn!("This AO source list does not specify a fileformat; pretending it is 1.1") + "This AO source list does not specify a fileformat; pretending it is 1.1" + .warn() } } } diff --git a/src/srclist/ao/write.rs b/src/srclist/ao/write.rs index 4590edaf..a1b28e5f 100644 --- a/src/srclist/ao/write.rs +++ b/src/srclist/ao/write.rs @@ -4,10 +4,13 @@ //! Writing "André Offringa"-style text source lists. -use log::{debug, warn}; +use log::debug; use marlu::sexagesimal::*; -use crate::srclist::{error::WriteSourceListError, ComponentType, FluxDensityType, SourceList}; +use crate::{ + cli::Warn, + srclist::{error::WriteSourceListError, ComponentType, FluxDensityType, SourceList}, +}; pub(crate) fn write_source_list( buf: &mut T, @@ -34,8 +37,11 @@ pub(crate) fn write_source_list( }); if any_curved_power_laws { if !warned_curved_power_laws { - warn!("AO source lists don't support curved-power-law flux densities."); - warn!("Any sources containing them won't be written."); + [ + "AO source lists don't support curved-power-law flux densities.".into(), + "Any sources containing them won't be written.".into(), + ] + .warn(); warned_curved_power_laws = true; } debug!("Ignoring source {name} as it contains a curved power law"); @@ -43,8 +49,11 @@ pub(crate) fn write_source_list( } if any_shapelets { if !warned_shapelets { - warn!("AO source lists don't support shapelet components."); - warn!("Any sources containing them won't be written."); + [ + "AO source lists don't support shapelet components.".into(), + "Any sources containing them won't be written.".into(), + ] + .warn(); warned_shapelets = true; } debug!("Ignoring source {name} as it contains a shapelet component"); @@ -137,7 +146,7 @@ pub(crate) fn write_source_list( if let Some(num_sources) = num_sources { if num_sources > num_written_sources { - warn!("Couldn't write the requested number of sources ({num_sources}): wrote {num_written_sources}") + format!("Couldn't write the requested number of sources ({num_sources}): wrote {num_written_sources}").warn() } } diff --git a/src/srclist/error.rs b/src/srclist/error.rs index f0aa7b17..5f54b24f 100644 --- a/src/srclist/error.rs +++ b/src/srclist/error.rs @@ -4,8 +4,8 @@ use thiserror::Error; -use crate::srclist::{ - HYPERDRIVE_SOURCE_LIST_FILE_TYPES_COMMA_SEPARATED, SOURCE_LIST_TYPES_COMMA_SEPARATED, +use crate::{ + beam::BeamError, io::GlobError, srclist::HYPERDRIVE_SOURCE_LIST_FILE_TYPES_COMMA_SEPARATED, }; /// Errors associated with reading in any kind of source list. @@ -40,6 +40,24 @@ pub(crate) enum ReadSourceListError { #[error("Could not deserialise the contents as yaml or json.\n\nyaml error: {yaml_err}\n\njson error: {json_err}")] FailedToDeserialise { yaml_err: String, json_err: String }, + #[error("No sky-model source list file supplied")] + NoSourceList, + + #[error(transparent)] + Glob(#[from] GlobError), + + #[error("The number of specified sources was 0, or the size of the source list was 0")] + NoSources, + + #[error("After vetoing sources, none were left. Decrease the veto threshold, or supply more sources")] + NoSourcesAfterVeto, + + #[error("Tried to use {requested} sources, but only {available} sources were available after vetoing")] + VetoTooFewSources { requested: usize, available: usize }, + + #[error("Beam error when trying to veto the source list: {0}")] + Beam(#[from] BeamError), + #[error(transparent)] Common(#[from] ReadSourceListCommonError), @@ -285,12 +303,6 @@ pub(crate) enum WriteSourceListError { fd_type: &'static str, }, - #[error("Not enough information was provided to write the output source list. Please specify an output type.")] - NotEnoughInfo, - - #[error("Unrecognised source list type. Supported types are: {}", *SOURCE_LIST_TYPES_COMMA_SEPARATED)] - InvalidFormat, - #[error("'{0}' is an invalid file type for a hyperdrive-style source list; must have one of the following extensions: {}", *HYPERDRIVE_SOURCE_LIST_FILE_TYPES_COMMA_SEPARATED)] InvalidHyperdriveFormat(String), @@ -310,9 +322,6 @@ pub(crate) enum WriteSourceListError { #[derive(Error, Debug)] pub(crate) enum SrclistError { - #[error("No sources were left after vetoing; nothing left to do")] - NoSourcesAfterVeto, - #[error("Source list error: Need a metafits file to perform work, but none was supplied")] MissingMetafits, @@ -322,9 +331,6 @@ pub(crate) enum SrclistError { #[error(transparent)] WriteSourceList(#[from] WriteSourceListError), - #[error(transparent)] - Veto(#[from] super::VetoError), - #[error(transparent)] Beam(#[from] crate::beam::BeamError), diff --git a/src/srclist/hyperdrive/write.rs b/src/srclist/hyperdrive/write.rs index 3790afb1..ff8d218e 100644 --- a/src/srclist/hyperdrive/write.rs +++ b/src/srclist/hyperdrive/write.rs @@ -7,9 +7,11 @@ use std::collections::HashMap; use indexmap::IndexMap; -use log::warn; -use crate::srclist::{error::WriteSourceListError, SourceList}; +use crate::{ + cli::Warn, + srclist::{error::WriteSourceListError, SourceList}, +}; /// Write a [`SourceList`] to a yaml file. pub(crate) fn source_list_to_yaml( @@ -26,10 +28,11 @@ pub(crate) fn source_list_to_yaml( } if num_sources > sl.len() { - warn!( + format!( "Couldn't write the requested number of sources ({num_sources}): wrote {}", sl.len() ) + .warn() }; } else { serde_yaml::to_writer(buf, &sl)?; @@ -51,10 +54,11 @@ pub(crate) fn source_list_to_json( serde_json::to_writer_pretty(buf, &map)?; if num_sources > sl.len() { - warn!( + format!( "Couldn't write the requested number of sources ({num_sources}): wrote {}", sl.len() ) + .warn() }; } else { serde_json::to_writer_pretty(buf, &sl)?; diff --git a/src/srclist/mod.rs b/src/srclist/mod.rs index a69711bf..1cb95503 100644 --- a/src/srclist/mod.rs +++ b/src/srclist/mod.rs @@ -11,6 +11,7 @@ pub(crate) mod read; pub(crate) mod rts; pub(crate) mod types; pub(crate) mod woden; +mod write; mod error; #[cfg(test)] @@ -20,6 +21,7 @@ mod veto; pub(crate) use error::*; pub use types::*; pub(crate) use veto::*; +pub(crate) use write::write_source_list; use itertools::Itertools; use strum::IntoEnumIterator; diff --git a/src/srclist/rts/read.rs b/src/srclist/rts/read.rs index e4701535..0979d8d9 100644 --- a/src/srclist/rts/read.rs +++ b/src/srclist/rts/read.rs @@ -7,14 +7,16 @@ //! See for more info: //! -use log::warn; use marlu::RADec; use vec1::vec1; -use crate::srclist::{ - error::{ReadSourceListCommonError, ReadSourceListError, ReadSourceListRtsError}, - ComponentType, FluxDensity, FluxDensityType, ShapeletCoeff, Source, SourceComponent, - SourceList, +use crate::{ + cli::Warn, + srclist::{ + error::{ReadSourceListCommonError, ReadSourceListError, ReadSourceListRtsError}, + ComponentType, FluxDensity, FluxDensityType, ShapeletCoeff, Source, SourceComponent, + SourceList, + }, }; /// Parse a buffer containing an RTS-style source list into a `SourceList`. @@ -62,10 +64,11 @@ pub(crate) fn parse_source_list( } // We ignore any lines starting with whitespace, but emit a warning. else if line.starts_with(' ') | line.starts_with('\t') { - warn!( + format!( "Source list line {} starts with whitespace; ignoring it", line_num - ); + ) + .warn(); line.clear(); continue; } @@ -104,10 +107,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after declination", line_num - ); + ) + .warn(); } // Validation and conversion. @@ -178,10 +182,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after Stokes V", line_num - ); + ) + .warn(); } if stokes_i.is_nan() || stokes_q.is_nan() || stokes_u.is_nan() || stokes_v.is_nan() @@ -237,10 +242,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after minor axis", line_num - ); + ) + .warn(); } // Ensure the position angle is positive. @@ -293,10 +299,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after minor axis", line_num - ); + ) + .warn(); } // Ensure the position angle is positive. @@ -340,7 +347,7 @@ pub(crate) fn parse_source_list( }; components.iter_mut().last().unwrap().comp_type = comp_type; - warn!("Source list line {}: Ignoring SHAPELET component", line_num); + format!("Source list line {}: Ignoring SHAPELET component", line_num).warn(); component_type_set = true; in_shapelet = true; } @@ -380,10 +387,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after minor axis", line_num - ); + ) + .warn(); } // Because we ignore SHAPELET components, only add this COEFF @@ -468,10 +476,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after declination", line_num - ); + ) + .warn(); } // Validation and conversion. diff --git a/src/srclist/rts/write.rs b/src/srclist/rts/write.rs index 6e30c424..283f5ee2 100644 --- a/src/srclist/rts/write.rs +++ b/src/srclist/rts/write.rs @@ -4,9 +4,12 @@ //! Writing RTS-style text source lists. -use log::{debug, warn}; +use log::debug; -use crate::srclist::{ComponentType, FluxDensityType, SourceList, WriteSourceListError}; +use crate::{ + cli::Warn, + srclist::{ComponentType, FluxDensityType, SourceList, WriteSourceListError}, +}; fn write_comp_type( buf: &mut T, @@ -101,8 +104,11 @@ pub(crate) fn write_source_list( .any(|comp| matches!(comp.flux_type, FluxDensityType::CurvedPowerLaw { .. })) { if !warned_curved_power_laws { - warn!("RTS source lists don't support curved-power-law flux densities."); - warn!("Any sources containing them won't be written."); + [ + "RTS source lists don't support curved-power-law flux densities.".into(), + "Any sources containing them won't be written.".into(), + ] + .warn(); warned_curved_power_laws = true; } debug!("Ignoring source {name} as it contains a curved power law"); @@ -155,7 +161,7 @@ pub(crate) fn write_source_list( if let Some(num_sources) = num_sources { if num_sources > num_written_sources { - warn!("Couldn't write the requested number of sources ({num_sources}): wrote {num_written_sources}") + format!("Couldn't write the requested number of sources ({num_sources}): wrote {num_written_sources}").warn() } } @@ -181,8 +187,11 @@ pub(crate) fn write_source_list_with_order( .any(|comp| matches!(comp.flux_type, FluxDensityType::CurvedPowerLaw { .. })) { if !warned_curved_power_laws { - warn!("RTS source lists don't support curved-power-law flux densities."); - warn!("Any sources containing them won't be written."); + [ + "RTS source lists don't support curved-power-law flux densities.".into(), + "Any sources containing them won't be written.".into(), + ] + .warn(); warned_curved_power_laws = true; } debug!("Ignoring source {name} as it contains a curved power law"); diff --git a/src/srclist/veto.rs b/src/srclist/veto.rs index 4776cc41..cabcd51c 100644 --- a/src/srclist/veto.rs +++ b/src/srclist/veto.rs @@ -13,12 +13,11 @@ use std::collections::BTreeMap; use log::{debug, log_enabled, trace, Level::Trace}; use marlu::{Jones, RADec}; use rayon::{iter::Either, prelude::*}; -use thiserror::Error; use crate::{ - beam::{Beam, BeamError}, + beam::Beam, constants::*, - srclist::{FluxDensity, SourceList}, + srclist::{FluxDensity, ReadSourceListError, SourceList}, }; /// This function mutates the input source list, removing any sources that have @@ -52,11 +51,11 @@ pub(crate) fn veto_sources( num_sources: Option, source_dist_cutoff_deg: f64, veto_threshold: f64, -) -> Result<(), VetoError> { +) -> Result<(), ReadSourceListError> { let dist_cutoff = source_dist_cutoff_deg.to_radians(); // TODO: This step is relatively expensive! - let (vetoed_sources, not_vetoed_sources): (Vec>, BTreeMap) = source_list + let (vetoed_sources, not_vetoed_sources): (Vec>, BTreeMap) = source_list .par_iter() .partition_map(|(source_name, source)| { let source_name = source_name.to_owned(); @@ -180,7 +179,7 @@ pub(crate) fn veto_sources( // out. if let Some(n) = num_sources { if n > source_list.len() { - return Err(VetoError::TooFewSources { + return Err(ReadSourceListError::VetoTooFewSources { requested: n, available: source_list.len(), }); @@ -208,15 +207,6 @@ fn get_beam_attenuated_flux_density(fd: &FluxDensity, j: Jones) -> f64 { // (jijh[0].norm() * jijh[3].norm()) - (jijh[1].norm() * jijh[2].norm()) } -#[derive(Error, Debug)] -pub(crate) enum VetoError { - #[error("Tried to use {requested} sources, but only {available} sources were available after vetoing")] - TooFewSources { requested: usize, available: usize }, - - #[error("Error when trying to veto the source list: {0}")] - Beam(#[from] BeamError), -} - #[cfg(test)] mod tests { use approx::assert_abs_diff_eq; diff --git a/src/srclist/woden/read.rs b/src/srclist/woden/read.rs index 478ac9ee..f1a1298f 100644 --- a/src/srclist/woden/read.rs +++ b/src/srclist/woden/read.rs @@ -26,10 +26,10 @@ use std::convert::TryInto; -use log::warn; use marlu::{constants::DH2R, RADec}; use crate::{ + cli::Warn, constants::DEFAULT_SPEC_INDEX, srclist::{ error::{ReadSourceListCommonError, ReadSourceListError, ReadSourceListWodenError}, @@ -83,10 +83,11 @@ pub(crate) fn parse_source_list( } // We ignore any lines starting with whitespace, but emit a warning. else if line.starts_with(' ') | line.starts_with('\t') { - warn!( + format!( "Source list line {} starts with whitespace; ignoring it", line_num - ); + ) + .warn(); line.clear(); continue; } @@ -192,10 +193,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after declination", line_num - ); + ) + .warn(); } in_source = true; @@ -253,10 +255,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after declination", line_num - ); + ) + .warn(); } // Validation and conversion. @@ -330,10 +333,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after Stokes V", line_num - ); + ) + .warn(); } if stokes_i.is_nan() || stokes_q.is_nan() || stokes_u.is_nan() || stokes_v.is_nan() @@ -346,7 +350,7 @@ pub(crate) fn parse_source_list( // If the frequency is set (i.e. not 0), the ignore // additional flux density lines for this component. if fd.freq > f64::EPSILON { - warn!("Ignoring FREQ line {}", line_num); + format!("Ignoring FREQ line {}", line_num).warn(); } else { *fd = FluxDensity { freq, @@ -416,10 +420,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after spectral index", line_num - ); + ) + .warn(); } if stokes_i.is_nan() @@ -436,7 +441,7 @@ pub(crate) fn parse_source_list( // If the frequency is set (i.e. not 0), the ignore // additional flux density lines for this component. if fd.freq > f64::EPSILON { - warn!("Ignoring LINEAR line {}", line_num); + format!("Ignoring LINEAR line {}", line_num).warn(); } else { *fd = FluxDensity { freq, @@ -489,10 +494,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after minor axis", line_num - ); + ) + .warn(); } // Ensure the position angle is positive. @@ -556,10 +562,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after minor axis", line_num - ); + ) + .warn(); } // Ensure the position angle is positive. @@ -624,10 +631,11 @@ pub(crate) fn parse_source_list( } }; if items.next().is_some() { - warn!( + format!( "Source list line {}: Ignoring trailing contents after minor axis", line_num - ); + ) + .warn(); } let shapelet_coeff = ShapeletCoeff { diff --git a/src/srclist/woden/write.rs b/src/srclist/woden/write.rs index c376d5f8..19c4c28a 100644 --- a/src/srclist/woden/write.rs +++ b/src/srclist/woden/write.rs @@ -8,9 +8,12 @@ //! only the first flux density in a list of flux densities will be written //! here. -use log::{debug, warn}; +use log::debug; -use crate::srclist::{error::WriteSourceListError, ComponentType, FluxDensityType, SourceList}; +use crate::{ + cli::Warn, + srclist::{error::WriteSourceListError, ComponentType, FluxDensityType, SourceList}, +}; fn write_comp_type( buf: &mut T, @@ -101,8 +104,11 @@ pub(crate) fn write_source_list( .any(|comp| matches!(comp.flux_type, FluxDensityType::CurvedPowerLaw { .. })) { if !warned_curved_power_laws { - warn!("WODEN source lists don't support curved-power-law flux densities."); - warn!("Any sources containing them won't be written."); + [ + "WODEN source lists don't support curved-power-law flux densities.".into(), + "Any sources containing them won't be written.".into(), + ] + .warn(); warned_curved_power_laws = true; } debug!("Ignoring source {name} as it contains a curved power law"); @@ -172,7 +178,7 @@ pub(crate) fn write_source_list( if let Some(num_sources) = num_sources { if num_sources > num_written_sources { - warn!("Couldn't write the requested number of sources ({num_sources}): wrote {num_written_sources}") + format!("Couldn't write the requested number of sources ({num_sources}): wrote {num_written_sources}").warn() } } diff --git a/src/srclist/write.rs b/src/srclist/write.rs new file mode 100644 index 00000000..fb0cca63 --- /dev/null +++ b/src/srclist/write.rs @@ -0,0 +1,70 @@ +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +use std::{ + fs::File, + io::{BufWriter, Write}, + path::Path, + str::FromStr, +}; + +use log::{info, trace}; + +use super::{ + ao, hyperdrive, rts, woden, HyperdriveFileType, SourceList, SourceListType, + WriteSourceListError, +}; + +pub(crate) fn write_source_list( + sl: &SourceList, + path: &Path, + input_srclist_type: SourceListType, + output_srclist_type: Option, + num_sources: Option, +) -> Result<(), WriteSourceListError> { + trace!("Attempting to write output source list"); + let mut f = BufWriter::new(File::create(path)?); + let output_ext = path.extension().and_then(|e| e.to_str()); + let hyp_file_type = output_ext.and_then(|e| HyperdriveFileType::from_str(e).ok()); + + let output_srclist_type = match (output_srclist_type, hyp_file_type) { + (Some(t), _) => t, + + (None, Some(_)) => SourceListType::Hyperdrive, + + // Use the input source list type as the output type. + (None, None) => input_srclist_type, + }; + + match (output_srclist_type, hyp_file_type) { + (SourceListType::Hyperdrive, None) => { + return Err(WriteSourceListError::InvalidHyperdriveFormat( + output_ext.unwrap_or("").to_string(), + )) + } + (SourceListType::Rts, _) => { + rts::write_source_list(&mut f, sl, num_sources)?; + info!("Wrote rts-style source list to {}", path.display()); + } + (SourceListType::AO, _) => { + ao::write_source_list(&mut f, sl, num_sources)?; + info!("Wrote ao-style source list to {}", path.display()); + } + (SourceListType::Woden, _) => { + woden::write_source_list(&mut f, sl, num_sources)?; + info!("Wrote woden-style source list to {}", path.display()); + } + (_, Some(HyperdriveFileType::Yaml)) => { + hyperdrive::source_list_to_yaml(&mut f, sl, num_sources)?; + info!("Wrote hyperdrive-style source list to {}", path.display()); + } + (_, Some(HyperdriveFileType::Json)) => { + hyperdrive::source_list_to_json(&mut f, sl, num_sources)?; + info!("Wrote hyperdrive-style source list to {}", path.display()); + } + } + f.flush()?; + + Ok(()) +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index e4c97d77..6082c10a 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -2,9 +2,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this // file, You can obtain one at http://mozilla.org/MPL/2.0/. -//! Integration tests and helpful functions. - -pub mod reduced_obsids; +//! Helpful functions for tests. use std::{ fs::File, @@ -23,3 +21,99 @@ pub(crate) fn deflate_gz_into_tempfile>(file: T) -> TempPath { deflate_gz_into_file(file, &mut temp_file); temp_path } + +const DATA_DIR_1090008640: &str = "test_files/1090008640"; + +pub(crate) struct DataAsStrings { + pub(crate) metafits: String, + pub(crate) vis: Vec, + pub(crate) mwafs: Vec, + pub(crate) srclist: String, +} + +pub(crate) struct DataAsPathBufs { + pub(crate) metafits: PathBuf, + pub(crate) vis: Vec, + pub(crate) mwafs: Vec, + pub(crate) srclist: PathBuf, +} + +pub(crate) fn get_reduced_1090008640_raw() -> DataAsStrings { + DataAsStrings { + metafits: format!("{DATA_DIR_1090008640}/1090008640.metafits"), + vis: vec![format!( + "{DATA_DIR_1090008640}/1090008640_20140721201027_gpubox01_00.fits" + )], + mwafs: vec![format!("{DATA_DIR_1090008640}/1090008640_01.mwaf")], + srclist: format!( + "{DATA_DIR_1090008640}/srclist_pumav3_EoR0aegean_EoR1pietro+ForA_1090008640_100.yaml" + ), + } +} + +pub(crate) fn get_reduced_1090008640_ms() -> DataAsStrings { + let mut data = get_reduced_1090008640_raw(); + data.vis[0] = format!("{DATA_DIR_1090008640}/1090008640.ms"); + data +} + +pub(crate) fn get_reduced_1090008640_uvfits() -> DataAsStrings { + let mut data = get_reduced_1090008640_raw(); + data.vis[0] = format!("{DATA_DIR_1090008640}/1090008640.uvfits"); + data +} + +pub(crate) fn get_reduced_1090008640_raw_pbs() -> DataAsPathBufs { + let DataAsStrings { + metafits, + vis, + mwafs, + srclist, + } = get_reduced_1090008640_raw(); + let pbs = DataAsPathBufs { + metafits: PathBuf::from(metafits).canonicalize().unwrap(), + vis: vis + .into_iter() + .map(|s| PathBuf::from(s).canonicalize().unwrap()) + .collect(), + mwafs: mwafs + .into_iter() + .map(|s| PathBuf::from(s).canonicalize().unwrap()) + .collect(), + srclist: PathBuf::from(srclist).canonicalize().unwrap(), + }; + + // Ensure that the required files are there. + for file in [&pbs.metafits] + .into_iter() + .chain(pbs.vis.iter()) + .chain(pbs.mwafs.iter()) + .chain([&pbs.srclist].into_iter()) + { + assert!( + file.exists(), + "Could not find '{}', which is required for this test", + file.display() + ); + } + + pbs +} + +pub(crate) fn get_reduced_1090008640_ms_pbs() -> DataAsPathBufs { + let mut data = get_reduced_1090008640_raw_pbs(); + data.vis[0] = PathBuf::from(DATA_DIR_1090008640) + .canonicalize() + .unwrap() + .join("1090008640.ms"); + data +} + +pub(crate) fn get_reduced_1090008640_uvfits_pbs() -> DataAsPathBufs { + let mut data = get_reduced_1090008640_raw_pbs(); + data.vis[0] = PathBuf::from(DATA_DIR_1090008640) + .canonicalize() + .unwrap() + .join("1090008640.uvfits"); + data +} diff --git a/src/tests/reduced_obsids.rs b/src/tests/reduced_obsids.rs deleted file mode 100644 index 5b3e1986..00000000 --- a/src/tests/reduced_obsids.rs +++ /dev/null @@ -1,124 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//! This module provides functions for tests against observations that occupy a -//! small amount of data. These tests are useful as it allows hyperdrive to -//! perform units tests without requiring MWA data for a full observation. - -use super::*; -use crate::cli::di_calibrate::DiCalArgs; - -/// Get the calibration arguments associated with the obsid 1090008640 (raw MWA -/// data). This observational data is inside the hyperdrive git repo, but has -/// been reduced; there is only 1 coarse channel and 1 timestep. -pub(crate) fn get_reduced_1090008640(use_fee_beam: bool, include_mwaf: bool) -> DiCalArgs { - // Use absolute paths. - let test_files = PathBuf::from("test_files/1090008640") - .canonicalize() - .unwrap(); - - // Ensure that the required files are there. - let mut data = vec![ - format!("{}/1090008640.metafits", test_files.display()), - format!( - "{}/1090008640_20140721201027_gpubox01_00.fits", - test_files.display() - ), - ]; - if include_mwaf { - data.push(format!("{}/1090008640_01.mwaf", test_files.display())); - } - for file in &data { - let pb = PathBuf::from(file); - assert!( - pb.exists(), - "Could not find {}, which is required for this test", - pb.display() - ); - } - - let srclist = format!( - "{}/srclist_pumav3_EoR0aegean_EoR1pietro+ForA_1090008640_100.yaml", - test_files.display() - ); - assert!( - PathBuf::from(&srclist).exists(), - "Could not find {srclist}, which is required for this test" - ); - - DiCalArgs { - data: Some(data), - source_list: Some(srclist), - no_beam: !use_fee_beam, - ..Default::default() - } -} - -/// Get the calibration arguments associated with the obsid 1090008640 -/// (measurement set). This observational data is inside the hyperdrive git -/// repo, but has been reduced; there is only 1 coarse channel and 1 timestep. -pub(crate) fn get_reduced_1090008640_ms() -> DiCalArgs { - // Ensure that the required files are there. - let data = vec![ - "test_files/1090008640/1090008640.metafits".to_string(), - "test_files/1090008640/1090008640.ms".to_string(), - ]; - for file in &data { - let pb = PathBuf::from(file); - assert!( - pb.exists(), - "Could not find {}, which is required for this test", - pb.display() - ); - } - - let srclist = - "test_files/1090008640/srclist_pumav3_EoR0aegean_EoR1pietro+ForA_1090008640_100.yaml" - .to_string(); - assert!( - PathBuf::from(&srclist).exists(), - "Could not find {srclist}, which is required for this test" - ); - - DiCalArgs { - data: Some(data), - source_list: Some(srclist), - no_beam: true, - ..Default::default() - } -} - -/// Get the calibration arguments associated with the obsid 1090008640 (uvfits). -/// This observational data is inside the hyperdrive git repo, but has been -/// reduced; there is only 1 coarse channel and 1 timestep. -pub(crate) fn get_reduced_1090008640_uvfits() -> DiCalArgs { - // Ensure that the required files are there. - let data = vec![ - "test_files/1090008640/1090008640.metafits".to_string(), - "test_files/1090008640/1090008640.uvfits".to_string(), - ]; - for file in &data { - let pb = PathBuf::from(file); - assert!( - pb.exists(), - "Could not find {}, which is required for this test", - pb.display() - ); - } - - let srclist = - "test_files/1090008640/srclist_pumav3_EoR0aegean_EoR1pietro+ForA_1090008640_100.yaml" - .to_string(); - assert!( - PathBuf::from(&srclist).exists(), - "Could not find {srclist}, which is required for this test" - ); - - DiCalArgs { - data: Some(data), - source_list: Some(srclist), - no_beam: true, - ..Default::default() - } -} diff --git a/tests/integration_tests/di_calibrate/arg_files.rs b/tests/integration_tests/di_calibrate/arg_files.rs deleted file mode 100644 index a4483f4d..00000000 --- a/tests/integration_tests/di_calibrate/arg_files.rs +++ /dev/null @@ -1,93 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -//! This module tests the "calibrate" command-line interface in hyperdrive with -//! toml and json argument files. - -use tempfile::tempdir; - -use crate::{ - get_cmd_output, get_reduced_1090008640, hyperdrive, make_file_in_dir, serialise_cal_args_json, - serialise_cal_args_toml, -}; - -#[test] -fn arg_file_absolute_paths() { - let args = get_reduced_1090008640(false, true); - let temp_dir = tempdir().expect("Couldn't make tempdir"); - - let (toml, mut toml_file) = make_file_in_dir("calibrate.toml", temp_dir.path()); - serialise_cal_args_toml(&args, &mut toml_file); - let cmd = hyperdrive() - .arg("di-calibrate") - .arg(toml.display().to_string()) - .arg("--dry-run") - .ok(); - assert!(cmd.is_ok(), "{}", get_cmd_output(cmd).1); - - let (json, mut json_file) = make_file_in_dir("calibrate.json", temp_dir.path()); - serialise_cal_args_json(&args, &mut json_file); - let cmd = hyperdrive() - .arg("di-calibrate") - .arg(json.display().to_string()) - .arg("--dry-run") - .ok(); - assert!(cmd.is_ok(), "{}", get_cmd_output(cmd).1); -} - -#[test] -fn arg_file_absolute_globs() { - let args = get_reduced_1090008640(false, true); - let temp_dir = tempdir().expect("Couldn't make tempdir"); - - let (toml_pb, mut toml) = make_file_in_dir("calibrate.toml", temp_dir.path()); - serialise_cal_args_toml(&args, &mut toml); - let cmd = hyperdrive() - .arg("di-calibrate") - .arg(toml_pb.display().to_string()) - .arg("--dry-run") - .ok(); - assert!(cmd.is_ok(), "{}", get_cmd_output(cmd).1); - - let (json_pb, mut json) = make_file_in_dir("calibrate.json", temp_dir.path()); - serialise_cal_args_json(&args, &mut json); - let cmd = hyperdrive() - .arg("di-calibrate") - .arg(json_pb.display().to_string()) - .arg("--dry-run") - .arg("--verb") - .ok(); - assert!(cmd.is_ok(), "{}", get_cmd_output(cmd).1); -} - -#[test] -fn arg_file_relative_globs() { - let mut args = get_reduced_1090008640(false, true); - args.data = Some(vec![ - "test_files/1090008640/*.metafits".to_string(), - "test_files/1090008640/*gpubox*".to_string(), - "test_files/1090008640/*.mwaf".to_string(), - ]); - args.source_list = Some("test_files/1090008640/*srclist*_100.yaml".to_string()); - - let temp_dir = tempdir().expect("Couldn't make tempdir"); - - let (toml_pb, mut toml) = make_file_in_dir("calibrate.toml", temp_dir.path()); - serialise_cal_args_toml(&args, &mut toml); - let cmd = hyperdrive() - .arg("di-calibrate") - .arg(toml_pb.display().to_string()) - .arg("--dry-run") - .ok(); - assert!(cmd.is_ok(), "{}", get_cmd_output(cmd).1); - - let (json_pb, mut json) = make_file_in_dir("calibrate.json", temp_dir.path()); - serialise_cal_args_json(&args, &mut json); - let cmd = hyperdrive() - .arg("di-calibrate") - .arg(json_pb.display().to_string()) - .arg("--dry-run") - .ok(); - assert!(cmd.is_ok(), "{}", get_cmd_output(cmd).1); -} diff --git a/tests/integration_tests/di_calibrate/missing_files.rs b/tests/integration_tests/di_calibrate/missing_files.rs index 42997b1b..95febf42 100644 --- a/tests/integration_tests/di_calibrate/missing_files.rs +++ b/tests/integration_tests/di_calibrate/missing_files.rs @@ -9,14 +9,12 @@ use std::io::Write; use tempfile::Builder; -use crate::{get_cmd_output, get_reduced_1090008640, hyperdrive}; +use crate::{get_cmd_output, get_reduced_1090008640, hyperdrive, Files}; /// Try to calibrate raw MWA data without a metafits file. #[test] fn arg_file_missing_metafits() { - let args = get_reduced_1090008640(false, true); - let source_list = args.source_list.unwrap(); - let data = args.data.unwrap(); + let Files { data, srclist } = get_reduced_1090008640(true); let metafits = data .iter() .find(|d| d.contains("metafits")) @@ -30,7 +28,7 @@ fn arg_file_missing_metafits() { let cmd = hyperdrive() .arg("di-calibrate") .arg("--source-list") - .arg(&source_list) + .arg(&srclist) .arg("--data") .args(&data) .arg("--dry-run") @@ -47,7 +45,7 @@ fn arg_file_missing_metafits() { let cmd = hyperdrive() .arg("di-calibrate") .arg("--source-list") - .arg(&source_list) + .arg(&srclist) .arg("--data") .arg(metafits) .args(&data) @@ -59,9 +57,7 @@ fn arg_file_missing_metafits() { /// Try to calibrate raw MWA data without gpubox files. #[test] fn arg_file_missing_gpuboxes() { - let args = get_reduced_1090008640(false, true); - let source_list = args.source_list.unwrap(); - let data = args.data.unwrap(); + let Files { data, srclist } = get_reduced_1090008640(true); let gpuboxes: Vec = data .iter() .filter(|d| d.contains("gpubox")) @@ -72,7 +68,7 @@ fn arg_file_missing_gpuboxes() { let cmd = hyperdrive() .arg("di-calibrate") .arg("--source-list") - .arg(&source_list) + .arg(&srclist) .arg("--data") .args(&data) .arg("--dry-run") @@ -89,7 +85,7 @@ fn arg_file_missing_gpuboxes() { let cmd = hyperdrive() .arg("di-calibrate") .arg("--source-list") - .arg(&source_list) + .arg(&srclist) .arg("--data") .args(&gpuboxes) .args(&data) @@ -101,9 +97,7 @@ fn arg_file_missing_gpuboxes() { /// Ensure that di-calibrate issues a warning when no mwaf files are supplied. #[test] fn missing_mwafs() { - let args = get_reduced_1090008640(false, true); - let source_list = args.source_list.unwrap(); - let data = args.data.unwrap(); + let Files { data, srclist } = get_reduced_1090008640(true); let metafits = &data[0]; let gpubox = &data[1]; let mwaf = &data[2]; @@ -113,7 +107,7 @@ fn missing_mwafs() { let cmd = hyperdrive() .arg("di-calibrate") .arg("--source-list") - .arg(&source_list) + .arg(&srclist) .arg("--data") .arg(metafits) .arg(gpubox) @@ -127,7 +121,7 @@ fn missing_mwafs() { let cmd = hyperdrive() .arg("di-calibrate") .arg("--source-list") - .arg(&source_list) + .arg(&srclist) .arg("--data") .arg(metafits) .arg(gpubox) @@ -143,10 +137,24 @@ fn missing_mwafs() { /// via argument files. #[test] fn arg_file_missing_mwafs() { - let args = get_reduced_1090008640(false, false); - let args_str = toml::to_string(&args).unwrap(); + let Files { data, srclist } = get_reduced_1090008640(true); + let metafits = &data[0]; + let gpubox = &data[1]; + let mwaf = &data[2]; let mut args_file = Builder::new().suffix(".toml").tempfile().unwrap(); - args_file.write_all(args_str.as_bytes()).unwrap(); + args_file + .write_all( + format!( + r#"[data] +files = ["{metafits}", "{gpubox}"] + +[sky-model] +source_list = "{srclist}" +"# + ) + .as_bytes(), + ) + .unwrap(); // Don't include an mwaf file; we expect a warning to be logged for this // reason. @@ -160,10 +168,20 @@ fn arg_file_missing_mwafs() { assert!(stdout.contains("No mwaf files supplied"), "{}", stdout); // Include an mwaf file; we don't expect a warning this time. - let args = get_reduced_1090008640(false, true); - let args_str = toml::to_string(&args).unwrap(); let mut args_file = Builder::new().suffix(".toml").tempfile().unwrap(); - args_file.write_all(args_str.as_bytes()).unwrap(); + args_file + .write_all( + format!( + r#"[data] +files = ["{metafits}", "{gpubox}", "{mwaf}"] + +[sky-model] +source_list = "{srclist}" +"# + ) + .as_bytes(), + ) + .unwrap(); let cmd = hyperdrive() .arg("di-calibrate") diff --git a/tests/integration_tests/di_calibrate/mod.rs b/tests/integration_tests/di_calibrate/mod.rs index 0760de19..b46fe626 100644 --- a/tests/integration_tests/di_calibrate/mod.rs +++ b/tests/integration_tests/di_calibrate/mod.rs @@ -4,7 +4,6 @@ //! Code for calibration testing. -mod arg_files; mod cli_args; mod missing_files; diff --git a/tests/integration_tests/main.rs b/tests/integration_tests/main.rs index aa296009..3e590f43 100644 --- a/tests/integration_tests/main.rs +++ b/tests/integration_tests/main.rs @@ -12,8 +12,6 @@ mod no_stderr; mod solutions_apply; use std::{ - fs::File, - io::Write, path::{Path, PathBuf}, process::Output, str::from_utf8, @@ -23,7 +21,7 @@ use assert_cmd::{output::OutputError, Command}; use marlu::Jones; use ndarray::prelude::*; -use mwa_hyperdrive::{CalibrationSolutions, DiCalArgs}; +use mwa_hyperdrive::CalibrationSolutions; fn hyperdrive() -> Command { Command::cargo_bin("hyperdrive").unwrap() @@ -40,7 +38,7 @@ fn get_cmd_output(result: Result) -> (String, String) { ) } -fn get_1090008640_identity_solutions_file(tmp_dir: &Path) -> PathBuf { +fn get_identity_solutions_file(tmp_dir: &Path) -> PathBuf { let sols = CalibrationSolutions { di_jones: Array3::from_elem((1, 128, 32), Jones::identity()), ..Default::default() @@ -50,26 +48,15 @@ fn get_1090008640_identity_solutions_file(tmp_dir: &Path) -> PathBuf { file } -fn make_file_in_dir, U: AsRef>(filename: T, dir: U) -> (PathBuf, File) { - let path = dir.as_ref().join(filename); - let f = File::create(&path).expect("couldn't make file"); - (path, f) -} - -fn serialise_cal_args_toml(args: &DiCalArgs, file: &mut File) { - let ser = toml::to_string_pretty(&args).expect("couldn't serialise DiCalArgs as toml"); - write!(file, "{ser}").unwrap(); -} - -fn serialise_cal_args_json(args: &DiCalArgs, file: &mut File) { - let ser = serde_json::to_string_pretty(&args).expect("couldn't serialise DiCalArgs as json"); - write!(file, "{ser}").unwrap(); +struct Files { + data: Vec, + srclist: String, } /// Get the calibration arguments associated with the obsid 1090008640 (raw MWA /// data). This observational data is inside the hyperdrive git repo, but has /// been reduced; there is only 1 coarse channel and 1 timestep. -fn get_reduced_1090008640(use_fee_beam: bool, include_mwaf: bool) -> DiCalArgs { +fn get_reduced_1090008640(include_mwaf: bool) -> Files { // Use absolute paths. let test_files = PathBuf::from("test_files/1090008640") .canonicalize() @@ -104,10 +91,5 @@ fn get_reduced_1090008640(use_fee_beam: bool, include_mwaf: bool) -> DiCalArgs { "Could not find {srclist}, which is required for this test" ); - DiCalArgs { - data: Some(data), - source_list: Some(srclist), - no_beam: !use_fee_beam, - ..Default::default() - } + Files { data, srclist } } diff --git a/tests/integration_tests/no_stderr.rs b/tests/integration_tests/no_stderr.rs index e2df0ea2..188f83c8 100644 --- a/tests/integration_tests/no_stderr.rs +++ b/tests/integration_tests/no_stderr.rs @@ -9,22 +9,21 @@ use std::{collections::HashMap, io::Write}; use tempfile::TempDir; use crate::{ - get_1090008640_identity_solutions_file, get_cmd_output, get_reduced_1090008640, hyperdrive, + get_cmd_output, get_identity_solutions_file, get_reduced_1090008640, hyperdrive, Files, }; #[test] fn test_di_no_stderr() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); let sols = tmp_dir.path().join("sols.fits"); - let args = get_reduced_1090008640(true, false); - let data = args.data.unwrap(); + let Files { data, srclist } = get_reduced_1090008640(false); #[rustfmt::skip] let cmd = hyperdrive() .args([ "di-calibrate", "--data", &data[0], &data[1], - "--source-list", &args.source_list.unwrap(), + "--source-list", &srclist, "--outputs", &format!("{}", sols.display()), ]) .ok(); @@ -41,9 +40,8 @@ fn test_di_no_stderr() { fn test_solutions_apply_no_stderr() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); let output = tmp_dir.path().join("out.uvfits"); - let sols = get_1090008640_identity_solutions_file(tmp_dir.path()); - let args = get_reduced_1090008640(true, false); - let data = args.data.unwrap(); + let sols = get_identity_solutions_file(tmp_dir.path()); + let Files { data, .. } = get_reduced_1090008640(false); #[rustfmt::skip] let cmd = hyperdrive() @@ -63,6 +61,29 @@ fn test_solutions_apply_no_stderr() { assert!(stderr.is_empty(), "stderr wasn't empty: {stderr}"); } +#[test] +fn test_vis_convert_no_stderr() { + let temp_dir = TempDir::new().expect("couldn't make tmp dir"); + let output = temp_dir.path().join("converted.uvfits"); + let Files { data, .. } = get_reduced_1090008640(false); + + #[rustfmt::skip] + let cmd = hyperdrive() + .args([ + "vis-convert", + "--data", &data[0], &data[1], + "--outputs", &format!("{}", output.display()), + ]) + .ok(); + assert!( + cmd.is_ok(), + "vis-convert failed on simple test data: {}", + cmd.err().unwrap() + ); + let (_, stderr) = get_cmd_output(cmd); + assert!(stderr.is_empty(), "stderr wasn't empty: {stderr}"); +} + #[test] fn test_vis_simulate_and_vis_subtract_no_stderr() { // First test vis-simulate. @@ -71,15 +92,15 @@ fn test_vis_simulate_and_vis_subtract_no_stderr() { let temp_dir = TempDir::new().expect("couldn't make tmp dir"); let model_path = temp_dir.path().join("model.uvfits"); - let args = get_reduced_1090008640(true, false); - let metafits = args.data.as_ref().unwrap()[0].clone(); + let Files { data, srclist } = get_reduced_1090008640(false); + let metafits = data[0].clone(); #[rustfmt::skip] let cmd = hyperdrive() .args([ "vis-simulate", "--metafits", &metafits, - "--source-list", args.source_list.as_ref().unwrap(), + "--source-list", &srclist, "--output-model-files", &format!("{}", model_path.display()), "--num-timesteps", &format!("{num_timesteps}"), "--num-fine-channels", &format!("{num_chans}"), @@ -102,7 +123,7 @@ fn test_vis_simulate_and_vis_subtract_no_stderr() { .args([ "vis-subtract", "--data", &metafits, &format!("{}", model_path.display()), - "--source-list", &args.source_list.unwrap(), + "--source-list", &srclist, "--invert", "--output", &format!("{}", sub_path.display()), "--no-progress-bars" @@ -121,7 +142,7 @@ fn test_vis_simulate_and_vis_subtract_no_stderr() { fn test_solutions_convert_no_stderr() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); let output = tmp_dir.path().join("sols.bin"); - let sols = get_1090008640_identity_solutions_file(tmp_dir.path()); + let sols = get_identity_solutions_file(tmp_dir.path()); #[rustfmt::skip] let cmd = hyperdrive() @@ -144,9 +165,9 @@ fn test_solutions_convert_no_stderr() { #[cfg(feature = "plotting")] fn test_solutions_plot_no_stderr() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let sols = get_1090008640_identity_solutions_file(tmp_dir.path()); - let args = get_reduced_1090008640(true, false); - let metafits = args.data.as_ref().unwrap()[0].clone(); + let sols = get_identity_solutions_file(tmp_dir.path()); + let Files { data, .. } = get_reduced_1090008640(false); + let metafits = data[0].clone(); #[rustfmt::skip] let cmd = hyperdrive() @@ -169,9 +190,8 @@ fn test_solutions_plot_no_stderr() { #[test] fn test_srclist_by_beam_no_stderr() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let args = get_reduced_1090008640(true, false); - let metafits = args.data.as_ref().unwrap()[0].clone(); - let srclist = args.source_list.unwrap(); + let Files { data, srclist } = get_reduced_1090008640(false); + let metafits = data[0].clone(); let output = tmp_dir.path().join("srclist.txt"); #[rustfmt::skip] @@ -197,8 +217,7 @@ fn test_srclist_by_beam_no_stderr() { #[test] fn test_srclist_convert_no_stderr() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let args = get_reduced_1090008640(true, false); - let srclist = args.source_list.unwrap(); + let Files { srclist, .. } = get_reduced_1090008640(false); let output = tmp_dir.path().join("srclist.txt"); #[rustfmt::skip] @@ -222,8 +241,7 @@ fn test_srclist_convert_no_stderr() { #[test] fn test_srclist_shift_no_stderr() { let tmp_dir = TempDir::new().expect("couldn't make tmp dir"); - let args = get_reduced_1090008640(true, false); - let srclist = args.source_list.unwrap(); + let Files { srclist, .. } = get_reduced_1090008640(false); let output = tmp_dir.path().join("shifted.txt"); let shifts = tmp_dir.path().join("shifts.json"); @@ -259,8 +277,7 @@ fn test_srclist_shift_no_stderr() { #[test] fn test_srclist_verify_no_stderr() { - let args = get_reduced_1090008640(true, false); - let srclist = args.source_list.unwrap(); + let Files { srclist, .. } = get_reduced_1090008640(false); #[rustfmt::skip] let cmd = hyperdrive() @@ -280,14 +297,14 @@ fn test_srclist_verify_no_stderr() { #[test] fn test_dipole_gains_no_stderr() { - let args = get_reduced_1090008640(true, false); - let metafits = args.data.as_ref().unwrap()[0].clone(); + let Files { data, .. } = get_reduced_1090008640(false); + let metafits = &data[0]; #[rustfmt::skip] let cmd = hyperdrive() .args([ "dipole-gains", - &metafits, + metafits, ]) .ok(); assert!(