Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
dagou committed Jul 3, 2024
1 parent 6a2fff1 commit 9d11486
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 81 deletions.
2 changes: 1 addition & 1 deletion kr2r/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "kr2r"
version = "0.5.9"
version = "0.6.0"
edition = "2021"
authors = ["eric9n@gmail.com"]

Expand Down
4 changes: 2 additions & 2 deletions kr2r/src/bin/estimate_capacity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use kr2r::args::KLMTArgs;
use kr2r::utils::{find_library_fna_files, format_bytes, open_file};
use kr2r::KBuildHasher;

use seqkmer::{read_parallel, FastaReader};
use seqkmer::{read_parallel, BufferFastaReader};
use serde_json;
use std::collections::HashSet;
use std::fs::File;
Expand Down Expand Up @@ -83,7 +83,7 @@ fn process_sequence<P: AsRef<Path>>(
let mut hllp: HyperLogLogPlus<u64, _> =
HyperLogLogPlus::new(16, KBuildHasher::default()).unwrap();

let mut reader = FastaReader::from_path(fna_file, 1)
let mut reader = BufferFastaReader::from_path(fna_file, 1)
.expect("Failed to open the FASTA file with FastaReader");
let range_n = args.n as u64;
read_parallel(
Expand Down
5 changes: 4 additions & 1 deletion kr2r/src/bin/hashshard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ fn mmap_read_write<P: AsRef<Path>, Q: AsRef<Path>>(
}

#[derive(Parser, Debug, Clone)]
#[clap(version, about = "split hash file", long_about = "split hash file")]
#[clap(
version,
about = "Convert Kraken2 database files to Kun-peng database format for efficient processing and analysis."
)]
pub struct Args {
/// The database directory for the Kraken 2 index. contains index files(hash.k2d opts.k2d taxo.k2d)
#[clap(long = "db", value_parser, required = true)]
Expand Down
96 changes: 65 additions & 31 deletions kr2r/src/bin/merge_fna.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ use clap::Parser;

use flate2::read::GzDecoder;
use kr2r::utils::{find_files, open_file};
use std::collections::HashMap;
use rayon::prelude::*;
use std::fs::{create_dir_all, File, OpenOptions};
use std::io::{BufRead, BufReader, BufWriter, Result, Write};
use std::io::{BufRead, BufReader, BufWriter, Read, Result, Write};
use std::path::PathBuf;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::Instant;

#[derive(Parser, Debug, Clone)]
Expand All @@ -23,8 +24,8 @@ pub struct Args {
// pub id_to_taxon_map_filename: Option<PathBuf>,
}

fn parse_assembly_fna(assembly_file: &PathBuf, site: &str) -> Result<HashMap<String, String>> {
let mut gz_files: HashMap<String, String> = HashMap::new();
fn parse_assembly_fna(assembly_file: &PathBuf, site: &str) -> Result<Vec<(String, String)>> {
let mut gz_files = Vec::new();
let file = open_file(&assembly_file)?;
let reader = BufReader::new(file);
let lines = reader.lines();
Expand Down Expand Up @@ -57,7 +58,7 @@ fn parse_assembly_fna(assembly_file: &PathBuf, site: &str) -> Result<HashMap<Str
site,
ftp_path.split('/').last().unwrap_or_default()
);
gz_files.insert(fna_file_name, taxid.into());
gz_files.push((fna_file_name, taxid.into()));
}
}
Ok(gz_files)
Expand Down Expand Up @@ -119,61 +120,94 @@ fn process_gz_file(
const PREFIX: &'static str = "assembly_summary";
const SUFFIX: &'static str = "txt";

fn merge_fna(assembly_files: &Vec<PathBuf>, database: &PathBuf) -> Result<()> {
fn merge_fna_parallel(assembly_files: &Vec<PathBuf>, database: &PathBuf) -> Result<()> {
let pattern = format!(r"{}_(\S+)\.{}", PREFIX, SUFFIX);
let file_site = regex::Regex::new(&pattern).unwrap();

let library_fna_path = database.join("library.fna");
let seqid2taxid_path = database.join("seqid2taxid.map");
let mut fna_writer = BufWriter::new(
OpenOptions::new()
.create(true)
.write(true)
.open(&library_fna_path)?,
);
let mut map_writer = BufWriter::new(
OpenOptions::new()
.create(true)
.write(true)
.open(&seqid2taxid_path)?,
);

let fna_start: regex::Regex = regex::Regex::new(r"^>(\S+)").unwrap();
let mut is_empty = true;
let is_empty = AtomicBool::new(true);
for assembly_file in assembly_files {
if let Some(caps) = file_site.captures(assembly_file.to_string_lossy().as_ref()) {
if let Some(matched) = caps.get(1) {
let gz_files = parse_assembly_fna(assembly_file, matched.as_str())?;

for (gz_path, taxid) in gz_files {
gz_files.par_iter().for_each(|(gz_path, taxid)| {
let gz_file = PathBuf::from(&gz_path);
if !gz_file.exists() {
// eprintln!("{} does not exist", gz_file.to_string_lossy());
continue;
return;
}
let thread_index = rayon::current_thread_index().unwrap_or(0);
let library_fna_path = database.join(format!("library_{}.fna", thread_index));
let seqid2taxid_path =
database.join(format!("seqid2taxid_{}.map", thread_index));
let mut fna_writer = BufWriter::new(
OpenOptions::new()
.create(true)
.append(true)
.write(true)
.open(&library_fna_path)
.unwrap(),
);
let mut map_writer = BufWriter::new(
OpenOptions::new()
.create(true)
.write(true)
.append(true)
.open(&seqid2taxid_path)
.unwrap(),
);

is_empty = false;
process_gz_file(
&gz_file,
&mut map_writer,
&mut fna_writer,
&fna_start,
&taxid,
)?;
}
)
.unwrap();

fna_writer.flush()?;
map_writer.flush()?;
fna_writer.flush().unwrap();
map_writer.flush().unwrap();
is_empty.fetch_and(false, Ordering::Relaxed);
});
}
}
}

if is_empty {
let fna_files = find_files(database, "library_", "fna");
let seqid_files = find_files(database, "seqid2taxid_", "map");
let library_fna_path = database.join("library.fna");
let seqid2taxid_path = database.join("seqid2taxid.map");
merge_files(&fna_files, &library_fna_path)?;
merge_files(&seqid_files, &seqid2taxid_path)?;
if is_empty.load(Ordering::Relaxed) {
panic!("genimics fna files is empty! please check download dir");
}
Ok(())
}

fn merge_files(paths: &Vec<PathBuf>, output_path: &PathBuf) -> Result<()> {
let mut output = BufWriter::new(File::create(output_path)?);
for path in paths {
let mut input = File::open(path)?;
let mut buffer = [0; 1024 * 1024]; // 使用 1MB 的缓冲区

// 逐块读取并写入
loop {
let bytes_read = input.read(&mut buffer)?;
if bytes_read == 0 {
break; // 文件读取完毕
}
output.write_all(&buffer[..bytes_read])?;
}
std::fs::remove_file(path)?;
}

output.flush()?;
Ok(())
}

pub fn run(args: Args) -> Result<()> {
// 开始计时
let start = Instant::now();
Expand Down Expand Up @@ -213,7 +247,7 @@ pub fn run(args: Args) -> Result<()> {
}
let assembly_files = find_files(&download_dir, &PREFIX, &SUFFIX);

merge_fna(&assembly_files, &args.database)?;
merge_fna_parallel(&assembly_files, &args.database)?;

// 计算持续时间
let duration = start.elapsed();
Expand Down
6 changes: 3 additions & 3 deletions kr2r/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
use crate::compact_hash::{Compact, HashConfig, Slot};
// use crate::mmscanner::MinimizerScanner;
use crate::taxonomy::{NCBITaxonomy, Taxonomy};
use seqkmer::Meros;
use seqkmer::{read_parallel, BufferFastaReader, Meros};

use crate::utils::open_file;
use byteorder::{LittleEndian, WriteBytesExt};
use rayon::prelude::*;
use seqkmer::{read_parallel, FastaReader};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Result as IOResult, Write};
Expand Down Expand Up @@ -199,7 +198,7 @@ pub fn convert_fna_to_k2_format<P: AsRef<Path>>(
chunk_size: usize,
threads: usize,
) {
let mut reader = FastaReader::from_path(fna_file, 1).unwrap();
let mut reader = BufferFastaReader::from_path(fna_file, 1).unwrap();
let value_bits = hash_config.value_bits;
let cell_size = std::mem::size_of::<Slot<u32>>();

Expand Down Expand Up @@ -230,6 +229,7 @@ pub fn convert_fna_to_k2_format<P: AsRef<Path>>(
}
});
}

Some(k2_cell_list)
},
|record_sets| {
Expand Down
148 changes: 127 additions & 21 deletions seqkmer/src/fasta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::io::{BufRead, BufReader, Read, Result};
use std::path::Path;

const SEQ_LIMIT: u64 = u64::pow(2, 32);

/// FastaReader
pub struct FastaReader<R>
where
Expand Down Expand Up @@ -40,27 +41,6 @@ where
}
}

pub fn read_next_entry<'a>(&'a mut self) -> Result<Option<(&'a Vec<u8>, &'a Vec<u8>)>> {
// 清空header和seq缓冲区
self.header.clear();
self.seq.clear();

// 读取header部分
if self.reader.read_until(b'\n', &mut self.header)? == 0 {
return Ok(None);
}
trim_end(&mut self.header);

// 读取seq部分
if self.reader.read_until(b'>', &mut self.seq)? == 0 {
return Ok(None);
}
trim_end(&mut self.seq);

// 返回header和seq的引用
Ok(Some((&self.header, &self.seq)))
}

pub fn read_next(&mut self) -> Result<Option<()>> {
// 读取fastq文件header部分
self.header.clear();
Expand Down Expand Up @@ -149,3 +129,129 @@ impl<R: Read + Send> Reader for FastaReader<R> {
Ok(if seqs.is_empty() { None } else { Some(seqs) })
}
}

/// BufferFastaReader
pub struct BufferFastaReader<R>
where
R: Read + Send,
{
reader: BufReader<R>,
file_index: usize,
reads_index: usize,
header: Vec<u8>,
seq: Vec<u8>,
line_num: usize,

// 批量读取
batch_size: usize,
}

impl<R> BufferFastaReader<R>
where
R: Read + Send,
{
pub fn new(reader: R, file_index: usize) -> Self {
Self::with_capacity(reader, file_index, BUFSIZE, 60)
}

pub fn with_capacity(reader: R, file_index: usize, capacity: usize, batch_size: usize) -> Self {
assert!(capacity >= 3);
Self {
reader: BufReader::with_capacity(capacity, reader),
file_index,
reads_index: 0,
line_num: 0,
header: Vec::new(),
seq: Vec::new(),
batch_size,
}
}

pub fn read_next(&mut self) -> Result<Option<()>> {
// 读取fastq文件header部分
if self.header.is_empty() {
if self.reader.read_until(b'\n', &mut self.header)? == 0 {
return Ok(None);
}
}

if self.reader.read_until(b'\n', &mut self.seq)? == 0 {
return Ok(None);
}
if self.seq.starts_with(&[b'>']) {
self.header = self.seq.clone();
self.seq.clear();
if self.reader.read_until(b'\n', &mut self.seq)? == 0 {
return Ok(None);
}
}
self.line_num += 1;
trim_end(&mut self.seq);
Ok(Some(()))
}

pub fn _next(&mut self) -> Result<Option<Base<Vec<u8>>>> {
self.seq.clear();
for _ in 0..self.batch_size {
if self.read_next()?.is_none() {
return Ok(None);
}
}

let seq_len = self.seq.len();
// 检查seq的长度是否大于2的32次方
if seq_len as u64 > SEQ_LIMIT {
eprintln!("Sequence length exceeds 2^32, which is not handled.");
return Ok(None);
}

let seq_id = unsafe {
let slice = if self.header.starts_with(b">") {
&self.header[1..]
} else {
&self.header[..]
};

let s = std::str::from_utf8_unchecked(slice);
let first_space_index = s
.as_bytes()
.iter()
.position(|&c| c == b' ')
.unwrap_or(s.len());

// 直接从原始切片创建第一个单词的切片
&s[..first_space_index]
};
self.reads_index += 1;

let seq_header = SeqHeader {
file_index: self.file_index,
reads_index: self.reads_index,
format: SeqFormat::Fasta,
id: seq_id.to_owned(),
};
Ok(Some(Base::new(
seq_header,
OptionPair::Single(self.seq.to_owned()),
)))
}
}

impl BufferFastaReader<Box<dyn Read + Send>> {
#[inline]
pub fn from_path<P: AsRef<Path>>(path: P, file_index: usize) -> Result<Self> {
let reader = dyn_reader(path)?;
Ok(Self::new(reader, file_index))
}
}

impl<R: Read + Send> Reader for BufferFastaReader<R> {
fn next(&mut self) -> Result<Option<Vec<Base<Vec<u8>>>>> {
let mut seqs = Vec::new();
if let Some(seq) = self._next()? {
seqs.push(seq);
}

Ok(if seqs.is_empty() { None } else { Some(seqs) })
}
}
Loading

0 comments on commit 9d11486

Please sign in to comment.