From 4b5f7425f7d17ae1c6b2bf7b27f2029bd145ef0a Mon Sep 17 00:00:00 2001 From: "Christopher H. Jordan" Date: Fri, 10 Mar 2023 03:18:39 +0800 Subject: [PATCH] Improve UX with a new expensive_op function. As stated in the docs: Perform this expensive operation as a normal Rust function, but if it takes more than a certain amount of time, display a message to the user that you're still waiting for this operation to complete. This function is more or less the same as one I wrote for raw-data reading ages ago, but now I've made it generic to be used for other routines. --- src/cli/di_calibrate/params.rs | 5 +- src/cli/srclist/by_beam/mod.rs | 5 +- src/cli/srclist/convert.rs | 5 +- src/cli/srclist/shift.rs | 5 +- src/cli/srclist/verify.rs | 26 ++++- src/cli/vis_utils/simulate/mod.rs | 5 +- src/cli/vis_utils/subtract/mod.rs | 5 +- src/di_calibrate/mod.rs | 23 +++- src/io/read/ms/mod.rs | 167 ++++++++++++++++-------------- src/io/read/raw/helpers.rs | 73 ------------- src/io/read/raw/mod.rs | 8 +- src/io/read/uvfits/mod.rs | 9 +- src/misc.rs | 70 +++++++++++++ 13 files changed, 236 insertions(+), 170 deletions(-) delete mode 100644 src/io/read/raw/helpers.rs diff --git a/src/cli/di_calibrate/params.rs b/src/cli/di_calibrate/params.rs index 4f3b63e0..29f2e9fd 100644 --- a/src/cli/di_calibrate/params.rs +++ b/src/cli/di_calibrate/params.rs @@ -977,7 +977,10 @@ impl DiCalParams { // kinds. let sl_type_specified = source_list_type.is_none(); let sl_type = source_list_type.and_then(|t| SourceListType::from_str(t.as_ref()).ok()); - let (sl, sl_type) = match read_source_list_file(sl_pb, sl_type) { + let (sl, sl_type) = match crate::misc::expensive_op( + || read_source_list_file(sl_pb, sl_type), + "Still reading source list file", + ) { Ok((sl, sl_type)) => (sl, sl_type), Err(e) => return Err(DiCalArgsError::from(e)), }; diff --git a/src/cli/srclist/by_beam/mod.rs b/src/cli/srclist/by_beam/mod.rs index f9d6debe..e23922f0 100644 --- a/src/cli/srclist/by_beam/mod.rs +++ b/src/cli/srclist/by_beam/mod.rs @@ -173,7 +173,10 @@ fn by_beam, S: AsRef>( ) -> Result<(), SrclistError> { // Read the input source list. let input_type = input_type.and_then(|t| SourceListType::from_str(t).ok()); - let (sl, sl_type) = read_source_list_file(input_path, input_type)?; + let (sl, sl_type) = crate::misc::expensive_op( + || read_source_list_file(input_path, input_type), + "Still reading source list file", + )?; if input_type.is_none() { info!( "Successfully read {} as a {}-style source list", diff --git a/src/cli/srclist/convert.rs b/src/cli/srclist/convert.rs index ced42f97..64d22db1 100644 --- a/src/cli/srclist/convert.rs +++ b/src/cli/srclist/convert.rs @@ -127,7 +127,10 @@ fn convert, S: AsRef>( } // Read the input source list. - let (sl, sl_type) = read_source_list_file(input_path, input_type)?; + let (sl, sl_type) = crate::misc::expensive_op( + || read_source_list_file(input_path, input_type), + "Still reading source list file", + )?; if input_type.is_none() { info!( "Successfully read {} as a {}-style source list", diff --git a/src/cli/srclist/shift.rs b/src/cli/srclist/shift.rs index 907b05e0..52216e93 100644 --- a/src/cli/srclist/shift.rs +++ b/src/cli/srclist/shift.rs @@ -128,7 +128,10 @@ fn shift, S: AsRef>( let f = BufReader::new(File::open(source_shifts_file)?); let source_shifts: BTreeMap = serde_json::from_reader(f).map_err(WriteSourceListError::from)?; - let (sl, sl_type) = read_source_list_file(source_list_file, input_type)?; + let (sl, sl_type) = crate::misc::expensive_op( + || read_source_list_file(source_list_file, input_type), + "Still reading source list file", + )?; info!( "Successfully read {} as a {}-style source list", source_list_file.display(), diff --git a/src/cli/srclist/verify.rs b/src/cli/srclist/verify.rs index 6e961af8..69a30e01 100644 --- a/src/cli/srclist/verify.rs +++ b/src/cli/srclist/verify.rs @@ -67,10 +67,22 @@ fn verify>( let (sl, sl_type) = if let Some(input_type) = input_type { let mut buf = std::io::BufReader::new(File::open(source_list)?); let result = match input_type { - SourceListType::Hyperdrive => hyperdrive::source_list_from_yaml(&mut buf), - SourceListType::AO => ao::parse_source_list(&mut buf), - SourceListType::Rts => rts::parse_source_list(&mut buf), - SourceListType::Woden => woden::parse_source_list(&mut buf), + SourceListType::Hyperdrive => crate::misc::expensive_op( + || hyperdrive::source_list_from_yaml(&mut buf), + "Still reading source list file", + ), + SourceListType::AO => crate::misc::expensive_op( + || ao::parse_source_list(&mut buf), + "Still reading source list file", + ), + SourceListType::Rts => crate::misc::expensive_op( + || rts::parse_source_list(&mut buf), + "Still reading source list file", + ), + SourceListType::Woden => crate::misc::expensive_op( + || woden::parse_source_list(&mut buf), + "Still reading source list file", + ), }; match result { Ok(sl) => (sl, input_type), @@ -81,7 +93,11 @@ fn verify>( } } } else { - match read_source_list_file(source_list, None) { + let source_list = source_list.as_ref(); + match crate::misc::expensive_op( + || read_source_list_file(source_list, None), + "Still reading source list file", + ) { Ok(sl) => sl, Err(e) => { info!("{}", e); diff --git a/src/cli/vis_utils/simulate/mod.rs b/src/cli/vis_utils/simulate/mod.rs index 794404d3..4dc6c48a 100644 --- a/src/cli/vis_utils/simulate/mod.rs +++ b/src/cli/vis_utils/simulate/mod.rs @@ -464,7 +464,10 @@ impl VisSimParams { }; // Read the source list. // TODO: Allow the user to specify a source list type. - let source_list = match read_source_list_file(sl_pb, None) { + let source_list: SourceList = match crate::misc::expensive_op( + || read_source_list_file(sl_pb, None), + "Still reading source list file", + ) { Ok((sl, sl_type)) => { debug!("Successfully parsed {}-style source list", sl_type); sl diff --git a/src/cli/vis_utils/subtract/mod.rs b/src/cli/vis_utils/subtract/mod.rs index fc738d3d..1f67dfd9 100644 --- a/src/cli/vis_utils/subtract/mod.rs +++ b/src/cli/vis_utils/subtract/mod.rs @@ -268,7 +268,10 @@ fn vis_subtract(args: VisSubtractArgs, dry_run: bool) -> Result<(), VisSubtractE let sl_type = source_list_type .as_ref() .and_then(|t| SourceListType::from_str(t.as_ref()).ok()); - let (sl, _) = match read_source_list_file(pb, sl_type) { + let (sl, _) = match crate::misc::expensive_op( + || read_source_list_file(pb, sl_type), + "Still reading source list file", + ) { Ok((sl, sl_type)) => (sl, sl_type), Err(e) => return Err(VisSubtractError::from(e)), }; diff --git a/src/di_calibrate/mod.rs b/src/di_calibrate/mod.rs index 5a4cd2fa..5a9a1f89 100644 --- a/src/di_calibrate/mod.rs +++ b/src/di_calibrate/mod.rs @@ -38,6 +38,7 @@ use crate::{ context::Polarisations, io::write::{write_vis, VisTimestep}, math::average_epoch, + misc::expensive_op, model::{self, ModellerInfo}, solutions::CalibrationSolutions, }; @@ -115,9 +116,25 @@ pub(crate) fn get_cal_vis( } debug!("Allocating memory for input data visibilities and model visibilities"); - let mut vis_data_tfb: Array3> = fallible_allocator!(Jones::default())?; - let mut vis_model_tfb: Array3> = fallible_allocator!(Jones::default())?; - let mut vis_weights_tfb: Array3 = fallible_allocator!(0.0)?; + let CalVis { + mut vis_data_tfb, + mut vis_model_tfb, + mut vis_weights_tfb, + pols: _, + } = expensive_op( + || -> Result<_, DiCalibrateError> { + let vis_data_tfb: Array3> = fallible_allocator!(Jones::default())?; + let vis_model_tfb: Array3> = fallible_allocator!(Jones::default())?; + let vis_weights_tfb: Array3 = fallible_allocator!(0.0)?; + Ok(CalVis { + vis_data_tfb, + vis_weights_tfb, + vis_model_tfb, + pols: Polarisations::default(), + }) + }, + "Still waiting to allocate visibility memory", + )?; // Sky-modelling communication channel. Used to tell the model writer when // visibilities have been generated and they're ready to be written. diff --git a/src/io/read/ms/mod.rs b/src/io/read/ms/mod.rs index c9c6863d..d4f67ced 100644 --- a/src/io/read/ms/mod.rs +++ b/src/io/read/ms/mod.rs @@ -351,51 +351,57 @@ impl MsReader { // ("unavailable tiles"). Iterate over the baselines (i.e. main table // rows) until we've seen all available antennas. let mut autocorrelations_present = false; - let (tile_map, unavailable_tiles): (HashMap, Vec) = { - let antenna1: Vec = main_table.get_col_as_vec("ANTENNA1")?; - let antenna2: Vec = main_table.get_col_as_vec("ANTENNA2")?; - - let mut present_tiles = HashSet::with_capacity(total_num_tiles); - for (&antenna1, &antenna2) in antenna1.iter().zip(antenna2.iter()) { - present_tiles.insert(antenna1); - present_tiles.insert(antenna2); - - if !autocorrelations_present && antenna1 == antenna2 { - autocorrelations_present = true; - } - } + let (tile_map, unavailable_tiles): (HashMap, Vec) = + crate::misc::expensive_op( + || { + let mut main_table = read_table(&ms, None)?; + let antenna1: Vec = main_table.get_col_as_vec("ANTENNA1")?; + let antenna2: Vec = main_table.get_col_as_vec("ANTENNA2")?; + + let mut present_tiles = HashSet::with_capacity(total_num_tiles); + for (&antenna1, &antenna2) in antenna1.iter().zip(antenna2.iter()) { + present_tiles.insert(antenna1); + present_tiles.insert(antenna2); + + if !autocorrelations_present && antenna1 == antenna2 { + autocorrelations_present = true; + } + } - // Ensure there aren't more tiles here than in the names or XYZs - // (names and XYZs are checked to be the same above). - if present_tiles.len() > tile_xyzs.len() { - return Err(MsReadError::MismatchNumMainTableNumXyzs { - main: present_tiles.len(), - xyzs: tile_xyzs.len(), - }); - } + // Ensure there aren't more tiles here than in the names or XYZs + // (names and XYZs are checked to be the same above). + if present_tiles.len() > tile_xyzs.len() { + return Err(MsReadError::MismatchNumMainTableNumXyzs { + main: present_tiles.len(), + xyzs: tile_xyzs.len(), + }); + } - // Ensure all MS antenna indices are positive and none are bigger - // than the number of XYZs. - for &i in &present_tiles { - if i < 0 { - return Err(MsReadError::AntennaNumNegative(i)); - } - if i as usize >= tile_xyzs.len() { - return Err(MsReadError::AntennaNumTooBig(i)); - } - } + // Ensure all MS antenna indices are positive and none are bigger + // than the number of XYZs. + for &i in &present_tiles { + if i < 0 { + return Err(MsReadError::AntennaNumNegative(i)); + } + if i as usize >= tile_xyzs.len() { + return Err(MsReadError::AntennaNumTooBig(i)); + } + } - let mut tile_map = HashMap::with_capacity(present_tiles.len()); - let mut unavailable_tiles = Vec::with_capacity(total_num_tiles - present_tiles.len()); - for i_tile in 0..total_num_tiles { - if let Some(v) = present_tiles.get(&(i_tile as i32)) { - tile_map.insert(*v, i_tile); - } else { - unavailable_tiles.push(i_tile); - } - } - (tile_map, unavailable_tiles) - }; + let mut tile_map = HashMap::with_capacity(present_tiles.len()); + let mut unavailable_tiles = + Vec::with_capacity(total_num_tiles - present_tiles.len()); + for i_tile in 0..total_num_tiles { + if let Some(v) = present_tiles.get(&(i_tile as i32)) { + tile_map.insert(*v, i_tile); + } else { + unavailable_tiles.push(i_tile); + } + } + Ok::<_, MsReadError>((tile_map, unavailable_tiles)) + }, + "Still waiting to determine MS antenna metadata", + )?; debug!("Autocorrelations present: {autocorrelations_present}"); debug!("Unavailable tiles: {unavailable_tiles:?}"); @@ -418,47 +424,52 @@ impl MsReader { // all flagged, and (by default) we are not interested in using any of // those data. We work out the first and last good timesteps by // inspecting the flags at each timestep. - let unflagged_timesteps: Vec = { - // The first and last good timestep indices. - let mut first: Option = None; - let mut last: Option = None; - - trace!("Searching for unflagged timesteps in the MS"); - for i_step in 0..(main_table.n_rows() as usize) / step { - trace!("Reading timestep {i_step}"); - let mut all_rows_for_step_flagged = true; - for i_row in 0..step { - let vis_flags: Vec = - main_table.get_cell_as_vec("FLAG", (i_step * step + i_row) as u64)?; - let all_flagged = vis_flags.into_iter().all(|f| f); - if !all_flagged { - all_rows_for_step_flagged = false; - if first.is_none() { - first = Some(i_step); - debug!("First good timestep: {i_step}"); + let unflagged_timesteps: Vec = crate::misc::expensive_op( + || { + // The first and last good timestep indices. + let mut first: Option = None; + let mut last: Option = None; + + trace!("Searching for unflagged timesteps in the MS"); + let mut main_table = read_table(&ms, None)?; + for i_step in 0..(main_table.n_rows() as usize) / step { + trace!("Reading timestep {i_step}"); + let mut all_rows_for_step_flagged = true; + for i_row in 0..step { + let vis_flags: Vec = + main_table.get_cell_as_vec("FLAG", (i_step * step + i_row) as u64)?; + let all_flagged = vis_flags.into_iter().all(|f| f); + if !all_flagged { + all_rows_for_step_flagged = false; + if first.is_none() { + first = Some(i_step); + debug!("First good timestep: {i_step}"); + } + break; } + } + if all_rows_for_step_flagged && first.is_some() { + last = Some(i_step); + debug!("Last good timestep: {}", i_step - 1); break; } } - if all_rows_for_step_flagged && first.is_some() { - last = Some(i_step); - debug!("Last good timestep: {}", i_step - 1); - break; - } - } - // Did the indices get set correctly? - match (first, last) { - (Some(f), Some(l)) => f..l, - // If there weren't any flags at the end of the MS, then the - // last timestep is fine. - (Some(f), None) => f..main_table.n_rows() as usize / step, - // All timesteps are flagged. The user can still use the MS, but - // they must specify some amount of flagged timesteps. - _ => 0..0, - } - } - .collect(); + // Did the indices get set correctly? + let timesteps = match (first, last) { + (Some(f), Some(l)) => f..l, + // If there weren't any flags at the end of the MS, then the + // last timestep is fine. + (Some(f), None) => f..main_table.n_rows() as usize / step, + // All timesteps are flagged. The user can still use the MS, but + // they must specify some amount of flagged timesteps. + _ => 0..0, + } + .collect(); + Ok::<_, TableError>(timesteps) + }, + "Still waiting to determine MS timesteps", + )?; // Neither Birli nor cotter utilise the "FLAG_ROW" column of the antenna // table. This is the best (only?) way to unambiguously identify flagged diff --git a/src/io/read/raw/helpers.rs b/src/io/read/raw/helpers.rs deleted file mode 100644 index 23335153..00000000 --- a/src/io/read/raw/helpers.rs +++ /dev/null @@ -1,73 +0,0 @@ -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -use std::path::PathBuf; -use std::time::Duration; - -use console::Term; -use crossbeam_channel::bounded; -use mwalib::{CorrelatorContext, MwalibError}; - -use crate::misc::is_a_tty; - -/// Wait this many seconds before printing a message that we're still waiting to -/// read gpubox files. -const READ_GPUBOX_WAIT_SECS: u64 = 2; - -/// Return a [CorrelatorContext] given the necessary files. -/// -/// It can take a while to create the correlator context because all gpubox -/// files need to be iterated over (big IO operation). To make the UX a bit -/// nicer, this function prints a message that we're waiting if the operation -/// takes a while. -pub(super) fn get_mwalib_correlator_context( - metafits: PathBuf, - gpuboxes: Vec, -) -> Result { - let (tx_context, rx_context) = bounded(1); - std::thread::spawn(move || tx_context.send(CorrelatorContext::new(&metafits, &gpuboxes))); - let mwalib_context = { - // Only print messages if we're in an interactive terminal. - let term = is_a_tty().then(Term::stderr); - - let mut total_wait_time = Duration::from_secs(0); - let inc_wait_time = Duration::from_millis(500); - let mut printed_wait_line = false; - // Loop forever until the context is ready. - loop { - // If the channel is full, then the context is ready. - if rx_context.is_full() { - // Clear the waiting line. - if let Some(term) = term.as_ref() { - if printed_wait_line { - term.move_cursor_up(1).expect("Couldn't move cursor up"); - term.clear_line().expect("Couldn't clear line"); - } - } - break; - } - // Otherwise we must wait longer. - std::thread::sleep(inc_wait_time); - total_wait_time += inc_wait_time; - if let Some(term) = term.as_ref() { - if total_wait_time.as_secs() >= READ_GPUBOX_WAIT_SECS { - if printed_wait_line { - term.move_cursor_up(1).expect("Couldn't move cursor up"); - term.clear_line().expect("Couldn't clear line"); - } - term.write_line(&format!( - "Still waiting to inspect all gpubox metadata: {:.2}s", - total_wait_time.as_secs_f64() - )) - .expect("Couldn't write line"); - printed_wait_line = true; - } - } - } - // Receive the context result. We can safely unwrap because we only - // break the loop when the channel is populated. - rx_context.recv().unwrap() - }; - mwalib_context -} diff --git a/src/io/read/raw/mod.rs b/src/io/read/raw/mod.rs index 40c374a3..ea7074e9 100644 --- a/src/io/read/raw/mod.rs +++ b/src/io/read/raw/mod.rs @@ -5,13 +5,11 @@ //! Code to handle reading from raw MWA files. mod error; -mod helpers; pub(crate) mod pfb_gains; #[cfg(test)] mod tests; pub(crate) use error::*; -use helpers::*; use std::{ collections::HashSet, @@ -168,8 +166,10 @@ impl RawDataReader { trace!("Using gpubox files: {:#?}", gpubox_pbs); trace!("Creating mwalib context"); - let mwalib_context = - get_mwalib_correlator_context(meta_pb, gpubox_pbs).map_err(Box::new)?; + let mwalib_context = crate::misc::expensive_op( + || CorrelatorContext::new(meta_pb, &gpubox_pbs).map_err(Box::new), + "Still waiting to inspect all gpubox metadata", + )?; let metafits_context = &mwalib_context.metafits_context; let is_mwax = match mwalib_context.mwa_version { diff --git a/src/io/read/uvfits/mod.rs b/src/io/read/uvfits/mod.rs index 73a3be89..ffd147d3 100644 --- a/src/io/read/uvfits/mod.rs +++ b/src/io/read/uvfits/mod.rs @@ -207,7 +207,14 @@ impl UvfitsReader { let tile_xyzs = Vec1::try_from_vec(tile_xyzs) .expect("can't be empty, non-empty tile names verified above"); - let metadata = UvfitsMetadata::new(&mut uvfits_fptr, &primary_hdu)?; + let metadata = crate::misc::expensive_op( + || { + let mut uvfits_fptr = fits_open(&uvfits)?; + let hdu = fits_open_hdu(&mut uvfits_fptr, 0)?; + UvfitsMetadata::new(&mut uvfits_fptr, &hdu) + }, + "Still waiting to inspect all uvfits metadata", + )?; // Make a nice little string for user display. uvfits always puts YY // before cross pols so we have to use some logic here. let pol_str = match metadata.pols { diff --git a/src/misc.rs b/src/misc.rs index 57948a9e..5ee3fd58 100644 --- a/src/misc.rs +++ b/src/misc.rs @@ -4,8 +4,78 @@ //! Miscellaneous things. +use std::thread; + +use console::Term; +use crossbeam_channel::bounded; use is_terminal::IsTerminal; pub(crate) fn is_a_tty() -> bool { std::io::stdout().is_terminal() || std::io::stderr().is_terminal() } + +/// Perform this expensive operation as a normal Rust function, but if it takes +/// more than a certain amount of time, display a message to the user that +/// you're still waiting for this operation to complete. +pub(crate) fn expensive_op(func: F, wait_message: &str) -> R +where + F: FnOnce() -> R + Send, + R: Send, +{ + use std::time::Duration; + + const INITIAL_WAIT_TIME: Duration = Duration::from_secs(1); + const INC_WAIT_TIME: Duration = Duration::from_millis(250); + + let (tx, rx) = bounded(1); + + thread::scope(|s| { + s.spawn(|| { + let r = func(); + tx.send(r).expect("receiver is not disconnected"); + }); + + // Only print messages if we're in an interactive terminal. + let term = is_a_tty().then(Term::stderr); + + let mut total_wait_time = Duration::from_secs(0); + let mut printed_wait_line = false; + // Loop forever until the return value is ready. + loop { + match rx.try_recv() { + // If the channel received a value, then we need to clean up + // before returning it to the caller. + Ok(r) => { + // Clear the waiting line. + if let Some(term) = term.as_ref() { + if printed_wait_line { + term.move_cursor_up(1).expect("Couldn't move cursor up"); + term.clear_line().expect("Couldn't clear line"); + } + } + + return r; + } + // Otherwise we must wait longer. + Err(_) => { + thread::sleep(INC_WAIT_TIME); + total_wait_time += INC_WAIT_TIME; + if let Some(term) = term.as_ref() { + if total_wait_time >= INITIAL_WAIT_TIME { + if printed_wait_line { + term.move_cursor_up(1).expect("Couldn't move cursor up"); + term.clear_line().expect("Couldn't clear line"); + } + term.write_line(&format!( + "{wait_message}: {:.2}s", + total_wait_time.as_secs_f64() + )) + .expect("Couldn't write line"); + printed_wait_line = true; + } + } + } + } + } + }) +}