Skip to content

Commit

Permalink
Refactor sequence handling into trait
Browse files Browse the repository at this point in the history
  • Loading branch information
Roderick Bovee committed Sep 10, 2019
1 parent 4621ab0 commit 6c032c4
Show file tree
Hide file tree
Showing 10 changed files with 317 additions and 247 deletions.
52 changes: 29 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,39 @@ Needletail's goal is to be as fast as the [readfq](https://github.com/lh3/readfq

```rust
extern crate needletail;
use needletail::{parse_sequences, Sequence};
use std::env;
use std::fs::File;
use needletail::parse_sequences;

fn main() {
let filename: String = env::args().nth(1).unwrap();

let mut n_bases = 0;
let mut n_valid_kmers = 0;
parse_sequences(File::open(filename).expect("missing file"), |_| {}, |seq| {
// seq.id is the name of the record
// seq.seq is the base sequence
// seq.qual is an optional quality score

// keep track of the total number of bases
n_bases += seq.seq.len();

// keep track of the number of AAAA (or TTTT via canonicalization) in the
// file (normalize makes sure ever base is capitalized for comparison)
for (_, kmer, _) in seq.normalize(false).kmers(4, true) {
if kmer == b"AAAA" {
n_valid_kmers += 1;
}
}
}).expect("parsing failed");
println!("There are {} bases in your file.", n_bases);
println!("There are {} AAAAs in your file.", n_valid_kmers);
let filename: String = env::args().nth(1).unwrap();

let mut n_bases = 0;
let mut n_valid_kmers = 0;
parse_sequences(
File::open(filename).expect("missing file"),
|_| {},
|seq| {
// seq.id is the name of the record
// seq.seq is the base sequence
// seq.qual is an optional quality score

// keep track of the total number of bases
n_bases += seq.seq.len();

// keep track of the number of AAAA (or TTTT via canonicalization) in the
// file (normalize makes sure ever base is capitalized for comparison)
let rc = seq.reverse_complement();
for (_, kmer, _) in seq.normalize(false).canonical_kmers(4, &rc) {
if kmer == b"AAAA" {
n_valid_kmers += 1;
}
}
},
)
.expect("parsing failed");
println!("There are {} bases in your file.", n_bases);
println!("There are {} AAAAs in your file.", n_valid_kmers);
}
```

Expand Down
11 changes: 6 additions & 5 deletions benches/benchmark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ extern crate needletail;

use criterion::Criterion;
use needletail::parse_sequences;
use needletail::seq::Sequence;
use std::fs::File;
use std::io::{Cursor, Read};

Expand All @@ -28,8 +29,10 @@ fn bench_kmer_speed(c: &mut Criterion) {
parse_sequences(
fasta_data,
|_| {},
|seq| {
for (_, _kmer, was_rc) in seq.normalize(true).kmers(ksize, true) {
|rec| {
let seq = rec.seq.normalize(true);
let rc = seq.reverse_complement();
for (_, _kmer, was_rc) in seq.canonical_kmers(ksize, &rc) {
if !was_rc {
n_canonical += 1;
}
Expand Down Expand Up @@ -192,13 +195,11 @@ fn bench_fasta_file(c: &mut Criterion) {

group.bench_function("Needletail (No Buffer)", |bench| {
use needletail::formats::{FastaParser, RecParser};
use needletail::seq::Sequence;
bench.iter(|| {
let mut reader = FastaParser::from_buffer(&data, true);
let mut n_bases = 0;
for rec in reader.by_ref() {
let seq = Sequence::from(rec.unwrap());
n_bases += seq.seq.len();
n_bases += rec.unwrap().seq.strip_returns().len();
}
assert_eq!(738_580, n_bases);
});
Expand Down
4 changes: 2 additions & 2 deletions src/bitkmer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl<'a> BitNuclKmer<'a> {
pub fn new(slice: &'a [u8], k: u8, canonical: bool) -> BitNuclKmer<'a> {
let mut kmer = (0u64, k);
let mut start_pos = 0;
update_position(&mut start_pos, &mut kmer, slice, true);
update_position(&mut start_pos, &mut kmer, &slice, true);

BitNuclKmer {
start_pos,
Expand All @@ -78,7 +78,7 @@ impl<'a> Iterator for BitNuclKmer<'a> {
type Item = (usize, BitKmer, bool);

fn next(&mut self) -> Option<(usize, BitKmer, bool)> {
if !update_position(&mut self.start_pos, &mut self.cur_kmer, self.buffer, false) {
if !update_position(&mut self.start_pos, &mut self.cur_kmer, &self.buffer, false) {
return None;
}
self.start_pos += 1;
Expand Down
32 changes: 16 additions & 16 deletions src/formats/fasta.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use std::io::Write;
use memchr::memchr;

use crate::formats::buffer::RecParser;
use crate::seq::Sequence;
use crate::util::{memchr_both_last, strip_whitespace, ParseError, ParseErrorType};
use crate::seq::{Sequence, SequenceRecord};
use crate::util::{memchr_both_last, ParseError, ParseErrorType};

#[derive(Debug)]
pub struct FastaRecord<'a> {
Expand All @@ -14,28 +14,25 @@ pub struct FastaRecord<'a> {
}

impl<'a> FastaRecord<'a> {
pub fn write(&self, writer: &mut dyn Write) -> Result<(), ParseError> {
pub fn write(&self, writer: &mut dyn Write, ending: &[u8]) -> Result<(), ParseError> {
writer.write_all(b">")?;
writer.write_all(&self.id)?;
writer.write_all(b"\n")?;
writer.write_all(ending)?;
writer.write_all(&self.seq)?;
writer.write_all(b"\n")?;
writer.write_all(ending)?;
Ok(())
}
}

impl<'a> From<FastaRecord<'a>> for Sequence<'a> {
fn from(fasta: FastaRecord<'a>) -> Sequence<'a> {
Sequence::new(fasta.id, strip_whitespace(fasta.seq), None)
impl<'a> Sequence<'a> for FastaRecord<'a> {
fn sequence(&self) -> &'a [u8] {
self.seq
}
}

impl<'a> From<&'a Sequence<'a>> for FastaRecord<'a> {
fn from(seq: &'a Sequence<'a>) -> FastaRecord<'a> {
FastaRecord {
id: &seq.id,
seq: &seq.seq,
}
impl<'a> From<FastaRecord<'a>> for SequenceRecord<'a> {
fn from(fasta: FastaRecord<'a>) -> SequenceRecord<'a> {
SequenceRecord::new(fasta.id.into(), fasta.seq.into(), None)
}
}

Expand Down Expand Up @@ -95,8 +92,11 @@ impl<'a> Iterator for FastaParser<'a> {
.context(context)));
}
let mut seq = &buf[id_end..seq_end];
if seq[seq.len() - 1] == b'\r' {
seq = &seq[..seq.len()];
if seq.len() > 0 && seq[seq.len() - 1] == b'\n' {
seq = &seq[..seq.len() - 1];
}
if seq.len() > 0 && seq[seq.len() - 1] == b'\r' {
seq = &seq[..seq.len() - 1];
}

self.pos += seq_end;
Expand Down
42 changes: 16 additions & 26 deletions src/formats/fastq.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use std::borrow::Cow;
use std::cmp::min;
use std::io::Write;

use memchr::memchr;

use crate::formats::buffer::RecParser;
use crate::formats::fasta::check_end;
use crate::seq::Sequence;
use crate::seq::{Sequence, SequenceRecord};
use crate::util::{memchr_both, ParseError, ParseErrorType};

#[derive(Debug)]
Expand All @@ -18,45 +17,36 @@ pub struct FastqRecord<'a> {
}

impl<'a> FastqRecord<'a> {
pub fn write(&self, writer: &mut dyn Write) -> Result<(), ParseError> {
pub fn write(&self, writer: &mut dyn Write, ending: &[u8]) -> Result<(), ParseError> {
writer.write_all(b"@")?;
writer.write_all(&self.id)?;
writer.write_all(b"\n")?;
writer.write_all(ending)?;
writer.write_all(&self.seq)?;
writer.write_all(b"\n+\n")?;
writer.write_all(ending)?;
writer.write_all(b"+")?;
writer.write_all(ending)?;
// this is kind of a hack, but we want to allow writing out sequences
// that don't have qualitys so this will mask to "good" if the quality
// slice is empty
if self.seq.len() != self.qual.len() {
writer.write_all(&vec![b'I'; self.seq.len()])?;
} else {
writer.write_all(&self.qual)?;
}
writer.write_all(b"\n")?;
writer.write_all(ending)?;
Ok(())
}
}

impl<'a> From<FastqRecord<'a>> for Sequence<'a> {
fn from(fastq: FastqRecord<'a>) -> Sequence<'a> {
let qual = if fastq.seq.len() != fastq.qual.len() {
None
} else {
Some(fastq.qual)
};
Sequence::new(fastq.id, Cow::from(fastq.seq), qual)
impl<'a> Sequence<'a> for FastqRecord<'a> {
fn sequence(&self) -> &'a [u8] {
self.seq
}
}

impl<'a> From<&'a Sequence<'a>> for FastqRecord<'a> {
fn from(seq: &'a Sequence<'a>) -> FastqRecord<'a> {
let qual = match &seq.qual {
None => &b""[..],
Some(q) => &q,
};
FastqRecord {
id: &seq.id,
seq: &seq.seq,
id2: b"",
qual,
}
impl<'a> From<FastqRecord<'a>> for SequenceRecord<'a> {
fn from(fastq: FastqRecord<'a>) -> SequenceRecord<'a> {
SequenceRecord::new(fastq.id.into(), fastq.seq.into(), Some(fastq.qual.into()))
}
}

Expand Down
12 changes: 6 additions & 6 deletions src/formats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use xz2::read::XzDecoder;
pub use crate::formats::buffer::{RecBuffer, RecParser};
pub use crate::formats::fasta::{FastaParser, FastaRecord};
pub use crate::formats::fastq::{FastqParser, FastqRecord};
use crate::seq::Sequence;
use crate::seq::SequenceRecord;
use crate::util::{ParseError, ParseErrorType};

#[macro_export]
Expand Down Expand Up @@ -71,7 +71,7 @@ fn seq_reader<F, R, T>(
type_callback: &mut T,
) -> Result<(), ParseError>
where
F: for<'a> FnMut(Sequence<'a>) -> (),
F: for<'a> FnMut(SequenceRecord<'a>) -> (),
R: Read,
T: ?Sized + FnMut(&'static str) -> (),
{
Expand All @@ -91,10 +91,10 @@ where

match file_type {
"FASTA" => parse_stream!(reader, first, FastaParser, rec, {
callback(Sequence::from(rec))
callback(SequenceRecord::from(rec))
}),
"FASTQ" => parse_stream!(reader, first, FastqParser, rec, {
callback(Sequence::from(rec))
callback(SequenceRecord::from(rec))
}),
_ => panic!("A file type was inferred that could not be parsed"),
};
Expand All @@ -108,7 +108,7 @@ pub fn parse_sequences<F, R, T>(
callback: F,
) -> Result<(), ParseError>
where
F: for<'a> FnMut(Sequence<'a>) -> (),
F: for<'a> FnMut(SequenceRecord<'a>) -> (),
R: Read,
T: FnMut(&'static str) -> (),
{
Expand All @@ -124,7 +124,7 @@ pub fn parse_sequences<F, R, T>(
callback: F,
) -> Result<(), ParseError>
where
F: for<'a> FnMut(Sequence<'a>) -> (),
F: for<'a> FnMut(SequenceRecord<'a>) -> (),
R: Read,
T: FnMut(&'static str) -> (),
{
Expand Down
Loading

0 comments on commit 6c032c4

Please sign in to comment.