Skip to content

Commit

Permalink
recursive fast5 discovery support
Browse files Browse the repository at this point in the history
  • Loading branch information
iiSeymour committed Nov 13, 2020
1 parent a39bf52 commit 2c51ec5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
4 changes: 3 additions & 1 deletion bonito/cli/basecaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def main(args):
aligner = None

reads = get_reads(
args.reads_directory, n_proc=8, skip=args.skip, read_ids=column_to_set(args.read_ids)
args.reads_directory, n_proc=8, recursive=args.recursive,
read_ids=column_to_set(args.read_ids), skip=args.skip,
)

ctc_data = load_symbol(args.model_directory, "ctc_data")
Expand Down Expand Up @@ -83,6 +84,7 @@ def argparser():
parser.add_argument("--skip", action="store_true", default=False)
parser.add_argument("--fastq", action="store_true", default=False)
parser.add_argument("--save-ctc", action="store_true", default=False)
parser.add_argument("--recursive", action="store_true", default=False)
parser.add_argument("--ctc-min-coverage", default=0.9, type=float)
parser.add_argument("--ctc-min-accuracy", default=0.9, type=float)
return parser
10 changes: 5 additions & 5 deletions bonito/fast5.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
Bonito Fast5 Utils
"""

import os
import sys
from glob import glob
from pathlib import Path
from functools import partial
from multiprocessing import Pool
from itertools import chain, starmap
Expand All @@ -19,8 +18,8 @@ class Read:
def __init__(self, read, filename):

self.read_id = read.read_id
self.filename = filename.name
self.run_id = read.get_run_id().decode()
self.filename = os.path.basename(read.filename)

read_attrs = read.handle[read.raw_dataset_group_name].attrs
channel_info = read.handle[read.global_key + 'channel_id'].attrs
Expand Down Expand Up @@ -114,13 +113,14 @@ def get_raw_data_for_read(info):
return Read(f5_fh.get_read(read_id), filename)


def get_reads(directory, read_ids=None, skip=False, max_read_size=0, n_proc=1):
def get_reads(directory, read_ids=None, skip=False, max_read_size=0, n_proc=1, recursive=False):
"""
Get all reads in a given `directory`.
"""
pattern = "**/*.fast5" if recursive else "*.fast5"
get_filtered_reads = partial(get_read_ids, read_ids=read_ids, skip=skip)
with Pool(n_proc) as pool:
for job in chain(pool.imap(get_filtered_reads, glob("%s/*.fast5" % directory))):
for job in chain(pool.imap(get_filtered_reads, Path(directory).glob(pattern))):
for read in pool.imap(get_raw_data_for_read, job):
if max_read_size > 0 and len(read.signal) > max_read_size:
sys.stderr.write(
Expand Down

0 comments on commit 2c51ec5

Please sign in to comment.