Skip to content

Commit

Permalink
MRG: fix gather memory usage issue by not accumulating GatherResult (
Browse files Browse the repository at this point in the history
…#2962)

This is kind of a patch-fix for
#2950 for `sourmash
gather` specifically.

This PR changes `sourmash gather` and `sourmash multigather` so that
they no longer store any `GatherResult` objects, thus decreasing memory
usage substantially.

The solution is hacky at several levels, including storing a CSV file in
memory rather than writing it progressively. But I think it's an
important fix to get in, since `gather` is one of our main use cases and
it's causing people some problems (including me) :(.

The PR also changes `--save-matches` so that it writes out sketches as
they are encountered. This breaks semantic versioning a little bit
because the target file for `--save-matches` is opened before any
matches are found, and thus may be empty and may also overwrite files
unnecessarily.

Ultimately, a better fix is needed - probably one that changes up the
dataclasses so that they don't store MinHashes - but such a fix is
beyond me at the moment.

## benchmarking

with latest @ e2c199f: 645 MB

```
        Command being timed: "sourmash gather /home/ctbrown/transfer/SRR606249.trim.k31.sig.gz /home/ctbrown/transfer/podar-ref.zip -o xxx.csv"
        User time (seconds): 48.51
        System time (seconds): 1.15
        Percent of CPU this job got: 99%
        Elapsed (wall clock) time (h:mm:ss or m:ss): 0:49.91
        Average shared text size (kbytes): 0
        Average unshared data size (kbytes): 0
        Average stack size (kbytes): 0
        Average total size (kbytes): 0
        Maximum resident set size (kbytes): 644900
        Average resident set size (kbytes): 0
        Major (requiring I/O) page faults: 156
        Minor (reclaiming a frame) page faults: 254494
        Voluntary context switches: 2412
        Involuntary context switches: 2749
        Swaps: 0
        File system inputs: 31488
        File system outputs: 64
        Socket messages sent: 0
        Socket messages received: 0
        Signals delivered: 0
        Page size (bytes): 4096
        Exit status: 0
```
with this branch: 215 MB
```
        Command being timed: "sourmash gather /home/ctbrown/transfer/SRR606249.trim.k31.sig.gz /home/ctbrown/transfer/podar-ref.zip -o xxx.csv"
        User time (seconds): 43.38
        System time (seconds): 0.89
        Percent of CPU this job got: 97%
        Elapsed (wall clock) time (h:mm:ss or m:ss): 0:45.58
        Average shared text size (kbytes): 0
        Average unshared data size (kbytes): 0
        Average stack size (kbytes): 0
        Average total size (kbytes): 0
        Maximum resident set size (kbytes): 215560
        Average resident set size (kbytes): 0
        Major (requiring I/O) page faults: 773
        Minor (reclaiming a frame) page faults: 148722
        Voluntary context switches: 3884
        Involuntary context switches: 6174
        Swaps: 0
        File system inputs: 151648
        File system outputs: 160
        Socket messages sent: 0
        Socket messages received: 0
        Signals delivered: 0
        Page size (bytes): 4096
        Exit status: 0
```
  • Loading branch information
ctb authored Jan 31, 2024
1 parent e2c199f commit 0827e8a
Showing 1 changed file with 107 additions and 72 deletions.
179 changes: 107 additions & 72 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os.path
import sys
import shutil
import io

import screed
from .compare import (compare_all_pairs, compare_serial_containment,
Expand Down Expand Up @@ -829,7 +830,7 @@ def gather(args):
## ok! now do gather -
notify("Doing gather to generate minimum metagenome cover.")

found = []
found = 0
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
orig_query_mh = query.minhash
Expand All @@ -845,39 +846,71 @@ def gather(args):
screen_width = _get_screen_width()
sum_f_uniq_found = 0.
result = None
for result in gather_iter:
sum_f_uniq_found += result.f_unique_to_query

if not len(found): # first result? print header.
if is_abundance:
print_results("")
print_results("overlap p_query p_match avg_abund")
print_results("--------- ------- ------- ---------")
else:
print_results("")
print_results("overlap p_query p_match")
print_results("--------- ------- -------")
### open output handles as needed for (1) saving CSV (2) saving matches

# save matching signatures?
if args.save_matches:
notify(f"saving all matches to '{args.save_matches}'")
save_sig_obj = SaveSignaturesToLocation(args.save_matches)
save_sig = save_sig_obj.__enter__()
else:
save_sig_obj = None
save_sig = None

# save CSV?
csv_outfp = io.StringIO()
csv_writer = None

# print interim result & save in `found` list for later use
pct_query = '{:.1f}%'.format(result.f_unique_weighted*100)
pct_genome = '{:.1f}%'.format(result.f_match*100)
try:
for result in gather_iter:
found += 1
sum_f_uniq_found += result.f_unique_to_query

if is_abundance:
name = result.match._display_name(screen_width - 41)
average_abund ='{:.1f}'.format(result.average_abund)
print_results('{:9} {:>7} {:>7} {:>9} {}',
format_bp(result.intersect_bp), pct_query, pct_genome,
average_abund, name)
else:
name = result.match._display_name(screen_width - 31)
print_results('{:9} {:>7} {:>7} {}',
format_bp(result.intersect_bp), pct_query, pct_genome,
name)
found.append(result)
if found == 1: # first result? print header.
if is_abundance:
print_results("")
print_results("overlap p_query p_match avg_abund")
print_results("--------- ------- ------- ---------")
else:
print_results("")
print_results("overlap p_query p_match")
print_results("--------- ------- -------")

if args.num_results and len(found) >= args.num_results:
break

# print interim result & save in `found` list for later use
pct_query = '{:.1f}%'.format(result.f_unique_weighted*100)
pct_genome = '{:.1f}%'.format(result.f_match*100)

if is_abundance:
name = result.match._display_name(screen_width - 41)
average_abund ='{:.1f}'.format(result.average_abund)
print_results('{:9} {:>7} {:>7} {:>9} {}',
format_bp(result.intersect_bp), pct_query, pct_genome,
average_abund, name)
else:
name = result.match._display_name(screen_width - 31)
print_results('{:9} {:>7} {:>7} {}',
format_bp(result.intersect_bp), pct_query, pct_genome,
name)

# write out CSV
if args.output:
if csv_writer is None:
csv_writer = result.init_dictwriter(csv_outfp)
result.write(csv_writer)

# save matches?
if save_sig is not None:
save_sig.add(result.match)

if args.num_results and found >= args.num_results:
break
finally:
if save_sig_obj:
save_sig_obj.close()
save_sig_obj = None
save_sig = None

# report on thresholding -
if gather_iter.query:
Expand All @@ -886,8 +919,8 @@ def gather(args):

# basic reporting:
if found:
print_results(f'\nfound {len(found)} matches total;')
if len(found) == args.num_results:
print_results(f'\nfound {found} matches total;')
if found == args.num_results:
print_results(f'(truncated gather because --num-results={args.num_results})')
else:
display_bp = format_bp(args.threshold_bp)
Expand All @@ -908,18 +941,7 @@ def gather(args):
# save CSV?
if (found and args.output) or args.create_empty_results:
with FileOutputCSV(args.output) as fp:
w = None
for result in found:
if w is None:
w = result.init_dictwriter(fp)
result.write(w)

# save matching signatures?
if found and args.save_matches:
notify(f"saving all matches to '{args.save_matches}'")
with SaveSignaturesToLocation(args.save_matches) as save_sig:
for sr in found:
save_sig.add(sr.match)
fp.write(csv_outfp.getvalue())

# save unassigned hashes?
if args.output_unassigned:
Expand Down Expand Up @@ -1027,7 +1049,7 @@ def multigather(args):
noident_mh.remove_many(union_found)
ident_mh.add_many(union_found)

found = []
found = 0
weighted_missed = 1
is_abundance = query.minhash.track_abundance and not args.ignore_abundance
orig_query_mh = query.minhash
Expand All @@ -1040,9 +1062,32 @@ def multigather(args):
screen_width = _get_screen_width()
sum_f_uniq_found = 0.
result = None

query_filename = query.filename
if not query_filename:
# use md5sum if query.filename not properly set
query_filename = query.md5sum()

output_base = os.path.basename(query_filename)
if args.output_dir:
output_base = os.path.join(args.output_dir, output_base)
output_csv = output_base + '.csv'

output_matches = output_base + '.matches.sig'
save_sig_obj = SaveSignaturesToLocation(output_matches)
save_sig = save_sig_obj.__enter__()
notify(f"saving all matching signatures to '{output_matches}'")

# track matches
notify(f'saving all CSV matches to "{output_csv}"')
csv_out_obj = FileOutputCSV(output_csv)
csv_outfp = csv_out_obj.__enter__()
csv_writer = None

for result in gather_iter:
found += 1
sum_f_uniq_found += result.f_unique_to_query
if not len(found): # first result? print header.
if found == 1: # first result? print header.
if is_abundance:
print_results("")
print_results("overlap p_query p_match avg_abund")
Expand All @@ -1068,7 +1113,13 @@ def multigather(args):
print_results('{:9} {:>7} {:>7} {}',
format_bp(result.intersect_bp), pct_query, pct_genome,
name)
found.append(result)

## @CTB
if csv_writer is None:
csv_writer = result.init_dictwriter(csv_outfp)
result.write(csv_writer)

save_sig.add(result.match)

# check for size estimation accuracy, which impacts ANI estimation
if not size_may_be_inaccurate and result.size_may_be_inaccurate:
Expand All @@ -1080,7 +1131,14 @@ def multigather(args):
notify(f'found less than {format_bp(args.threshold_bp)} in common. => exiting')

# basic reporting
print_results('\nfound {} matches total;', len(found))
print_results('\nfound {} matches total;', found)

# close saving etc.
save_sig_obj.close()
save_sig_obj = save_sig = None

csv_out_obj.close()
csv_out_obj = csv_outfp = csv_writer = None

if is_abundance and result:
p_covered = result.sum_weighted_found / result.total_weighted_hashes
Expand All @@ -1090,33 +1148,10 @@ def multigather(args):
print_results(f'the recovered matches hit {sum_f_uniq_found*100:.1f}% of the query k-mers (unweighted).')
print_results('')

if not found:
if found == 0:
notify('nothing found... skipping.')
continue

query_filename = query.filename
if not query_filename:
# use md5sum if query.filename not properly set
query_filename = query.md5sum()

output_base = os.path.basename(query_filename)
if args.output_dir:
output_base = os.path.join(args.output_dir, output_base)
output_csv = output_base + '.csv'

notify(f'saving all CSV matches to "{output_csv}"')
w = None
with FileOutputCSV(output_csv) as fp:
for result in found:
if w is None:
w = result.init_dictwriter(fp)
result.write(w)

output_matches = output_base + '.matches.sig'
with SaveSignaturesToLocation(output_matches) as save_sig:
notify(f"saving all matching signatures to '{output_matches}'")
save_sig.add_many([ r.match for r in found ])

output_unassigned = output_base + '.unassigned.sig'
with open(output_unassigned, 'wt') as fp:
remaining_query = gather_iter.query
Expand All @@ -1129,7 +1164,7 @@ def multigather(args):
abund_query_mh = remaining_query.minhash.inflate(orig_query_mh)
remaining_query.minhash = abund_query_mh

if not found:
if found == 0:
notify('nothing found - entire query signature unassigned.')
elif not remaining_query:
notify('no unassigned hashes! not saving.')
Expand Down

0 comments on commit 0827e8a

Please sign in to comment.