diff --git a/searcher/src/main.rs b/searcher/src/main.rs index f6b8fc0..995bdbb 100644 --- a/searcher/src/main.rs +++ b/searcher/src/main.rs @@ -38,6 +38,58 @@ struct Cli { output: Option, } +fn check_compatible_downsample( + me: &KmerMinHash, + other: &KmerMinHash, +) -> Result<(), sourmash::Error> { + /* + if self.num != other.num { + return Err(Error::MismatchNum { + n1: self.num, + n2: other.num, + } + .into()); + } + */ + use sourmash::Error; + + if me.ksize() != other.ksize() { + return Err(Error::MismatchKSizes); + } + if me.hash_function() != other.hash_function() { + // TODO: fix this error + return Err(Error::MismatchDNAProt); + } + if me.max_hash() < other.max_hash() { + return Err(Error::MismatchScaled); + } + if me.seed() != other.seed() { + return Err(Error::MismatchSeed); + } + Ok(()) +} + +fn prepare_query(search_sig: &Signature, template: &Sketch) -> Option { + let mut search_mh = None; + if let Some(Sketch::MinHash(mh)) = search_sig.select_sketch(template) { + search_mh = Some(mh.clone()); + } else { + // try to find one that can be downsampled + if let Sketch::MinHash(template_mh) = template { + for sketch in search_sig.sketches() { + if let Sketch::MinHash(ref_mh) = sketch { + if check_compatible_downsample(&ref_mh, template_mh).is_ok() { + let max_hash = max_hash_for_scaled(template_mh.scaled()); + let mh = ref_mh.downsample_max_hash(max_hash).unwrap(); + return Some(mh); + } + } + } + } + } + search_mh +} + fn search>( querylist: P, siglist: P, @@ -70,7 +122,7 @@ fn search>( let mut query = None; for sig in &query_sig { - if let Some(Sketch::MinHash(mh)) = sig.select_sketch(&template) { + if let Some(mh) = prepare_query(sig, &template) { query = Some((sig.name(), mh.clone())); } } @@ -125,7 +177,8 @@ fn search>( let mut search_mh = None; let search_sig = &Signature::from_path(&filename) .unwrap_or_else(|_| panic!("Error processing {:?}", filename))[0]; - if let Some(Sketch::MinHash(mh)) = search_sig.select_sketch(&template) { + + if let Some(mh) = prepare_query(search_sig, &template) { search_mh = Some(mh); } let search_mh = search_mh.unwrap(); @@ -135,7 +188,7 @@ fn search>( for (name, query) in &queries { let containment = - query.count_common(search_mh, false).unwrap() as f64 / query.size() as f64; + query.count_common(&search_mh, false).unwrap() as f64 / query.size() as f64; if containment > threshold { results.push((name.clone(), match_fn.clone(), containment)) } diff --git a/searcher/tests/searcher_cmd.rs b/searcher/tests/searcher_cmd.rs index bfb54ea..fa2ea2c 100644 --- a/searcher/tests/searcher_cmd.rs +++ b/searcher/tests/searcher_cmd.rs @@ -34,6 +34,33 @@ fn search() -> Result<(), Box> { Ok(()) } +#[test] +fn search_downsample() -> Result<(), Box> { + let mut cmd = Command::cargo_bin("searcher")?; + + let mut queries = NamedTempFile::new()?; + writeln!(queries, "tests/data/genome-s10.fa.gz.sig")?; + + let mut catalog = NamedTempFile::new()?; + writeln!(catalog, "tests/data/genome-s10.fa.gz.sig")?; + writeln!(catalog, "tests/data/genome-s11.fa.gz.sig")?; + writeln!(catalog, "tests/data/genome-s12.fa.gz.sig")?; + + cmd.args(&["--threshold", "0"]) + .args(&["-k", "31"]) + .args(&["--scaled", "20000"]) + .arg(queries.path()) + .arg(catalog.path()) + .assert() + .success() + .stdout(contains("query,Run,containment")) + .stdout(contains( + "../genome-s10.fa.gz','tests/data/genome-s10.fa.gz.sig',1", + )); + + Ok(()) +} + #[test] fn search_empty_query() -> Result<(), Box> { let mut cmd = Command::cargo_bin("searcher")?;