From 609a086d7e7077cfc3cb47940cb713df1683104b Mon Sep 17 00:00:00 2001 From: dagou Date: Mon, 12 Aug 2024 14:52:52 +0800 Subject: [PATCH] batch size --- kr2r/src/args.rs | 10 ++- kr2r/src/bin/annotate.rs | 101 ++++++++++++++++++------ kr2r/src/bin/kun.rs | 2 +- kr2r/src/bin/merge_fna.rs | 2 +- kr2r/src/bin/resolve.rs | 159 +++++++++++++++++++------------------- kr2r/src/bin/splitr.rs | 5 -- kr2r/src/utils.rs | 48 ++++++++++++ 7 files changed, 212 insertions(+), 115 deletions(-) diff --git a/kr2r/src/args.rs b/kr2r/src/args.rs index 1be9129..f11bf9c 100644 --- a/kr2r/src/args.rs +++ b/kr2r/src/args.rs @@ -35,7 +35,7 @@ pub struct Build { pub threads: usize, } -const BATCH_SIZE: usize = 16 * 1024 * 1024; +const BUFFER_SIZE: usize = 16 * 1024 * 1024; /// Command line arguments for the classify program. /// @@ -84,8 +84,12 @@ pub struct ClassifyArgs { #[clap(short = 'p', long = "num-threads", value_parser, default_value_t = num_cpus::get())] pub num_threads: usize, - #[clap(long, default_value_t = BATCH_SIZE)] - pub batch_size: usize, + #[clap(long, default_value_t = BUFFER_SIZE)] + pub buffer_size: usize, + + /// The size of each batch for processing taxid match results, used to control memory usage + #[clap(long, default_value_t = 16)] + pub batch_size: u32, /// Confidence score threshold #[clap( diff --git a/kr2r/src/bin/annotate.rs b/kr2r/src/bin/annotate.rs index e7d4439..5367d6c 100644 --- a/kr2r/src/bin/annotate.rs +++ b/kr2r/src/bin/annotate.rs @@ -9,7 +9,7 @@ use std::path::Path; use std::path::PathBuf; use std::time::Instant; // 定义每批次处理的 Slot 数量 -pub const BATCH_SIZE: usize = 8 * 1024 * 1024; +pub const BUFFER_SIZE: usize = 8 * 1024 * 1024; /// Command line arguments for the splitr program. /// @@ -30,8 +30,12 @@ pub struct Args { #[clap(long)] pub chunk_dir: PathBuf, - #[clap(long, default_value_t = BATCH_SIZE)] - pub batch_size: usize, + #[clap(long, default_value_t = BUFFER_SIZE)] + pub buffer_size: usize, + + /// The size of each batch for processing taxid match results, used to control memory usage + #[clap(long, default_value_t = 16)] + pub batch_size: u32, /// The number of threads to use. #[clap(short = 'p', long = "num-threads", value_parser, default_value_t = num_cpus::get())] @@ -57,7 +61,7 @@ fn read_chunk_header(reader: &mut R) -> io::Result<(usize, usize)> { Ok((index as usize, chunk_size as usize)) } -fn write_to_file( +fn _write_to_file( file_index: u64, bytes: &[u8], last_file_index: &mut Option, @@ -87,12 +91,56 @@ fn write_to_file( Ok(()) } +fn write_to_file( + file_index: u64, + seq_id_mod: u32, + bytes: &[u8], + writers: &mut HashMap<(u64, u32), BufWriter>, + chunk_dir: &PathBuf, +) -> io::Result<()> { + // 检查是否已经有该文件的 writer,没有则创建一个新的 + let writer = writers.entry((file_index, seq_id_mod)).or_insert_with(|| { + let file_name = format!("sample_file_{}_{}.bin", file_index, seq_id_mod); + let file_path = chunk_dir.join(file_name); + let file = OpenOptions::new() + .create(true) + .append(true) + .open(&file_path) + .expect("failed to open file"); + BufWriter::new(file) + }); + + writer.write_all(bytes)?; + + Ok(()) +} + +fn clean_up_writers( + writers: &mut HashMap<(u64, u32), BufWriter>, + current_file_index: u64, +) -> io::Result<()> { + let keys_to_remove: Vec<(u64, u32)> = writers + .keys() + .cloned() + .filter(|(idx, _)| *idx != current_file_index) + .collect(); + + for key in keys_to_remove { + if let Some(mut writer) = writers.remove(&key) { + writer.flush()?; // 刷新并清理 + } + } + + Ok(()) +} + fn process_batch( reader: &mut R, hash_config: &HashConfig, chtm: &CHTable, chunk_dir: PathBuf, - batch_size: usize, + buffer_size: usize, + bin_threads: u32, page_index: usize, num_threads: usize, ) -> std::io::Result<()> @@ -100,8 +148,8 @@ where R: Read + Send, { let row_size = std::mem::size_of::(); - let mut last_file_index: Option = None; - let mut writer: Option> = None; + let mut writers: HashMap<(u64, u32), BufWriter> = HashMap::new(); + let mut current_file_index: Option = None; let value_mask = hash_config.get_value_mask(); let value_bits = hash_config.get_value_bits(); @@ -111,9 +159,9 @@ where buffer_read_parallel( reader, num_threads, - batch_size, + buffer_size, |dataset: Vec>| { - let mut results: HashMap> = HashMap::new(); + let mut results: HashMap<(u64, u32), Vec> = HashMap::new(); for slot in dataset { let indx = slot.idx & idx_mask; let compacted = slot.value.left(value_bits) as u32; @@ -127,9 +175,10 @@ where let high = u32::combined(left, taxid, value_bits); let row = Row::new(high, seq_id, kmer_id as u32); let value_bytes = row.as_slice(row_size); + let seq_id_mod = seq_id % bin_threads; results - .entry(file_index) + .entry((file_index, seq_id_mod)) .or_insert_with(Vec::new) .extend(value_bytes); } @@ -138,19 +187,19 @@ where }, |result| { while let Some(Some(res)) = result.next() { - let mut file_indices: Vec<_> = res.keys().cloned().collect(); - file_indices.sort_unstable(); // 对file_index进行排序 - - for file_index in file_indices { - if let Some(bytes) = res.get(&file_index) { - write_to_file( - file_index, - bytes, - &mut last_file_index, - &mut writer, - &chunk_dir, - ) - .expect("write to file error"); + let mut file_keys: Vec<_> = res.keys().cloned().collect(); + file_keys.sort_unstable(); // 对 (file_index, seq_id_mod) 进行排序 + + for (file_index, seq_id_mod) in file_keys { + if let Some(bytes) = res.get(&(file_index, seq_id_mod)) { + // 如果当前处理的 file_index 改变了,清理非当前的 writers + if current_file_index != Some(file_index) { + clean_up_writers(&mut writers, file_index).expect("clean writer"); + current_file_index = Some(file_index); + } + + write_to_file(file_index, seq_id_mod, bytes, &mut writers, &chunk_dir) + .expect("write to file error"); } } } @@ -158,8 +207,9 @@ where ) .expect("failed"); - if let Some(w) = writer.as_mut() { - w.flush()?; + // 最终批次处理完成后,刷新所有的 writer + for writer in writers.values_mut() { + writer.flush()?; } Ok(()) @@ -190,6 +240,7 @@ fn process_chunk_file>( &config, &chtm, args.chunk_dir.clone(), + args.buffer_size, args.batch_size, page_index, args.num_threads, diff --git a/kr2r/src/bin/kun.rs b/kr2r/src/bin/kun.rs index e21f2d4..d54f0f5 100644 --- a/kr2r/src/bin/kun.rs +++ b/kr2r/src/bin/kun.rs @@ -81,6 +81,7 @@ impl From for annotate::Args { database: item.database, chunk_dir: item.chunk_dir, batch_size: item.batch_size, + buffer_size: item.buffer_size, num_threads: item.num_threads, } } @@ -91,7 +92,6 @@ impl From for resolve::Args { Self { database: item.database, chunk_dir: item.chunk_dir, - batch_size: item.batch_size, confidence_threshold: item.confidence_threshold, minimum_hit_groups: item.minimum_hit_groups, kraken_output_dir: item.kraken_output_dir, diff --git a/kr2r/src/bin/merge_fna.rs b/kr2r/src/bin/merge_fna.rs index 710c868..f65ffd4 100644 --- a/kr2r/src/bin/merge_fna.rs +++ b/kr2r/src/bin/merge_fna.rs @@ -16,7 +16,7 @@ use std::time::Instant; #[clap(version, about = "A tool for processing genomic files")] pub struct Args { /// Directory to store downloaded files - #[arg(short, long, default_value = "lib")] + #[arg(short, long, required = true)] pub download_dir: PathBuf, /// ncbi library fna database directory diff --git a/kr2r/src/bin/resolve.rs b/kr2r/src/bin/resolve.rs index 2daac9d..775b9a6 100644 --- a/kr2r/src/bin/resolve.rs +++ b/kr2r/src/bin/resolve.rs @@ -4,7 +4,7 @@ use kr2r::compact_hash::{HashConfig, Row}; use kr2r::readcounts::{TaxonCounters, TaxonCountersDash}; use kr2r::report::report_kraken_style; use kr2r::taxonomy::Taxonomy; -use kr2r::utils::{find_and_trans_files, open_file}; +use kr2r::utils::{find_and_trans_bin_files, find_and_trans_files, open_file}; use kr2r::HitGroup; // use rayon::prelude::*; use seqkmer::{buffer_map_parallel, trim_pair_info, OptionPair}; @@ -15,8 +15,6 @@ use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Instant; -const BATCH_SIZE: usize = 16 * 1024 * 1024; - pub fn read_id_to_seq_map>( filename: P, ) -> Result)>> { @@ -96,9 +94,6 @@ pub struct Args { )] pub minimum_hit_groups: usize, - #[clap(long, default_value_t = BATCH_SIZE)] - pub batch_size: usize, - /// The number of threads to use. #[clap(short = 'p', long = "num-threads", value_parser, default_value_t = num_cpus::get())] pub num_threads: usize, @@ -119,7 +114,7 @@ fn read_rows_from_file>(file_path: P) -> io::Result>( - sample_file: P, + sample_files: &Vec

, args: &Args, taxonomy: &Taxonomy, id_map: &HashMap)>, @@ -130,57 +125,59 @@ fn process_batch>( let confidence_threshold = args.confidence_threshold; let minimum_hit_groups = args.minimum_hit_groups; - let hit_counts: HashMap> = read_rows_from_file(sample_file)?; - let classify_counter = AtomicUsize::new(0); let cur_taxon_counts = TaxonCountersDash::new(); - buffer_map_parallel( - &hit_counts, - num_cpus::get(), - |(k, rows)| { - if let Some(item) = id_map.get(&k) { - let mut rows = rows.to_owned(); - rows.sort_unstable(); - let dna_id = trim_pair_info(&item.0); - let range = - OptionPair::from(((0, item.2), item.3.map(|size| (item.2, size + item.2)))); - let hits = HitGroup::new(rows, range); - - let hit_data = process_hitgroup( - &hits, - taxonomy, - &classify_counter, - hits.required_score(confidence_threshold), - minimum_hit_groups, - value_mask, - ); - - hit_data.3.iter().for_each(|(key, value)| { - cur_taxon_counts - .entry(*key) - .or_default() - .merge(value) - .unwrap(); - }); + for sample_file in sample_files { + let hit_counts: HashMap> = read_rows_from_file(sample_file)?; + + buffer_map_parallel( + &hit_counts, + num_cpus::get(), + |(k, rows)| { + if let Some(item) = id_map.get(&k) { + let mut rows = rows.to_owned(); + rows.sort_unstable(); + let dna_id = trim_pair_info(&item.0); + let range = + OptionPair::from(((0, item.2), item.3.map(|size| (item.2, size + item.2)))); + let hits = HitGroup::new(rows, range); + + let hit_data = process_hitgroup( + &hits, + taxonomy, + &classify_counter, + hits.required_score(confidence_threshold), + minimum_hit_groups, + value_mask, + ); - // 使用锁来同步写入 - let output_line = format!( - "{}\t{}\t{}\t{}\t{}\n", - hit_data.0, dna_id, hit_data.1, item.1, hit_data.2 - ); - Some(output_line) - } else { - None - } - }, - |result| { - while let Some(Some(res)) = result.next() { - writer.write_all(res.as_bytes()).unwrap(); - } - }, - ) - .expect("failed"); + hit_data.3.iter().for_each(|(key, value)| { + cur_taxon_counts + .entry(*key) + .or_default() + .merge(value) + .unwrap(); + }); + + // 使用锁来同步写入 + let output_line = format!( + "{}\t{}\t{}\t{}\t{}\n", + hit_data.0, dna_id, hit_data.1, item.1, hit_data.2 + ); + Some(output_line) + } else { + None + } + }, + |result| { + while let Some(Some(res)) = result.next() { + writer.write_all(res.as_bytes()).unwrap(); + } + }, + ) + .expect("failed"); + } Ok(( cur_taxon_counts, @@ -194,7 +191,7 @@ pub fn run(args: Args) -> Result<()> { let taxonomy_filename = k2d_dir.join("taxo.k2d"); let taxo = Taxonomy::from_file(taxonomy_filename)?; - let sample_files = find_and_trans_files(&args.chunk_dir, "sample_file", ".bin", false)?; + let sample_files = find_and_trans_bin_files(&args.chunk_dir, "sample_file", r".bin", false)?; let sample_id_files = find_and_trans_files(&args.chunk_dir, "sample_id", ".map", false)?; // let partition = sample_files.len(); @@ -209,9 +206,7 @@ pub fn run(args: Args) -> Result<()> { let start = Instant::now(); println!("resolve start..."); - for (i, sample_file) in &sample_files { - // for i in 0..partition { - // let sample_file = &sample_files[i]; + for (i, sam_files) in &sample_files { let sample_id_map = read_id_to_seq_map(&sample_id_files[i])?; let thread_sequences = sample_id_map.len(); @@ -223,8 +218,8 @@ pub fn run(args: Args) -> Result<()> { } None => Box::new(BufWriter::new(io::stdout())) as Box, }; - let (thread_taxon_counts, thread_classified, hit_seq_set) = process_batch::<&PathBuf>( - &sample_file, + let (thread_taxon_counts, thread_classified, hit_seq_set) = process_batch::( + sam_files, &args, &taxo, &sample_id_map, @@ -283,25 +278,27 @@ pub fn run(args: Args) -> Result<()> { } if let Some(output) = &args.kraken_output_dir { - let min = &sample_files.keys().min().cloned().unwrap(); - let max = &sample_files.keys().max().cloned().unwrap(); - - if max > min { - let filename = output.join(format!("output_{}-{}.kreport2", min, max)); - report_kraken_style( - filename, - args.report_zero_counts, - args.report_kmer_data, - &taxo, - &total_taxon_counts, - total_seqs as u64, - total_unclassified as u64, - )?; - } + if !sample_files.is_empty() { + let min = &sample_files.keys().min().cloned().unwrap(); + let max = &sample_files.keys().max().cloned().unwrap(); + + if max > min { + let filename = output.join(format!("output_{}-{}.kreport2", min, max)); + report_kraken_style( + filename, + args.report_zero_counts, + args.report_kmer_data, + &taxo, + &total_taxon_counts, + total_seqs as u64, + total_unclassified as u64, + )?; + } - let source_sample_file = args.chunk_dir.join("sample_file.map"); - let to_sample_file = output.join("sample_file.txt"); - std::fs::copy(source_sample_file, to_sample_file)?; + let source_sample_file = args.chunk_dir.join("sample_file.map"); + let to_sample_file = output.join("sample_file.txt"); + std::fs::copy(source_sample_file, to_sample_file)?; + }; } // 计算持续时间 @@ -309,8 +306,10 @@ pub fn run(args: Args) -> Result<()> { // 打印运行时间 println!("resolve took: {:?}", duration); - for (_, sample_file) in &sample_files { - let _ = std::fs::remove_file(sample_file); + for (_, sam_files) in &sample_files { + for sample_file in sam_files { + let _ = std::fs::remove_file(sample_file); + } } for (_, sample_file) in sample_id_files { diff --git a/kr2r/src/bin/splitr.rs b/kr2r/src/bin/splitr.rs index cc2aa77..0594613 100644 --- a/kr2r/src/bin/splitr.rs +++ b/kr2r/src/bin/splitr.rs @@ -225,11 +225,6 @@ where )?; file_writer.flush().unwrap(); - create_sample_file( - args.chunk_dir - .join(format!("sample_file_{}.bin", file_index)), - ); - action(file_index, path_pair)?; } diff --git a/kr2r/src/utils.rs b/kr2r/src/utils.rs index eb414b6..66e75d6 100644 --- a/kr2r/src/utils.rs +++ b/kr2r/src/utils.rs @@ -190,6 +190,54 @@ pub fn create_sample_file>(filename: P) -> BufWriter { use regex::Regex; +pub fn find_and_trans_bin_files( + directory: &Path, + prefix: &str, + suffix: &str, + check: bool, +) -> io::Result>> { + // 改为聚合相同数字的文件路径 + // 构建正则表达式以匹配文件名中的第一个数字 + let pattern = format!(r"{}_(\d+)_\d+{}", prefix, suffix); + let re = Regex::new(&pattern).expect("Invalid regex pattern"); + + // 读取指定目录下的所有条目 + let mut map_entries = Map::new(); + for entry in fs::read_dir(directory)? { + let path = entry?.path(); + + if path.is_file() { + if let Some(file_name) = path.file_name().and_then(|name| name.to_str()) { + // 使用正则表达式匹配文件名,并提取第一个数字部分 + if let Some(cap) = re.captures(file_name) { + if let Some(m) = cap.get(1) { + if let Ok(num) = m.as_str().parse::() { + map_entries.entry(num).or_insert_with(Vec::new).push(path); + } + } + } + } + } + } + + if check { + // 检查数字是否从1开始连续 + let mut keys: Vec<_> = map_entries.keys().cloned().collect(); + keys.sort_unstable(); + for (i, &key) in keys.iter().enumerate() { + if i + 1 != key { + return Err(io::Error::new( + io::ErrorKind::NotFound, + "File numbers are not continuous starting from 1.", + )); + } + } + } + + // 返回聚合后的文件路径 + Ok(map_entries) +} + pub fn find_and_trans_files( directory: &Path, prefix: &str,