Skip to content

Commit 78b0e01

Browse files
authored
Merge pull request #86 from sirelkhatim/MemcpyAsync
Asynchronous DtoH copying of the inference results
2 parents 1d22a03 + 82dc81d commit 78b0e01

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

bonito/crf/basecall.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,12 @@ def transfer(x):
6060
"""
6161
Device to host transfer using pinned memory.
6262
"""
63-
return {
64-
k: torch.empty(v.shape, pin_memory=True, dtype=v.dtype).copy_(v).numpy()
65-
for k, v in x.items()
66-
}
63+
torch.cuda.synchronize()
64+
with torch.cuda.stream(torch.cuda.Stream()):
65+
return {
66+
k: torch.empty(v.shape, pin_memory=True, dtype=v.dtype).copy_(v).numpy()
67+
for k, v in x.items()
68+
}
6769

6870

6971
def decode_int8(scores, seqdist, scale=127/5, beamsize=40, beamcut=100.0):
@@ -103,7 +105,7 @@ def basecall(model, reads, aligner=None, beamsize=40, chunksize=4000, overlap=50
103105
for read, batch in thread_iter(batchify(chunks, batchsize=batchsize))
104106
)
105107
stitched = ((read, _stitch(x)) for (read, x) in unbatchify(batches))
106-
transferred = thread_map(transfer, stitched, n_thread=8, preserve_order=True)
108+
transferred = thread_map(transfer, stitched, n_thread=1, preserve_order=True)
107109
basecalls = thread_map(_decode, transferred, n_thread=8, preserve_order=True)
108110

109111
basecalls = (

0 commit comments

Comments
 (0)