Skip to content

Commit

Permalink
Use weights properly. Test uvfits/ms data reading.
Browse files Browse the repository at this point in the history
Previously, weights from visibility data were not read consistently. Now
they are, and flags are used by making visibility weights negative. This
commit adds a measurement set and uvfits file with a flag injected for
testing.

In the process of testing that MS, uvfits and raw data return exactly
the same visibilities, lots of little bugs were exposed. Raw data
timestamps and frequencies were slightly off, raw data corrections were
applied in a different order to Birli which made the visibilities not
match the MS and uvfits (now they match exactly), and the precession
info was generated on a potentially different timestamp. MS tile
positions are slightly less accurate than metafits positions, so use
those if possible.

The weights are now also used in calibration. Add a test to ensure this
is working as expected. When using CUDA, this slows calibration down by
20-30%. For reasons I don't understand, when not using CUDA, calibration
is slowed by ~70%. Maybe the code is being optimised differently when
CUDA is on or off. In any case, it may be worth trying to improve
calibration performance.

Overhaul the inner-most code to read data from MS and uvfits; these are
hopefully more efficient now.

Also alter the calibration quality test so it's easier to adjust in the
future.
  • Loading branch information
cjordan committed Mar 3, 2022
1 parent f2b9686 commit 662aaaa
Show file tree
Hide file tree
Showing 92 changed files with 863 additions and 366 deletions.
6 changes: 6 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,9 @@
*.fits binary
*.metafits binary
*.mwaf binary
# Measurement sets
*.dat binary
*.f0 binary
*.f0i binary
*.info binary
*.lock binary
2 changes: 2 additions & 0 deletions benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ fn calibrate_benchmarks(c: &mut Criterion) {

let vis_shape = (num_timesteps, num_baselines, num_chanblocks);
let vis_data: Array3<Jones<f32>> = Array3::from_elem(vis_shape, Jones::identity() * 4.0);
let vis_weights: Array3<f32> = Array3::ones(vis_shape);
let vis_model: Array3<Jones<f32>> = Array3::from_elem(vis_shape, Jones::identity());
let baseline_weights = vec![1.0; num_baselines];

Expand All @@ -541,6 +542,7 @@ fn calibrate_benchmarks(c: &mut Criterion) {
b.iter(|| {
calibrate_timeblocks(
vis_data.view(),
vis_weights.view(),
vis_model.view(),
&timeblocks,
&chanblocks,
Expand Down
145 changes: 80 additions & 65 deletions src/calibrate/di/code/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//! some things private so that they aren't misused.

#[cfg(test)]
mod tests;
pub(crate) mod tests;

use std::ops::Deref;

Expand Down Expand Up @@ -34,20 +34,20 @@ use crate::{
};
use mwa_hyperdrive_common::{cfg_if, hifitime, indicatif, log, marlu, ndarray, rayon};

pub(super) struct CalVis {
pub(crate) struct CalVis {
/// Visibilites read from input data.
pub(super) vis_data: Array3<Jones<f32>>,
pub(crate) vis_data: Array3<Jones<f32>>,

/// The weights on the visibilites read from input data.
pub(super) vis_weights: Array3<f32>,
pub(crate) vis_weights: Array3<f32>,

/// Visibilites generated from the sky-model source list.
pub(super) vis_model: Array3<Jones<f32>>,
pub(crate) vis_model: Array3<Jones<f32>>,
}

/// For calibration, read in unflagged visibilities and generate sky-model
/// visibilities.
pub(super) fn get_cal_vis(
pub(crate) fn get_cal_vis(
params: &CalibrateParams,
draw_progress_bar: bool,
) -> Result<CalVis, CalibrateError> {
Expand Down Expand Up @@ -580,6 +580,7 @@ impl<'a> IncompleteSolutions<'a> {
#[allow(clippy::too_many_arguments)]
pub fn calibrate_timeblocks<'a>(
vis_data: ArrayView3<Jones<f32>>,
vis_weights: ArrayView3<f32>,
vis_model: ArrayView3<Jones<f32>>,
timeblocks: &'a [Timeblock],
chanblocks: &'a [Chanblock],
Expand All @@ -590,6 +591,21 @@ pub fn calibrate_timeblocks<'a>(
draw_progress_bar: bool,
print_convergence_messages: bool,
) -> (IncompleteSolutions<'a>, Array2<CalibrationResult>) {
// Multiply the baseline weights against the visibility weights. Then, only
// the visibility weights need to be multiplied against the data and model
// visibilities.
assert_eq!(vis_weights.len_of(Axis(1)), baseline_weights.len());
let mut vis_weights = vis_weights.to_owned();
vis_weights
.axis_iter_mut(Axis(1))
.into_par_iter()
.zip(baseline_weights)
.for_each(|(mut vis_weights, &baseline_weight)| {
vis_weights.iter_mut().for_each(|vis_weight| {
*vis_weight = (*vis_weight as f64 * baseline_weight) as f32;
})
});

let num_unflagged_tiles = num_tiles_from_num_cross_correlation_baselines(vis_data.dim().1);
let num_timeblocks = timeblocks.len();
let num_chanblocks = chanblocks.len();
Expand All @@ -605,11 +621,11 @@ pub fn calibrate_timeblocks<'a>(
);
let cal_results = calibrate_timeblock(
vis_data.view(),
vis_weights.view(),
vis_model.view(),
di_jones.view_mut(),
timeblocks.first().unwrap(),
chanblocks,
baseline_weights,
max_iterations,
stop_threshold,
min_threshold,
Expand Down Expand Up @@ -657,11 +673,11 @@ pub fn calibrate_timeblocks<'a>(
);
let cal_results = calibrate_timeblock(
vis_data.view(),
vis_weights.view(),
vis_model.view(),
di_jones.view_mut(),
&timeblock,
chanblocks,
baseline_weights,
max_iterations,
stop_threshold,
min_threshold,
Expand Down Expand Up @@ -693,11 +709,11 @@ pub fn calibrate_timeblocks<'a>(
);
let mut cal_results = calibrate_timeblock(
vis_data.view(),
vis_weights.view(),
vis_model.view(),
di_jones.view_mut(),
timeblock,
chanblocks,
baseline_weights,
max_iterations,
stop_threshold,
min_threshold,
Expand Down Expand Up @@ -761,11 +777,11 @@ fn make_calibration_progress_bar(
#[allow(clippy::too_many_arguments)]
fn calibrate_timeblock(
vis_data: ArrayView3<Jones<f32>>,
vis_weights: ArrayView3<f32>,
vis_model: ArrayView3<Jones<f32>>,
mut di_jones: ArrayViewMut3<Jones<f64>>,
timeblock: &Timeblock,
chanblocks: &[Chanblock],
baseline_weights: &[f64],
max_iterations: usize,
stop_threshold: f64,
min_threshold: f64,
Expand Down Expand Up @@ -793,9 +809,9 @@ fn calibrate_timeblock(
];
let mut cal_result = calibrate(
vis_data.slice(range),
vis_weights.slice(range),
vis_model.slice(range),
di_jones,
baseline_weights,
max_iterations,
stop_threshold,
min_threshold,
Expand Down Expand Up @@ -920,9 +936,9 @@ fn calibrate_timeblock(
let range = s![timeblock.range.clone(), .., i_chanblock..i_chanblock + 1];
let mut new_cal_result = calibrate(
vis_data.slice(range),
vis_weights.slice(range),
vis_model.slice(range),
di_jones,
baseline_weights,
max_iterations,
stop_threshold,
min_threshold,
Expand Down Expand Up @@ -998,9 +1014,9 @@ pub struct CalibrationResult {
/// parallel code is inside this function.
pub(super) fn calibrate(
data: ArrayView3<Jones<f32>>,
weights: ArrayView3<f32>,
model: ArrayView3<Jones<f32>>,
mut di_jones: ArrayViewMut1<Jones<f64>>,
baseline_weights: &[f64],
max_iterations: usize,
stop_threshold: f64,
min_threshold: f64,
Expand Down Expand Up @@ -1028,8 +1044,8 @@ pub(super) fn calibrate(

calibration_loop(
data,
weights,
model,
baseline_weights,
di_jones.view(),
top.view_mut(),
bot.view_mut(),
Expand Down Expand Up @@ -1169,8 +1185,8 @@ pub(super) fn calibrate(
/// "MitchCal".
fn calibration_loop(
data: ArrayView3<Jones<f32>>,
weights: ArrayView3<f32>,
model: ArrayView3<Jones<f32>>,
baseline_weights: &[f64],
di_jones: ArrayView1<Jones<f64>>,
mut top: ArrayViewMut1<Jones<f64>>,
mut bot: ArrayViewMut1<Jones<f64>>,
Expand All @@ -1179,59 +1195,58 @@ fn calibration_loop(

// Time axis.
data.outer_iter()
.zip(weights.outer_iter())
.zip(model.outer_iter())
.for_each(|(data_time, model_time)| {
.for_each(|((data, weights), model)| {
// Unflagged baseline axis.
data_time
.outer_iter()
.zip(model_time.outer_iter())
.zip(baseline_weights.iter())
data.outer_iter()
.zip(weights.outer_iter())
.zip(model.outer_iter())
.enumerate()
.for_each(|(i_baseline, ((data_bl, model_bl), &baseline_weight))| {
// Don't do anything if the baseline weight is 0.
if baseline_weight.abs() > f64::EPSILON {
let (tile1, tile2) =
cross_correlation_baseline_to_tiles(num_tiles, i_baseline);

// Unflagged frequency chan axis.
data_bl
.iter()
.zip(model_bl.iter())
.for_each(|(j_data, j_model)| {
// Copy and promote the data and model Jones
// matrices.
let j_data: Jones<f64> = Jones::from(j_data) * baseline_weight;
let j_model: Jones<f64> = Jones::from(j_model) * baseline_weight;

// Suppress boundary checks for maximum performance!
unsafe {
let j_t1 = di_jones.uget(tile1);
let j_t2 = di_jones.uget(tile2);

let top_t1 = top.uget_mut(tile1);
let bot_t1 = bot.uget_mut(tile1);

// André's calibrate: ( D J M^H ) / ( M J^H J M^H )
// J M^H
let z = *j_t2 * j_model.h();
// D (J M^H)
*top_t1 += j_data * z;
// (J M^H)^H (J M^H)
*bot_t1 += z.h() * z;

let top_t2 = top.uget_mut(tile2);
let bot_t2 = bot.uget_mut(tile2);

// André's calibrate: ( D J M^H ) / ( M J^H J M^H )
// J (M^H)^H
let z = *j_t1 * j_model;
// D^H (J M^H)^H
*top_t2 += j_data.h() * z;
// (J M^H) (J M^H)
*bot_t2 += z.h() * z;
}
});
}
.for_each(|(i_baseline, ((data, weights), model))| {
let (tile1, tile2) = cross_correlation_baseline_to_tiles(num_tiles, i_baseline);

// Unflagged frequency chan axis.
data.iter()
.zip(weights)
.zip(model)
// Don't do anything if the weight is flagged.
.filter(|((_, weight), _)| **weight > 0.0)
.for_each(|((j_data, weight), j_model)| {
// Copy and promote the data and model Jones
// matrices.
let weight = *weight as f64;
let j_data: Jones<f64> = Jones::from(j_data) * weight;
let j_model: Jones<f64> = Jones::from(j_model) * weight;

// Suppress boundary checks for maximum performance!
unsafe {
let j_t1 = di_jones.uget(tile1);
let j_t2 = di_jones.uget(tile2);

let top_t1 = top.uget_mut(tile1);
let bot_t1 = bot.uget_mut(tile1);

// André's calibrate: ( D J M^H ) / ( M J^H J M^H )
// J M^H
let z = *j_t2 * j_model.h();
// D (J M^H)
*top_t1 += j_data * z;
// (J M^H)^H (J M^H)
*bot_t1 += z.h() * z;

let top_t2 = top.uget_mut(tile2);
let bot_t2 = bot.uget_mut(tile2);

// André's calibrate: ( D J M^H ) / ( M J^H J M^H )
// J (M^H)^H
let z = *j_t1 * j_model;
// D^H (J M^H)^H
*top_t2 += j_data.h() * z;
// (J M^H) (J M^H)
*bot_t2 += z.h() * z;
}
});
})
});
}
Loading

0 comments on commit 662aaaa

Please sign in to comment.