diff --git a/src/calibrate/di/mod.rs b/src/calibrate/di/mod.rs index 8d440adb..aae07f8b 100644 --- a/src/calibrate/di/mod.rs +++ b/src/calibrate/di/mod.rs @@ -13,17 +13,23 @@ pub use code::calibrate_timeblocks; use code::*; use hifitime::{Duration, Unit}; -use itertools::{izip, Itertools}; +use itertools::izip; use log::{debug, info, log_enabled, trace, Level::Debug}; use marlu::{ - math::cross_correlation_baseline_to_tiles, Jones, UvfitsWriter, VisContext, VisWritable, + math::cross_correlation_baseline_to_tiles, Jones, MeasurementSetWriter, + ObsContext as MarluObsContext, UvfitsWriter, VisContext, VisWritable, }; use ndarray::prelude::*; use rayon::prelude::*; use super::{params::CalibrateParams, solutions::CalibrationSolutions, CalibrateError}; use crate::data_formats::VisOutputType; -use mwa_hyperdrive_common::{hifitime, itertools, log, marlu, ndarray, rayon}; +use mwa_hyperdrive_common::{ + hifitime::{self, Epoch}, + itertools, log, marlu, ndarray, + num_traits::Zero, + rayon, +}; /// Do all the steps required for direction-independent calibration; read the /// input data, generate a model against it, and write the solutions out. @@ -144,21 +150,23 @@ pub(crate) fn di_calibrate( // TODO(dev): support and test autos if params.using_autos { - panic!("not supperted yet... or are they?"); + panic!("writing auto outputs for calibrated vis not supported"); + } + + // TODO(dev): support and test time averaging for calibrated vis + if params.output_vis_time_average_factor > 1 { + panic!("time averaging for calibrated vis not supported"); } let ant_pairs: Vec<(usize, usize)> = params.get_ant_pairs(); let int_time: Duration = Duration::from_f64(obs_context.time_res.unwrap(), Unit::Second); - // TODO(dev): support sparse timesteps by chunking over time - for (&past, &future) in params.timesteps.iter().tuple_windows() { - assert!(future > past); - assert!(future - past == 1, "assuming contiguous timesteps"); - } + let start_timestamp = obs_context.timestamps[params.timesteps[0]]; + // XXX(Dev): VisContext does not support sparse timesteps, but in this case it doesn't matter let vis_ctx = VisContext { num_sel_timesteps: params.timesteps.len(), - start_timestamp: obs_context.timestamps[params.timesteps[0]], + start_timestamp, int_time, num_sel_chans: obs_context.fine_chan_freqs.len(), start_freq_hz: obs_context.fine_chan_freqs[0] as f64, @@ -169,13 +177,10 @@ pub(crate) fn di_calibrate( num_vis_pols: 4, }; - // pad and transpose the data - // TODO(dev): unify unpacking + let obs_name = obs_context.obsid.map(|o| format!("MWA obsid {}", o)); - // out data is [time, freq, baseline], in data is [time, baseline, freq] + // shape of entire output [time, freq, baseline]. in data is [time, baseline, freq] let out_shape = vis_ctx.sel_dims(); - let mut out_data = Array3::zeros(out_shape); - let mut out_weights = Array3::from_elem(out_shape, -0.0); assert_eq!(vis_weights.dim(), vis_data.dim()); // time @@ -188,23 +193,90 @@ pub(crate) fn di_calibrate( out_shape.1 ); + // re-use output arrays each timestep chunk + let out_shape_timestep = (1, out_shape.1, out_shape.2); + let mut tmp_out_data = Array3::from_elem(out_shape_timestep, Jones::zero()); + let mut tmp_out_weights = Array3::from_elem(out_shape_timestep, -0.0); + + // create a VisWritable for each output vis filename + let mut out_writers: Vec<(VisOutputType, Box)> = vec![]; + for (vis_type, file) in params.output_vis_filenames.iter() { + match vis_type { + VisOutputType::Uvfits => { + trace!(" - to uvfits {}", file.display()); + + let writer = UvfitsWriter::from_marlu( + &file, + &vis_ctx, + Some(params.array_position), + obs_context.phase_centre, + obs_name.clone(), + )?; + + out_writers.push((VisOutputType::Uvfits, Box::new(writer))); + } + VisOutputType::MeasurementSet => { + trace!(" - to measurement set {}", file.display()); + let writer = MeasurementSetWriter::new( + &file, + obs_context.phase_centre, + Some(params.array_position), + ); + + let sched_start_timestamp = match obs_context.obsid { + Some(gpst) => Epoch::from_gpst_seconds(gpst as f64), + None => start_timestamp, + }; + let sched_duration = *obs_context.timestamps.last() - sched_start_timestamp; + + let marlu_obs_ctx = MarluObsContext { + sched_start_timestamp, + sched_duration, + name: obs_name.clone(), + phase_centre: obs_context.phase_centre, + pointing_centre: obs_context.pointing_centre, + array_pos: params.array_position, + ant_positions_enh: obs_context + .tile_xyzs + .iter() + .map(|xyz| xyz.to_enh(params.array_position.latitude_rad)) + .collect(), + ant_names: obs_context.tile_names.iter().cloned().collect(), + // TODO(dev): is there any value in adding this metadata via hyperdrive obs context? + field_name: None, + project_id: None, + observer: None, + }; + + writer.initialize(&vis_ctx, &marlu_obs_ctx)?; + out_writers.push((VisOutputType::MeasurementSet, Box::new(writer))); + } + }; + } + // zip over time axis; - for (mut out_data, mut out_weights, vis_data, vis_weights) in izip!( - out_data.outer_iter_mut(), - out_weights.outer_iter_mut(), + for (×tep, vis_data, vis_weights) in izip!( + params.timesteps.iter(), vis_data.outer_iter(), vis_weights.outer_iter(), ) { + let chunk_vis_ctx = VisContext { + start_timestamp: obs_context.timestamps[timestep], + ..vis_ctx.clone() + }; + tmp_out_data.fill(Jones::zero()); + tmp_out_weights.fill(-0.0); + // zip over baseline axis - for (mut out_data, mut out_weights, vis_data, vis_weights) in izip!( - out_data.axis_iter_mut(Axis(1)), - out_weights.axis_iter_mut(Axis(1)), + for (mut tmp_out_data, mut tmp_out_weights, vis_data, vis_weights) in izip!( + tmp_out_data.axis_iter_mut(Axis(1)), + tmp_out_weights.axis_iter_mut(Axis(1)), vis_data.axis_iter(Axis(0)), vis_weights.axis_iter(Axis(0)) ) { // merge frequency axis for ((_, out_jones, out_weight), in_jones, in_weight) in izip!( - izip!(0.., out_data.iter_mut(), out_weights.iter_mut(),) + izip!(0.., tmp_out_data.iter_mut(), tmp_out_weights.iter_mut(),) .filter(|(chan_idx, _, _)| !params.flagged_fine_chans.contains(chan_idx)), vis_data.iter(), vis_weights.iter() @@ -213,39 +285,26 @@ pub(crate) fn di_calibrate( *out_weight = *in_weight; } } - } - let obs_name = obs_context.obsid.map(|o| format!("MWA obsid {}", o)); - - for (vis_type, file) in ¶ms.output_vis_filenames { - match vis_type { - // TODO: Make this an obs_context method? - VisOutputType::Uvfits => { - trace!("Writing to output uvfits"); - - let mut writer = UvfitsWriter::from_marlu( - &file, - &vis_ctx, - Some(params.array_position), - obs_context.phase_centre, - obs_name.clone(), - )?; - - writer.write_vis_marlu( - out_data.view(), - out_weights.view(), - &vis_ctx, - &obs_context.tile_xyzs, - false, - )?; + for (_, writer) in out_writers.iter_mut() { + writer.write_vis_marlu( + tmp_out_data.view(), + tmp_out_weights.view(), + &chunk_vis_ctx, + &obs_context.tile_xyzs, + false, + )?; + } + } - writer.write_uvfits_antenna_table( - &obs_context.tile_names, - &obs_context.tile_xyzs, - )?; - } // TODO(dev): Other formats + // finalize writing uvfits + for (vis_type, writer) in out_writers.into_iter() { + if matches!(vis_type, VisOutputType::Uvfits) { + let uvfits_writer = + unsafe { Box::from_raw(Box::into_raw(writer) as *mut UvfitsWriter) }; + uvfits_writer + .write_uvfits_antenna_table(&obs_context.tile_names, &obs_context.tile_xyzs)?; } - info!("Calibrated visibilities written to {}", file.display()); } } diff --git a/src/calibrate/error.rs b/src/calibrate/error.rs index d4cdfc02..1de3aaf8 100644 --- a/src/calibrate/error.rs +++ b/src/calibrate/error.rs @@ -4,7 +4,7 @@ //! Error type for all calibration-related errors. -use marlu::io::error::{IOError as MarluIOError, UvfitsWriteError}; +use marlu::io::error::{IOError as MarluIOError, MeasurementSetWriteError, UvfitsWriteError}; use mwalib::fitsio; use thiserror::Error; @@ -56,6 +56,9 @@ pub enum CalibrateError { #[error("Error when writing uvfits: {0}")] UviftsWrite(#[from] UvfitsWriteError), + #[error("Error when writing measurement set: {0}")] + MeasurementSetWrite(#[from] MeasurementSetWriteError), + #[error("Error when using Marlu for IO: {0}")] MarluIO(#[from] MarluIOError), diff --git a/src/calibrate/params/mod.rs b/src/calibrate/params/mod.rs index de8e85d9..b3421ac2 100644 --- a/src/calibrate/params/mod.rs +++ b/src/calibrate/params/mod.rs @@ -20,7 +20,7 @@ use filenames::InputDataTypes; use helpers::*; use std::collections::{HashMap, HashSet}; -use std::fs::OpenOptions; +use std::fs::{self, OpenOptions}; use std::ops::Deref; use std::path::{Path, PathBuf}; use std::str::FromStr; diff --git a/src/context/mod.rs b/src/context/mod.rs index 962fde78..2911c148 100644 --- a/src/context/mod.rs +++ b/src/context/mod.rs @@ -19,15 +19,15 @@ use mwa_hyperdrive_common::{hifitime, marlu, ndarray, vec1}; /// /// Tile information is ordered according to the "Antenna" column in HDU 1 of /// the observation's metafits file. -pub(crate) struct ObsContext { +pub struct ObsContext { /// The observation ID, which is also the observation's scheduled start GPS /// time (but shouldn't be used for this purpose). - pub(crate) obsid: Option, + pub obsid: Option, /// The unique timestamps in the observation. These are stored as `hifitime` /// [Epoch] structs to help keep the code flexible. These include timestamps /// that are deemed "flagged" by the observation. - pub(crate) timestamps: Vec1, + pub timestamps: Vec1, /// The *available* timestep indices of the input data. This does not /// necessarily start at 0, and is not necessarily regular (e.g. a valid @@ -37,7 +37,7 @@ pub(crate) struct ObsContext { /// data that also isn't regular; naively reading in a dataset with 2 /// timesteps that are separated by more than the time resolution of the /// data would give misleading results. - pub(crate) all_timesteps: Vec1, + pub all_timesteps: Vec1, /// The timestep indices of the input data that aren't totally flagged. /// @@ -107,7 +107,7 @@ pub(crate) struct ObsContext { /// /// These are kept as ints to help some otherwise error-prone calculations /// using floats. By using ints, we assume there is no sub-Hz structure. - pub(crate) fine_chan_freqs: Vec1, + pub fine_chan_freqs: Vec1, /// The flagged fine channels for each baseline in the supplied data. Zero /// indexed. diff --git a/src/data_formats/mod.rs b/src/data_formats/mod.rs index 22296684..543551b4 100644 --- a/src/data_formats/mod.rs +++ b/src/data_formats/mod.rs @@ -12,9 +12,9 @@ mod uvfits; pub(crate) use error::ReadInputDataError; pub use metafits::*; -pub(crate) use ms::{MsReadError, MS}; +pub use ms::{MsReadError, MS}; pub(crate) use raw::{RawDataReader, RawReadError}; -pub(crate) use uvfits::{UvfitsReadError, UvfitsReader}; +pub use uvfits::{UvfitsReadError, UvfitsReader}; use std::collections::{HashMap, HashSet}; @@ -27,7 +27,7 @@ use crate::context::ObsContext; use mwa_hyperdrive_common::{marlu, ndarray, vec1}; #[derive(Debug)] -pub(crate) enum VisInputType { +pub enum VisInputType { Raw, MeasurementSet, Uvfits, @@ -37,9 +37,11 @@ pub(crate) enum VisInputType { pub(crate) enum VisOutputType { #[strum(serialize = "uvfits")] Uvfits, + #[strum(serialize = "ms")] + MeasurementSet, } -pub(crate) trait InputData: Sync + Send { +pub trait InputData: Sync + Send { fn get_obs_context(&self) -> &ObsContext; fn get_input_data_type(&self) -> VisInputType; diff --git a/src/data_formats/ms/mod.rs b/src/data_formats/ms/mod.rs index ce701ac1..9c5f881a 100644 --- a/src/data_formats/ms/mod.rs +++ b/src/data_formats/ms/mod.rs @@ -32,7 +32,7 @@ use crate::{context::ObsContext, data_formats::metafits, time::round_hundredths_ use mwa_hyperdrive_beam::Delays; use mwa_hyperdrive_common::{hifitime, log, marlu, mwalib, ndarray}; -pub(crate) struct MS { +pub struct MS { /// Input data metadata. obs_context: ObsContext, @@ -59,7 +59,7 @@ impl MS { /// The measurement set is expected to be formatted in the way that /// cotter/Birli write measurement sets. // TODO: Handle multiple measurement sets. - pub(crate) fn new, P2: AsRef>( + pub fn new, P2: AsRef>( ms: P, metafits: Option, dipole_delays: &mut Delays, diff --git a/src/data_formats/uvfits/mod.rs b/src/data_formats/uvfits/mod.rs index 70ed8c95..1aa3ec9d 100644 --- a/src/data_formats/uvfits/mod.rs +++ b/src/data_formats/uvfits/mod.rs @@ -10,8 +10,8 @@ mod read; #[cfg(test)] mod tests; -pub(crate) use error::*; -pub(crate) use read::*; +pub use error::*; +pub use read::*; use hifitime::Epoch; diff --git a/src/data_formats/uvfits/read.rs b/src/data_formats/uvfits/read.rs index b4cd778c..3fb53c7f 100644 --- a/src/data_formats/uvfits/read.rs +++ b/src/data_formats/uvfits/read.rs @@ -25,7 +25,7 @@ use crate::{ use mwa_hyperdrive_beam::Delays; use mwa_hyperdrive_common::{log, marlu, mwalib, ndarray}; -pub(crate) struct UvfitsReader { +pub struct UvfitsReader { /// Observation metadata. pub(super) obs_context: ObsContext, @@ -48,7 +48,7 @@ impl UvfitsReader { /// /// The measurement set is expected to be formatted in the way that /// cotter/Birli write measurement sets. - pub(crate) fn new, P2: AsRef>( + pub fn new, P2: AsRef>( uvfits: P, metafits: Option, dipole_delays: &mut Delays, diff --git a/tests/integration/calibrate/mod.rs b/tests/integration/calibrate/mod.rs index 7910d52d..40e63032 100644 --- a/tests/integration/calibrate/mod.rs +++ b/tests/integration/calibrate/mod.rs @@ -8,12 +8,16 @@ mod cli_args; use approx::assert_abs_diff_eq; use clap::Parser; +use mwa_hyperdrive_beam::Delays; use mwalib::*; use serial_test::serial; use crate::*; -use mwa_hyperdrive::calibrate::{di_calibrate, solutions::CalibrationSolutions, CalibrateError}; -use mwa_hyperdrive_common::{clap, mwalib}; +use mwa_hyperdrive::{ + calibrate::{di_calibrate, solutions::CalibrationSolutions, CalibrateError}, + data_formats::{InputData, UvfitsReader, MS}, +}; +use mwa_hyperdrive_common::{clap, hifitime::Epoch, mwalib, setup_logging}; /// If di-calibrate is working, it should not write anything to stderr. #[test] @@ -274,13 +278,71 @@ fn test_1090008640_di_calibrate_writes_vis_uvfits_noautos_avg_freq() { gcount.parse::().unwrap(), exp_timesteps * exp_baselines ); - // let pcount: String = get_required_fits_key!(&mut out_vis, &hdu0, "PCOUNT").unwrap(); - // assert_eq!(pcount.parse::().unwrap(), 5); - // let floats_per_pol: String = get_required_fits_key!(&mut out_vis, &hdu0, "NAXIS2").unwrap(); - // assert_eq!(floats_per_pol.parse::().unwrap(), 3); - // let num_pols: String = get_required_fits_key!(&mut out_vis, &hdu0, "NAXIS3").unwrap(); - // assert_eq!(num_pols.parse::().unwrap(), 4); let num_fine_freq_chans: String = get_required_fits_key!(&mut out_vis, &hdu0, "NAXIS4").unwrap(); assert_eq!(num_fine_freq_chans.parse::().unwrap(), exp_channels); } + +#[test] +#[serial] +fn test_1090008640_di_calibrate_writes_vis_uvfits_ms_noautos() { + let tmp_dir = TempDir::new().expect("couldn't make tmp dir").into_path(); + let args = get_reduced_1090008640(true, false); + let data = args.data.unwrap(); + let metafits = &data[0]; + let gpufits = &data[1]; + let out_uvfits_path = tmp_dir.join("vis.uvfits"); + let out_ms_path = tmp_dir.join("vis.ms"); + let cal_model = tmp_dir.join("hyp_model.uvfits"); + + #[rustfmt::skip] + let cal_args = CalibrateUserArgs::parse_from(&[ + "di-calibrate", + "--data", metafits, gpufits, + "--source-list", &args.source_list.unwrap(), + "--outputs", + &format!("{}", out_uvfits_path.display()), + &format!("{}", out_ms_path.display()), + "--model-filename", &format!("{}", cal_model.display()), + "--ignore-autos", + ]); + + if out_ms_path.exists() { + std::fs::remove_dir_all(&out_ms_path).unwrap(); + } + + // Run di-cal and check that it succeeds + let result = di_calibrate::(Box::new(cal_args), None, false); + assert!(result.is_ok(), "result={:?} not ok", result.err().unwrap()); + + let exp_timesteps = 1; + let exp_channels = 32; + + // check uvfits file has been created, is readable + assert!(out_uvfits_path.exists(), "out vis file not written"); + + let uvfits_data = + UvfitsReader::new(&out_uvfits_path, Some(metafits), &mut Delays::None).unwrap(); + + let uvfits_ctx = uvfits_data.get_obs_context(); + + // check ms file has been created, is readable + assert!(out_ms_path.exists(), "out vis file not written"); + + let ms_data = MS::new(&out_ms_path, Some(metafits), &mut Delays::None).unwrap(); + + let ms_ctx = ms_data.get_obs_context(); + + // XXX(dev): Can't write obsid to ms file without MwaObsContext "MS obsid not available (no MWA_GPS_TIME in OBSERVATION table)" + // assert_eq!(uvfits_ctx.obsid, ms_ctx.obsid); + assert_eq!(uvfits_ctx.obsid, Some(1090008640)); + assert_eq!(uvfits_ctx.timestamps, ms_ctx.timestamps); + assert_eq!( + uvfits_ctx.timestamps, + vec![Epoch::from_gpst_seconds(1090008659.)] + ); + assert_eq!(uvfits_ctx.all_timesteps, ms_ctx.all_timesteps); + assert_eq!(uvfits_ctx.all_timesteps.len(), exp_timesteps); + assert_eq!(uvfits_ctx.fine_chan_freqs, ms_ctx.fine_chan_freqs); + assert_eq!(uvfits_ctx.fine_chan_freqs.len(), exp_channels); +}