Skip to content

Commit

Permalink
add --save-ctc flag to store basecalled results as training data
Browse files Browse the repository at this point in the history
  • Loading branch information
iiSeymour committed Sep 1, 2020
1 parent 82907bd commit 0b7ad17
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 19 deletions.
36 changes: 29 additions & 7 deletions bonito/basecaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,52 @@
from datetime import timedelta
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter

from bonito.util import load_model, chunk, stitch
from bonito.io import DecoderWriterPool, PreprocessReader
from bonito.util import load_model, chunk, stitch, half_supported
from bonito.io import DecoderWriterPool, PreprocessReader, CTCWriter

import torch
import numpy as np
from mappy import Aligner


def main(args):

if args.save_ctc and not args.reference:
sys.stderr.write("> a reference is needed to output ctc training data\n")
exit(1)

if args.save_ctc:
args.overlap = 900
args.chunksize = 3600

sys.stderr.write("> loading model\n")

model = load_model(
args.model_directory, args.device, weights=int(args.weights),
half=args.half, chunksize=args.chunksize, use_rt=args.cudart,
)

if args.reference:
sys.stderr.write("> loading reference\n")
aligner = Aligner(args.reference, preset='ont-map')
if not aligner:
sys.stderr.write("> failed to load/build index\n")
sys.exit(1)
else:
aligner = None

samples = 0
num_reads = 0
max_read_size = 4e6
dtype = np.float16 if args.half else np.float32
ctc_writer = CTCWriter(model, aligner)
reader = PreprocessReader(args.reads_directory)
writer = DecoderWriterPool(model, beamsize=args.beamsize, fastq=args.fastq, reference=args.reference)
writer = DecoderWriterPool(model, beamsize=args.beamsize, fastq=args.fastq, aligner=aligner)

t0 = time.perf_counter()
sys.stderr.write("> calling\n")

with writer, reader, torch.no_grad():
with writer, ctc_writer, reader, torch.no_grad():

while True:

Expand All @@ -51,10 +70,12 @@ def main(args):
raw_data = torch.tensor(read.signal.astype(dtype))
chunks = chunk(raw_data, args.chunksize, args.overlap)

posteriors = model(chunks.to(args.device)).cpu().numpy()
posteriors = stitch(posteriors, args.overlap // model.stride // 2)
posteriors_ = model(chunks.to(args.device)).cpu().numpy()
posteriors = stitch(posteriors_, args.overlap // model.stride // 2)

writer.queue.put((read, posteriors[:raw_data.shape[0]]))
if args.save_ctc and len(raw_data) > args.chunksize:
ctc_writer.queue.put((chunks.numpy(), posteriors_))

duration = time.perf_counter() - t0

Expand All @@ -77,7 +98,8 @@ def argparser():
parser.add_argument("--beamsize", default=5, type=int)
parser.add_argument("--chunksize", default=0, type=int)
parser.add_argument("--overlap", default=0, type=int)
parser.add_argument("--half", action="store_true", default=False)
parser.add_argument("--half", action="store_true", default=half_supported())
parser.add_argument("--fastq", action="store_true", default=False)
parser.add_argument("--cudart", action="store_true", default=False)
parser.add_argument("--save-ctc", action="store_true", default=False)
return parser
104 changes: 92 additions & 12 deletions bonito/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
from glob import glob
from warnings import warn
from logging import getLogger
from os.path import realpath, splitext
from os.path import realpath, splitext, dirname
from multiprocessing import Process, Queue, Lock, cpu_count

import numpy as np
from tqdm import tqdm
from mappy import Aligner, revcomp
from mappy import revcomp

import bonito
from bonito.training import ChunkDataSet
from bonito.convert import filter_chunks
from bonito.util import get_raw_data, mean_qscore_from_qstring


Expand Down Expand Up @@ -214,25 +216,103 @@ def stop(self):
self.join()


class CTCWriter(Process):
"""
CTC writer process that writes output numpy training data
"""
def __init__(self, model, aligner, min_coverage=0.90, min_accuracy=0.90):
super().__init__()
self.model = model
self.queue = Queue()
self.aligner = aligner
self.min_coverage = min_coverage
self.min_accuracy = min_accuracy

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.queue.put(None)
self.stop()

def run(self):

chunks = []
targets = []
target_lens = []

while True:

job = self.queue.get()
if job is None: break
chunks_, predictions = job

# convert logprobs to probs
predictions = np.exp(predictions.astype(np.float32))

for chunk, pred in zip(chunks_, predictions):

sequence = self.model.decode(pred)

if not sequence:
continue

for mapping in self.aligner.map(sequence):
cov = (mapping.q_en - mapping.q_st) / len(sequence)
acc = mapping.mlen / mapping.blen
refseq = self.aligner.seq(mapping.ctg, mapping.r_st + 1, mapping.r_en)
if 'N' in refseq: continue
if mapping.strand == -1: refseq = revcomp(refseq)
break
else:
continue

if acc > self.min_accuracy and cov > self.min_accuracy:
chunks.append(chunk.squeeze())
targets.append([
int(x) for x in refseq.translate({65: '1', 67: '2', 71: '3', 84: '4'})
])
target_lens.append(len(refseq))

chunks = np.array(chunks, dtype=np.float32)
chunk_lens = np.full(chunks.shape[0], chunks.shape[1], dtype=np.int16)

targets_ = np.zeros((chunks.shape[0], max(target_lens)), dtype=np.uint8)
for idx, target in enumerate(targets): targets_[idx, :len(target)] = target
target_lens = np.array(target_lens, dtype=np.uint16)

training = ChunkDataSet(chunks, chunk_lens, targets_, target_lens)
training = filter_chunks(training)

output_directory = '.' if sys.stdout.isatty() else dirname(realpath('/dev/fd/1'))
np.save(os.path.join(output_directory, "chunks.npy"), training.chunks.squeeze(1))
np.save(os.path.join(output_directory, "chunk_lengths.npy"), training.chunk_lengths)
np.save(os.path.join(output_directory, "references.npy"), training.targets)
np.save(os.path.join(output_directory, "reference_lengths.npy"), training.target_lengths)

sys.stderr.write("> written ctc training data\n")
sys.stderr.write(" - chunks.npy with shape (%s)\n" % ','.join(map(str, training.chunks.squeeze(1).shape)))
sys.stderr.write(" - chunk_lengths.npy with shape (%s)\n" % ','.join(map(str, training.chunk_lengths.shape)))
sys.stderr.write(" - references.npy with shape (%s)\n" % ','.join(map(str, training.targets.shape)))
sys.stderr.write(" - reference_lengths.npy shape (%s)\n" % ','.join(map(str, training.target_lengths.shape)))

def stop(self):
self.join()


class DecoderWriterPool:
"""
Simple pool of decoder writers
"""
def __init__(self, model, procs=4, reference=None, **kwargs):
def __init__(self, model, procs=4, aligner=None, **kwargs):
self.lock = Lock()
self.queue = Queue()
self.procs = procs if procs else cpu_count()
self.aligner = aligner
self.decoders = []

if reference:
sys.stderr.write("> loading reference\n")
aligner = Aligner(reference, preset='ont-map')
if not aligner:
sys.stderr.write("> failed to load/build index\n")
sys.exit(1)
write_sam_header(aligner)
else:
aligner = None
if aligner: write_sam_header(aligner)

with open(summary_file(), 'w') as summary:
write_summary_header(summary, alignment=aligner)
Expand Down

0 comments on commit 0b7ad17

Please sign in to comment.