Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MRG: fix performance regression in manysearch by removing unnecessary downsampling #464

Merged
merged 10 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/// Python interface Rust code for sourmash_plugin_branchwater.
use pyo3::prelude::*;
use singlesketch::singlesketch;

#[macro_use]
extern crate simple_error;
Expand All @@ -24,7 +23,7 @@ mod singlesketch;
use camino::Utf8PathBuf as PathBuf;

#[pyfunction]
#[pyo3(signature = (querylist_path, siglist_path, threshold, ksize, scaled, moltype, output_path=None))]
#[pyo3(signature = (querylist_path, siglist_path, threshold, ksize, scaled, moltype, output_path=None, ignore_abundance=false))]
fn do_manysearch(
querylist_path: String,
siglist_path: String,
Expand All @@ -33,14 +32,18 @@ fn do_manysearch(
scaled: usize,
moltype: String,
output_path: Option<String>,
ignore_abundance: Option<bool>,
) -> anyhow::Result<u8> {
let againstfile_path: PathBuf = siglist_path.clone().into();
let selection = build_selection(ksize, scaled, &moltype);
eprintln!("selection scaled: {:?}", selection.scaled());
let allow_failed_sigpaths = true;

let ignore_abundance = ignore_abundance.unwrap_or(false);

// if siglist_path is revindex, run mastiff_manysearch; otherwise run manysearch
if is_revindex_database(&againstfile_path) {
// note: mastiff_manysearch ignores abundance automatically.
match mastiff_manysearch::mastiff_manysearch(
querylist_path,
againstfile_path,
Expand All @@ -63,6 +66,7 @@ fn do_manysearch(
threshold,
output_path,
allow_failed_sigpaths,
ignore_abundance,
) {
Ok(_) => Ok(0),
Err(e) => {
Expand Down
181 changes: 112 additions & 69 deletions src/manysearch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@ use std::sync::atomic::AtomicUsize;

use crate::utils::{csvwriter_thread, load_collection, load_sketches, ReportType, SearchResult};
use sourmash::ani_utils::ani_from_containment;
use sourmash::errors::SourmashError;
use sourmash::selection::Selection;
use sourmash::signature::SigsTrait;
use sourmash::sketch::minhash::KmerMinHash;

pub fn manysearch(
query_filepath: String,
Expand All @@ -21,6 +23,7 @@ pub fn manysearch(
threshold: f64,
output: Option<String>,
allow_failed_sigpaths: bool,
ignore_abundance: bool,
) -> Result<()> {
// Load query collection
let query_collection = load_collection(
Expand Down Expand Up @@ -71,76 +74,71 @@ pub fn manysearch(
Ok(against_sig) => {
if let Some(against_mh) = against_sig.minhash() {
for query in query_sketchlist.iter() {
// to do - let user choose?
let calc_abund_stats = against_mh.track_abundance();

let against_mh_ds = against_mh.downsample_scaled(query.minhash.scaled()).unwrap();
let overlap =
query.minhash.count_common(&against_mh_ds, false).unwrap() as f64;

// avoid calculating details unless there is overlap
let overlap = query
.minhash
.count_common(against_mh, true)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, count_common handles downsampling only if needed. Didn't realized downsample_scaled always downsampled, even if not needed!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep!

.expect("incompatible sketches")
as f64;

let query_size = query.minhash.size() as f64;
let containment_query_in_target = overlap / query_size;
// only calculate results if we have shared hashes
if overlap > 0.0 {
let query_size = query.minhash.size() as f64;
let containment_query_in_target = overlap / query_size;
if containment_query_in_target > threshold {
let target_size = against_mh.size() as f64;
let containment_target_in_query = overlap / target_size;

let max_containment =
containment_query_in_target.max(containment_target_in_query);
let jaccard = overlap / (target_size + query_size - overlap);

let qani = ani_from_containment(
containment_query_in_target,
against_mh.ksize() as f64,
);
let mani = ani_from_containment(
containment_target_in_query,
against_mh.ksize() as f64,
);
let query_containment_ani = Some(qani);
let match_containment_ani = Some(mani);
let average_containment_ani = Some((qani + mani) / 2.);
let max_containment_ani = Some(f64::max(qani, mani));

let (total_weighted_hashes, n_weighted_found, average_abund, median_abund, std_abund) = if calc_abund_stats {
match query.minhash.inflated_abundances(&against_mh_ds) {
Ok((abunds, sum_weighted_overlap)) => {
let sum_all_abunds = against_mh_ds.sum_abunds() as usize;
let average_abund = sum_weighted_overlap as f64 / abunds.len() as f64;
let median_abund = median(abunds.iter().cloned()).unwrap();
let std_abund = stddev(abunds.iter().cloned());
(Some(sum_all_abunds), Some(sum_weighted_overlap as usize), Some(average_abund), Some(median_abund), Some(std_abund))
}
Err(e) => {
eprintln!("Error calculating abundances for query: {}, against: {}; Error: {}", query.name, against_sig.name(), e);
continue;
}
}
} else {
(None, None, None, None, None)
};

results.push(SearchResult {
query_name: query.name.clone(),
query_md5: query.md5sum.clone(),
match_name: against_sig.name(),
containment: containment_query_in_target,
intersect_hashes: overlap as usize,
match_md5: Some(against_sig.md5sum()),
jaccard: Some(jaccard),
max_containment: Some(max_containment),
average_abund,
median_abund,
std_abund,
query_containment_ani,
match_containment_ani,
average_containment_ani,
max_containment_ani,
n_weighted_found,
total_weighted_hashes,
});
}
if containment_query_in_target > threshold {
let target_size = against_mh.size() as f64;
let containment_target_in_query = overlap / target_size;

let max_containment =
containment_query_in_target.max(containment_target_in_query);
let jaccard = overlap / (target_size + query_size - overlap);

let qani = ani_from_containment(
containment_query_in_target,
against_mh.ksize() as f64,
);
let mani = ani_from_containment(
containment_target_in_query,
against_mh.ksize() as f64,
);
let query_containment_ani = Some(qani);
let match_containment_ani = Some(mani);
let average_containment_ani = Some((qani + mani) / 2.);
let max_containment_ani = Some(f64::max(qani, mani));

let calc_abund_stats =
against_mh.track_abundance() && !ignore_abundance;
let (
total_weighted_hashes,
n_weighted_found,
average_abund,
median_abund,
std_abund,
) = if calc_abund_stats {
downsample_and_inflate_abundances(&query.minhash, against_mh)
.ok()?
} else {
(None, None, None, None, None)
};

results.push(SearchResult {
query_name: query.name.clone(),
query_md5: query.md5sum.clone(),
match_name: against_sig.name(),
containment: containment_query_in_target,
intersect_hashes: overlap as usize,
match_md5: Some(against_sig.md5sum()),
jaccard: Some(jaccard),
max_containment: Some(max_containment),
average_abund,
median_abund,
std_abund,
query_containment_ani,
match_containment_ani,
average_containment_ani,
max_containment_ani,
n_weighted_found,
total_weighted_hashes,
});
}
}
} else {
Expand Down Expand Up @@ -197,3 +195,48 @@ pub fn manysearch(

Ok(())
}

fn downsample_and_inflate_abundances(
query: &KmerMinHash,
against: &KmerMinHash,
) -> Result<
(
Option<usize>,
Option<usize>,
Option<f64>,
Option<f64>,
Option<f64>,
),
SourmashError,
> {
let query_scaled = query.scaled();
let against_scaled = against.scaled();

let abunds: Vec<u64>;
let sum_weighted: u64;
let sum_all_abunds: usize;

// avoid downsampling if we can
if against_scaled != query_scaled {
let against_ds = against
.downsample_scaled(query.scaled())
.expect("cannot downsample sketch");
(abunds, sum_weighted) = query.inflated_abundances(&against_ds)?;
sum_all_abunds = against_ds.sum_abunds() as usize;
} else {
(abunds, sum_weighted) = query.inflated_abundances(against)?;
sum_all_abunds = against.sum_abunds() as usize;
}

let average_abund = sum_weighted as f64 / abunds.len() as f64;
let median_abund = median(abunds.iter().cloned()).expect("error");
let std_abund = stddev(abunds.iter().cloned());

Ok((
Some(sum_all_abunds),
Some(sum_weighted as usize),
Some(average_abund),
Some(median_abund),
Some(std_abund),
))
}
5 changes: 4 additions & 1 deletion src/python/sourmash_plugin_branchwater/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def __init__(self, p):
p.add_argument('-N', '--no-pretty-print', action='store_false',
dest='pretty_print',
help="do not display results (e.g. for large output)")
p.add_argument('--ignore-abundance', action='store_true',
help="do not do expensive abundance calculations")

def main(self, args):
print_version()
Expand All @@ -80,7 +82,8 @@ def main(self, args):
args.ksize,
args.scaled,
args.moltype,
args.output)
args.output,
args.ignore_abundance)
if status == 0:
notify(f"...manysearch is done! results in '{args.output}'")

Expand Down
Loading