From 88e406f467e9dba1fae4a077a87865f92d879352 Mon Sep 17 00:00:00 2001 From: "C. Titus Brown" Date: Tue, 8 Oct 2024 11:30:09 -0700 Subject: [PATCH] MRG: fix performance regression in `manysearch` by removing unnecessary downsampling (#464) * add support for ignoring abundance * cargo fmt * avoid downsampling until we know there is overlap * change downsample to true; add panic assertion * move downsampling side guard * eliminate redundant overlap check * move calc_abund_stats * extract abundance code into own function; avoid downsampling if poss * cleanup * fmt --- src/lib.rs | 8 +- src/manysearch.rs | 181 +++++++++++------- .../sourmash_plugin_branchwater/__init__.py | 5 +- 3 files changed, 122 insertions(+), 72 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 40789191..0f653337 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,5 @@ /// Python interface Rust code for sourmash_plugin_branchwater. use pyo3::prelude::*; -use singlesketch::singlesketch; #[macro_use] extern crate simple_error; @@ -24,7 +23,7 @@ mod singlesketch; use camino::Utf8PathBuf as PathBuf; #[pyfunction] -#[pyo3(signature = (querylist_path, siglist_path, threshold, ksize, scaled, moltype, output_path=None))] +#[pyo3(signature = (querylist_path, siglist_path, threshold, ksize, scaled, moltype, output_path=None, ignore_abundance=false))] fn do_manysearch( querylist_path: String, siglist_path: String, @@ -33,14 +32,18 @@ fn do_manysearch( scaled: usize, moltype: String, output_path: Option, + ignore_abundance: Option, ) -> anyhow::Result { let againstfile_path: PathBuf = siglist_path.clone().into(); let selection = build_selection(ksize, scaled, &moltype); eprintln!("selection scaled: {:?}", selection.scaled()); let allow_failed_sigpaths = true; + let ignore_abundance = ignore_abundance.unwrap_or(false); + // if siglist_path is revindex, run mastiff_manysearch; otherwise run manysearch if is_revindex_database(&againstfile_path) { + // note: mastiff_manysearch ignores abundance automatically. match mastiff_manysearch::mastiff_manysearch( querylist_path, againstfile_path, @@ -63,6 +66,7 @@ fn do_manysearch( threshold, output_path, allow_failed_sigpaths, + ignore_abundance, ) { Ok(_) => Ok(0), Err(e) => { diff --git a/src/manysearch.rs b/src/manysearch.rs index a200b52d..d343493d 100644 --- a/src/manysearch.rs +++ b/src/manysearch.rs @@ -11,8 +11,10 @@ use std::sync::atomic::AtomicUsize; use crate::utils::{csvwriter_thread, load_collection, load_sketches, ReportType, SearchResult}; use sourmash::ani_utils::ani_from_containment; +use sourmash::errors::SourmashError; use sourmash::selection::Selection; use sourmash::signature::SigsTrait; +use sourmash::sketch::minhash::KmerMinHash; pub fn manysearch( query_filepath: String, @@ -21,6 +23,7 @@ pub fn manysearch( threshold: f64, output: Option, allow_failed_sigpaths: bool, + ignore_abundance: bool, ) -> Result<()> { // Load query collection let query_collection = load_collection( @@ -71,76 +74,71 @@ pub fn manysearch( Ok(against_sig) => { if let Some(against_mh) = against_sig.minhash() { for query in query_sketchlist.iter() { - // to do - let user choose? - let calc_abund_stats = against_mh.track_abundance(); - - let against_mh_ds = against_mh.downsample_scaled(query.minhash.scaled()).unwrap(); - let overlap = - query.minhash.count_common(&against_mh_ds, false).unwrap() as f64; - + // avoid calculating details unless there is overlap + let overlap = query + .minhash + .count_common(against_mh, true) + .expect("incompatible sketches") + as f64; + + let query_size = query.minhash.size() as f64; + let containment_query_in_target = overlap / query_size; // only calculate results if we have shared hashes - if overlap > 0.0 { - let query_size = query.minhash.size() as f64; - let containment_query_in_target = overlap / query_size; - if containment_query_in_target > threshold { - let target_size = against_mh.size() as f64; - let containment_target_in_query = overlap / target_size; - - let max_containment = - containment_query_in_target.max(containment_target_in_query); - let jaccard = overlap / (target_size + query_size - overlap); - - let qani = ani_from_containment( - containment_query_in_target, - against_mh.ksize() as f64, - ); - let mani = ani_from_containment( - containment_target_in_query, - against_mh.ksize() as f64, - ); - let query_containment_ani = Some(qani); - let match_containment_ani = Some(mani); - let average_containment_ani = Some((qani + mani) / 2.); - let max_containment_ani = Some(f64::max(qani, mani)); - - let (total_weighted_hashes, n_weighted_found, average_abund, median_abund, std_abund) = if calc_abund_stats { - match query.minhash.inflated_abundances(&against_mh_ds) { - Ok((abunds, sum_weighted_overlap)) => { - let sum_all_abunds = against_mh_ds.sum_abunds() as usize; - let average_abund = sum_weighted_overlap as f64 / abunds.len() as f64; - let median_abund = median(abunds.iter().cloned()).unwrap(); - let std_abund = stddev(abunds.iter().cloned()); - (Some(sum_all_abunds), Some(sum_weighted_overlap as usize), Some(average_abund), Some(median_abund), Some(std_abund)) - } - Err(e) => { - eprintln!("Error calculating abundances for query: {}, against: {}; Error: {}", query.name, against_sig.name(), e); - continue; - } - } - } else { - (None, None, None, None, None) - }; - - results.push(SearchResult { - query_name: query.name.clone(), - query_md5: query.md5sum.clone(), - match_name: against_sig.name(), - containment: containment_query_in_target, - intersect_hashes: overlap as usize, - match_md5: Some(against_sig.md5sum()), - jaccard: Some(jaccard), - max_containment: Some(max_containment), - average_abund, - median_abund, - std_abund, - query_containment_ani, - match_containment_ani, - average_containment_ani, - max_containment_ani, - n_weighted_found, - total_weighted_hashes, - }); - } + if containment_query_in_target > threshold { + let target_size = against_mh.size() as f64; + let containment_target_in_query = overlap / target_size; + + let max_containment = + containment_query_in_target.max(containment_target_in_query); + let jaccard = overlap / (target_size + query_size - overlap); + + let qani = ani_from_containment( + containment_query_in_target, + against_mh.ksize() as f64, + ); + let mani = ani_from_containment( + containment_target_in_query, + against_mh.ksize() as f64, + ); + let query_containment_ani = Some(qani); + let match_containment_ani = Some(mani); + let average_containment_ani = Some((qani + mani) / 2.); + let max_containment_ani = Some(f64::max(qani, mani)); + + let calc_abund_stats = + against_mh.track_abundance() && !ignore_abundance; + let ( + total_weighted_hashes, + n_weighted_found, + average_abund, + median_abund, + std_abund, + ) = if calc_abund_stats { + downsample_and_inflate_abundances(&query.minhash, against_mh) + .ok()? + } else { + (None, None, None, None, None) + }; + + results.push(SearchResult { + query_name: query.name.clone(), + query_md5: query.md5sum.clone(), + match_name: against_sig.name(), + containment: containment_query_in_target, + intersect_hashes: overlap as usize, + match_md5: Some(against_sig.md5sum()), + jaccard: Some(jaccard), + max_containment: Some(max_containment), + average_abund, + median_abund, + std_abund, + query_containment_ani, + match_containment_ani, + average_containment_ani, + max_containment_ani, + n_weighted_found, + total_weighted_hashes, + }); } } } else { @@ -197,3 +195,48 @@ pub fn manysearch( Ok(()) } + +fn downsample_and_inflate_abundances( + query: &KmerMinHash, + against: &KmerMinHash, +) -> Result< + ( + Option, + Option, + Option, + Option, + Option, + ), + SourmashError, +> { + let query_scaled = query.scaled(); + let against_scaled = against.scaled(); + + let abunds: Vec; + let sum_weighted: u64; + let sum_all_abunds: usize; + + // avoid downsampling if we can + if against_scaled != query_scaled { + let against_ds = against + .downsample_scaled(query.scaled()) + .expect("cannot downsample sketch"); + (abunds, sum_weighted) = query.inflated_abundances(&against_ds)?; + sum_all_abunds = against_ds.sum_abunds() as usize; + } else { + (abunds, sum_weighted) = query.inflated_abundances(against)?; + sum_all_abunds = against.sum_abunds() as usize; + } + + let average_abund = sum_weighted as f64 / abunds.len() as f64; + let median_abund = median(abunds.iter().cloned()).expect("error"); + let std_abund = stddev(abunds.iter().cloned()); + + Ok(( + Some(sum_all_abunds), + Some(sum_weighted as usize), + Some(average_abund), + Some(median_abund), + Some(std_abund), + )) +} diff --git a/src/python/sourmash_plugin_branchwater/__init__.py b/src/python/sourmash_plugin_branchwater/__init__.py index 4280a257..2efc0bc6 100755 --- a/src/python/sourmash_plugin_branchwater/__init__.py +++ b/src/python/sourmash_plugin_branchwater/__init__.py @@ -65,6 +65,8 @@ def __init__(self, p): p.add_argument('-N', '--no-pretty-print', action='store_false', dest='pretty_print', help="do not display results (e.g. for large output)") + p.add_argument('--ignore-abundance', action='store_true', + help="do not do expensive abundance calculations") def main(self, args): print_version() @@ -80,7 +82,8 @@ def main(self, args): args.ksize, args.scaled, args.moltype, - args.output) + args.output, + args.ignore_abundance) if status == 0: notify(f"...manysearch is done! results in '{args.output}'")