Skip to content

Commit

Permalink
hyperloglogplus
Browse files Browse the repository at this point in the history
  • Loading branch information
eric committed Jan 24, 2024
1 parent eea61ad commit 58934ff
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 60 deletions.
2 changes: 2 additions & 0 deletions kr2r/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ clap = { version = "4.4.10", features = ["derive"] }
seq_io = "0.3.2"
hyperloglogplus = { version = "*", features = ["const-loop"] }
seahash = "4.1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

walkdir = "2"
rayon = "1.8"
Expand Down
169 changes: 112 additions & 57 deletions kr2r/src/bin/estimate_capacity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,26 @@ use kr2r::utils::{expand_spaced_seed_mask, find_library_fna_files};
use kr2r::{sea_hash, KBuildHasher};
use seq_io::fasta::{Reader, Record};
use seq_io::parallel::read_parallel;
use serde_json;
use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::fs::File;
use std::io::{Read, Write};
use std::path::{Path, PathBuf};

#[derive(Parser, Debug)]
#[derive(Parser, Debug, Clone)]
#[clap(
version,
about = "estimate capacity",
long_about = "Estimates the size of the Kraken 2 hash table."
)]
struct Args {
/// 构建数据库的目录
#[arg(long = "db", default_value = "lib")]
database: PathBuf,
/// build database directory or file
#[arg(long, default_value = "lib")]
source: PathBuf,

/// estimate capacity from cache if exists
#[arg(long, default_value = "false")]
cache: bool,

/// Set length of k-mers, k must be positive integer, k=35, k cannot be less than l
#[clap(short, long, value_parser = clap::value_parser!(u64).range(1..), required = true)]
Expand Down Expand Up @@ -56,6 +62,86 @@ fn parse_binary(src: &str) -> Result<u64, std::num::ParseIntError> {
const RANGE_SECTIONS: u64 = 1024;
const RANGE_MASK: u64 = RANGE_SECTIONS - 1;

fn build_output_path(input_path: &str, extension: &str) -> String {
let path = Path::new(input_path);
let parent_dir = path.parent().unwrap_or_else(|| Path::new(""));
let stem = path.file_stem().unwrap_or_else(|| path.as_os_str());

let mut output_path = parent_dir.join(stem);
output_path.set_extension(extension);

output_path.to_str().unwrap().to_owned()
}

fn process_sequece(
fna_file: &str,
// hllp: &mut HyperLogLogPlus<u64, KBuildHasher>,
args: Args,
) -> HyperLogLogPlus<u64, KBuildHasher> {
// 构建预期的 JSON 文件路径
let json_path = build_output_path(fna_file, "hllp.json");
// 检查是否存在 JSON 文件
if Path::new(&json_path).exists() {
// 如果存在,从文件读取并反序列化
let mut file = File::open(json_path).unwrap();
let mut serialized_hllp = String::new();
file.read_to_string(&mut serialized_hllp).unwrap();
let hllp: HyperLogLogPlus<u64, KBuildHasher> =
serde_json::from_str(&serialized_hllp).unwrap();

return hllp;
}

let k_mer = args.k_mer as usize;
let l_mer = args.l_mer as usize;
let mut hllp: HyperLogLogPlus<u64, _> =
HyperLogLogPlus::new(16, KBuildHasher::default()).unwrap();

let reader = Reader::from_path(fna_file).unwrap();
read_parallel(
reader,
args.threads as u32,
args.threads - 2 as usize,
|record_set| {
let mut scanner = MinimizerScanner::default(k_mer, l_mer);
scanner.set_spaced_seed_mask(args.spaced_seed_mask);
if let Some(toggle_mask) = args.toggle_mask {
scanner.set_toggle_mask(toggle_mask);
}
let mut minimizer_set = HashSet::new();
for record in record_set.into_iter() {
let seq = record.seq();
scanner.set_seq_end(seq);
while let Some(minimizer) = scanner.next_minimizer(seq) {
let hash_v = sea_hash(minimizer);
if hash_v & RANGE_MASK < args.n as u64 {
minimizer_set.insert(hash_v);
}
}
scanner.reset();
}
minimizer_set
},
|record_sets| {
while let Some(Ok((_, m_set))) = record_sets.next() {
for minimizer in m_set {
hllp.insert(&minimizer);
}
// sets.extend(m_set);

// counter.fetch_add(count, Ordering::SeqCst);
}
},
);

// 序列化 hllp 对象并将其写入文件
let serialized_hllp = serde_json::to_string(&hllp).unwrap();
let mut file = File::create(&json_path).unwrap();
file.write_all(serialized_hllp.as_bytes()).unwrap();

hllp
}

fn main() {
let mut args = Args::parse();
if args.k_mer < args.l_mer as u64 {
Expand All @@ -66,63 +152,32 @@ fn main() {
args.spaced_seed_mask =
expand_spaced_seed_mask(args.spaced_seed_mask, BITS_PER_CHAR as u64);
}
let fna_files = find_library_fna_files(args.database);
let mut hllp: HyperLogLogPlus<u64, _> =

let mut hllp: HyperLogLogPlus<u64, KBuildHasher> =
HyperLogLogPlus::new(16, KBuildHasher::default()).unwrap();

let counter = AtomicUsize::new(0); // 初始化原子计数器
let source: PathBuf = args.source.clone();
let fna_files = if source.is_file() {
vec![source.to_string_lossy().to_string()]
} else {
find_library_fna_files(args.source)
};

// let sets: HashSet<u64> = HashSet::new();
for fna_file in fna_files {
println!("fna_file {:?}", fna_file);
let reader = Reader::from_path(fna_file).unwrap();
let k_mer = args.k_mer as usize;
let l_mer = args.l_mer as usize;

read_parallel(
reader,
args.threads as u32,
args.threads - 2 as usize,
|record_set| {
let mut count = 0;
let mut scanner = MinimizerScanner::default(k_mer, l_mer);
scanner.set_spaced_seed_mask(args.spaced_seed_mask);
if let Some(toggle_mask) = args.toggle_mask {
scanner.set_toggle_mask(toggle_mask);
}
let mut minimizer_set = HashSet::new();
for record in record_set.into_iter() {
let seq = record.seq();
scanner.set_seq_end(seq);
while let Some(minimizer) = scanner.next_minimizer(seq) {
count += 1;

let hash_v = sea_hash(minimizer);
if hash_v & RANGE_MASK < args.n as u64 {
minimizer_set.insert(hash_v);
}
}
scanner.reset();
}
(minimizer_set, count)
},
|record_sets| {
while let Some(Ok((_, (m_set, count)))) = record_sets.next() {
for minimizer in m_set {
hllp.insert(&minimizer);
}
// sets.extend(m_set);

counter.fetch_add(count, Ordering::SeqCst);
}
},
);
let args_clone = Args {
source: source.clone(),
..args
};
let local_hllp = process_sequece(&fna_file, args_clone);
if let Err(e) = hllp.merge(&local_hllp) {
println!("hllp merge err {:?}", e);
}
}

let final_count = counter.load(Ordering::SeqCst); // 读取计数器的最终值
// let final_count = counter.load(Ordering::SeqCst); // 读取计数器的最终值

let hllp_count = hllp.count();
// println!("sets {:?}", sets.len() * 1024 / args.n);
println!("Final count: {:?}", final_count);
println!("HLLP count: {:?}", hllp_count * 1024f64 / args.n as f64);
let hllp_count = (hllp.count() * RANGE_SECTIONS as f64 / args.n as f64).round() as u64;
// println!("Final count: {:?}", final_count);
println!("estimate count: {:?}", hllp_count);
}
7 changes: 4 additions & 3 deletions kr2r/src/kv_store.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use seahash::SeaHasher;
use serde::{Deserialize, Serialize};
use std::hash::{BuildHasher, Hasher};

// 定义 KeyValueStore trait
Expand Down Expand Up @@ -65,7 +66,7 @@ pub fn sea_hash(key: u64) -> u64 {
/// representing a `u64`. Using this hasher with input that is not 8 bytes,
/// or not properly representing a `u64`, may lead to undefined behavior including
/// but not limited to memory safety violations
#[derive(Default)]
#[derive(Default, Serialize, Deserialize)]
pub struct KHasher {
hash: u64,
}
Expand Down Expand Up @@ -100,7 +101,7 @@ impl Hasher for KHasher {
}
}

#[derive(Default)]
#[derive(Default, Serialize, Deserialize)]
pub struct KBuildHasher;

impl BuildHasher for KBuildHasher {
Expand All @@ -111,7 +112,7 @@ impl BuildHasher for KBuildHasher {
}
}

#[derive(Default)]
#[derive(Default, Serialize, Deserialize)]
pub struct SBuildHasher;

impl BuildHasher for SBuildHasher {
Expand Down

0 comments on commit 58934ff

Please sign in to comment.