Skip to content

Commit

Permalink
non-uniform first chunk strategy instead of zero-padding.
Browse files Browse the repository at this point in the history
  • Loading branch information
iiSeymour committed Mar 21, 2021
1 parent c0f0cbd commit 3def77d
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 40 deletions.
1 change: 0 additions & 1 deletion bonito/cli/basecaller.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def main(args):
writer = Writer(
tqdm(basecalls, desc="> calling", unit=" reads", leave=False), aligner, fastq=args.fastq
)

t0 = perf_counter()
writer.start()
writer.join()
Expand Down
38 changes: 17 additions & 21 deletions bonito/crf/basecall.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,21 @@
from functools import partial
from operator import itemgetter

import bonito
from bonito.io import Writer
from bonito.fast5 import get_reads
from bonito.aligner import Aligner, align_map
from bonito.multiprocessing import thread_map, thread_iter
from bonito.util import concat, chunk, batchify, unbatchify, half_supported


def stitch(chunks, start, end):
def stitch(chunks, chunksize, overlap, length, stride):
"""
Stitch chunks together with a given overlap
"""
if isinstance(chunks, dict):
return {k: stitch(v, start, end) for k, v in chunks.items()}

if chunks.shape[0] == 1: return chunks.squeeze(0)
return concat([chunks[0, :end], *chunks[1:-1, start:end], chunks[-1, start:]])

return {k: stitch(v, chunksize, overlap, length, stride) for k, v in chunks.items()}
return bonito.util.stitch(chunks, chunksize, overlap, length, stride)

def compute_scores(model, batch):
"""
Expand Down Expand Up @@ -81,35 +79,33 @@ def decode_int8(scores, seqdist, scale=127/5, beamsize=40, beamcut=100.0):
except IndexError:
return ""

def split_read(read, split_read_length):
if len(read.signal) <= split_read_length:
return [(read, 0, len(read.signal))]
breaks = np.arange(0, len(read.signal)+split_read_length, split_read_length)
return [(read, start, min(end, len(read.signal))) for (start, end) in zip(breaks[:-1], breaks[1:])]


def basecall(model, reads, aligner=None, beamsize=40, chunksize=4000, overlap=500, batchsize=32, qscores=False):
"""
Basecalls a set of reads.
"""
split_read_length=400000
_stitch = partial(
stitch,
start=overlap // 2 // model.stride,
end=(chunksize - overlap // 2) // model.stride,
)
_decode = partial(decode_int8, seqdist=model.seqdist, beamsize=beamsize)
reads = (
((read, i), x) for read in reads
for (i, x) in enumerate(torch.split(torch.from_numpy(read.signal), split_read_length))
)
reads = (read_chunk for read in reads for read_chunk in split_read(read, 400000))
chunks = (
((read, chunk(signal, chunksize, overlap, pad_start=True)) for (read, signal) in reads)
((read, start, end), chunk(torch.from_numpy(read.signal[start:end]), chunksize, overlap)) for (read, start, end) in reads
)
batches = (
(read, quantise_int8(compute_scores(model, batch)))
for read, batch in thread_iter(batchify(chunks, batchsize=batchsize))
(k, quantise_int8(compute_scores(model, batch)))
for k, batch in thread_iter(batchify(chunks, batchsize=batchsize))
)
stitched = ((read, _stitch(x)) for (read, x) in unbatchify(batches))
stitched = ((read, stitch(x, chunksize, overlap, (end-start), model.stride)) for ((read, start, end), x) in unbatchify(batches))

transferred = thread_map(transfer, stitched, n_thread=1)
basecalls = thread_map(_decode, transferred, n_thread=8)

basecalls = (
(read, ''.join(seq for k, seq in parts)) for read, parts in groupby(basecalls, lambda x: x[0][0])
(read, ''.join(seq for k, seq in parts)) for read, parts in groupby(basecalls, lambda x: (x[0].parent if hasattr(x[0], 'parent') else x[0]))
)
basecalls = (
(read, {'sequence': seq, 'qstring': '?' * len(seq) if qscores else '*', 'mean_qscore': 0.0})
Expand Down
2 changes: 1 addition & 1 deletion bonito/ctc/basecall.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def basecall(model, reads, aligner=None, beamsize=5, chunksize=0, overlap=0, bat
(k, compute_scores(model, v)) for k, v in batchify(chunks, batchsize)
)
scores = (
(read, {'scores': stitch(v, overlap, model.stride)}) for read, v in scores
(read, {'scores': stitch(v, chunksize, overlap, len(read.signal), model.stride)}) for read, v in scores
)
decoder = partial(decode, decode=model.decode, beamsize=beamsize, qscores=qscores)
basecalls = process_map(decoder, scores, n_proc=4)
Expand Down
2 changes: 1 addition & 1 deletion bonito/fast5.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,4 @@ def get_reads(directory, read_ids=None, skip=False, max_read_size=0, n_proc=1, r
yield read

if cancel is not None and cancel.is_set():
return
return
36 changes: 20 additions & 16 deletions bonito/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,30 +147,34 @@ def column_to_set(filename, idx=0, skip_header=False):
return {line.strip().split()[idx] for line in tsv.readlines()}


def chunk(signal, chunksize, overlap, pad_start=False):
def chunk(signal, chunksize, overlap):
"""
Convert a read into overlapping chunks before calling
"""
T = signal.shape[0]
if chunksize > 0:
padding = chunksize - T if T < chunksize else (overlap - T) % (chunksize - overlap)
padded = torch.nn.functional.pad(signal, (padding, 0) if pad_start else (0, padding))
return padded.unfold(0, chunksize, chunksize - overlap).unsqueeze(1)
return signal.unsqueeze(0).unsqueeze(0)
if chunksize == 0:
chunks = signal[None, :]
elif T < chunksize:
chunks = torch.nn.functional.pad(signal, (chunksize - T, 0))[None, :]
else:
stub = (T - overlap) % (chunksize - overlap)
chunks = signal[stub:].unfold(0, chunksize, chunksize - overlap)
if stub > 0:
chunks = torch.cat([signal[None, :chunksize], chunks], dim=0)
return chunks.unsqueeze(1)


def stitch(predictions, overlap, stride=1):
def stitch(chunks, chunksize, overlap, length, stride):
"""
Stitch predictions together with a given overlap
Stitch chunks together with a given overlap
"""
overlap = overlap // stride // 2
if predictions.shape[0] == 1:
return predictions.squeeze(0)
stitched = [predictions[0, 0:-overlap]]
for i in range(1, predictions.shape[0] - 1):
stitched.append(predictions[i][overlap:-overlap])
stitched.append(predictions[-1][overlap:])
return concat(stitched)
if chunks.shape[0] == 1: return chunks.squeeze(0)

semi_overlap = overlap//2
start, end = semi_overlap // stride, (chunksize-semi_overlap) // stride
stub = (length - overlap) % (chunksize - overlap)
first_chunk_end = (stub + semi_overlap) // stride if (stub > 0) else end
return concat([chunks[0, :first_chunk_end], *chunks[1:-1, start:end], chunks[-1, start:]])


def batchify(items, batchsize, dim=0):
Expand Down

0 comments on commit 3def77d

Please sign in to comment.