Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] update gather to calculate fraction of match that was in original query #938

Merged
merged 11 commits into from
Apr 15, 2020
54 changes: 50 additions & 4 deletions doc/classifying-signatures.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,12 @@ genome; it then subtracts that match from the metagenome, and repeats.
At the end it reports how much of the metagenome remains unknown. The
[basic sourmash
tutorial](http://sourmash.readthedocs.io/en/latest/tutorials.html#what-s-in-my-metagenome)
has some sample output from using gather with GenBank.
has some sample output from using gather with GenBank. See the appendix at
the bottom of this page for more technical details.

Our preliminary benchmarking suggests that `gather` is the most accurate
method available for doing strain-level resolution of genomes. More on that
as we move forward!
Some benchmarking on CAMI suggests that `gather` is a very accurate
method for doing strain-level resolution of genomes. More on
that as we move forward!

## To do taxonomy, or not to do taxonomy?

Expand Down Expand Up @@ -116,3 +117,48 @@ We suggest the following approach:
This helps us figure out what people are actually interested in doing, and
any help we provide via the issue tracker will eventually be added into the
documentation.

## Appendix: how `sourmash gather` works.

The sourmash gather algorithm works as follows:

* find the best match in the database, based on containment;
* subtract that match from the query;
* repeat.

The output below is the CSV output for a fictional metagenome.

The first column, `f_unique_to_query`, is the fraction of the database
match that is _unique_ with respect to the original query. It will
always decrease as you get more matches.

The second column, `f_match_orig`, is how much of the match is in the
_original_ query. For this fictional metagenome, each match is
entirely contained in the original query. This is the number you would
get by running `sourmash search --containment <match> <metagenome>`.

The third column, `f_match`, is how much of the match is in the remaining
query metagenome, after all of the previous matches have been removed.

The fourth column, `f_orig_query`, is how much of the original query
belongs to the match. This is the number you'd get by running
`sourmash search --containment <metagenome> <match>`.

```
f_unique_to_query f_match_orig f_match f_orig_query
0.3321964529331514 1.0 1.0 0.3321964529331514
0.13096862210095497 1.0 1.0 0.13096862210095497
0.11527967257844475 1.0 0.898936170212766 0.12824010914051842
0.10709413369713507 1.0 1.0 0.10709413369713507
0.10368349249658936 1.0 0.3134020618556701 0.33083219645293316
```

A few quick notes for the algorithmic folk out there --

* the key innovation for gather is that it looks for **groups** of
k-mers in the databases, and picks the best group (based on
containment). It does not treat k-mers individually.
* because of this, gather does not saturate as databases grow in size,
and in fact should only become more sensitive and specific as we
increase database size. (Although of course it may get a lot
slower...)
8 changes: 5 additions & 3 deletions sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,8 @@ def gather(args):
if found and args.output:
fieldnames = ['intersect_bp', 'f_orig_query', 'f_match',
'f_unique_to_query', 'f_unique_weighted',
'average_abund', 'median_abund', 'std_abund', 'name', 'filename', 'md5']
'average_abund', 'median_abund', 'std_abund', 'name',
'filename', 'md5', 'f_match_orig']

with FileOutput(args.output, 'wt') as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)
Expand Down Expand Up @@ -783,8 +784,9 @@ def multigather(args):
output_csv = output_base + '.csv'

fieldnames = ['intersect_bp', 'f_orig_query', 'f_match',
'f_unique_to_query', 'f_unique_weighted',
'average_abund', 'median_abund', 'std_abund', 'name', 'filename', 'md5']
'f_unique_to_query', 'f_unique_weighted',
'average_abund', 'median_abund', 'std_abund', 'name',
'filename', 'md5', 'f_match_orig']
with open(output_csv, 'wt') as fp:
w = csv.DictWriter(fp, fieldnames=fieldnames)
w.writeheader()
Expand Down
52 changes: 33 additions & 19 deletions sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def search_databases(query, databases, threshold, do_containment, best_only,
###

GatherResult = namedtuple('GatherResult',
'intersect_bp, f_orig_query, f_match, f_unique_to_query, f_unique_weighted, average_abund, median_abund, std_abund, filename, name, md5, match')
'intersect_bp, f_orig_query, f_match, f_unique_to_query, f_unique_weighted, average_abund, median_abund, std_abund, filename, name, md5, match,f_match_orig')


# build a new query object, subtracting found mins and downsampling
Expand Down Expand Up @@ -101,21 +101,27 @@ def _find_best(dblist, query, threshold_bp):
return best_cont, best_match, best_filename


def _filter_max_hash(values, max_hash):
for v in values:
if v < max_hash:
yield v


def gather_databases(query, databases, threshold_bp, ignore_abundance):
"""
Iteratively find the best containment of `query` in all the `databases`,
until we find fewer than `threshold_bp` (estimated) bp in common.
"""
# track original query information for later usage.
track_abundance = query.minhash.track_abundance and not ignore_abundance
orig_mh = query.minhash
orig_mins = orig_mh.get_hashes()
orig_abunds = { k: 1 for k in orig_mins }
orig_query_mh = query.minhash
orig_query_mins = orig_query_mh.get_hashes()

# do we pay attention to abundances?
orig_query_abunds = { k: 1 for k in orig_query_mins }
if track_abundance:
import numpy as np
orig_abunds = orig_mh.get_mins(with_abundance=True)
orig_query_abunds = orig_query_mh.get_mins(with_abundance=True)

cmp_scaled = query.minhash.scaled # initialize with resolution of query
while query.minhash:
Expand All @@ -142,15 +148,15 @@ def gather_databases(query, databases, threshold_bp, ignore_abundance):
# (CTB note: this means that if a high scaled/low res signature is
# found early on, resolution will be low from then on.)
new_max_hash = get_max_hash_for_scaled(cmp_scaled)
query_mins = set([ i for i in query_mins if i < new_max_hash ])
found_mins = set([ i for i in found_mins if i < new_max_hash ])
orig_mins = set([ i for i in orig_mins if i < new_max_hash ])
sum_abunds = sum([ v for (k,v) in orig_abunds.items() if k < new_max_hash ])
query_mins = set(_filter_max_hash(query_mins, new_max_hash))
found_mins = set(_filter_max_hash(found_mins, new_max_hash))
orig_query_mins = set(_filter_max_hash(orig_query_mins, new_max_hash))
sum_abunds = sum(( v for (k,v) in orig_query_abunds.items() if k < new_max_hash ))

# calculate intersection:
# calculate intersection with query mins:
intersect_mins = query_mins.intersection(found_mins)
intersect_orig_mins = orig_mins.intersection(found_mins)
intersect_bp = cmp_scaled * len(intersect_orig_mins)
intersect_orig_query_mins = orig_query_mins.intersection(found_mins)
intersect_bp = cmp_scaled * len(intersect_orig_query_mins)

if intersect_bp < threshold_bp: # hard cutoff for now
notify('found less than {} in common. => exiting',
Expand All @@ -160,21 +166,28 @@ def gather_databases(query, databases, threshold_bp, ignore_abundance):
# calculate fractions wrt first denominator - genome size
genome_n_mins = len(found_mins)
f_match = len(intersect_mins) / float(genome_n_mins)
f_orig_query = len(intersect_orig_mins) / float(len(orig_mins))
f_orig_query = len(intersect_orig_query_mins) / \
float(len(orig_query_mins))

# calculate fractions wrt second denominator - metagenome size
orig_mh = orig_mh.downsample_scaled(cmp_scaled)
query_n_mins = len(orig_mh)
orig_query_mh = orig_query_mh.downsample_scaled(cmp_scaled)
query_n_mins = len(orig_query_mh)
f_unique_to_query = len(intersect_mins) / float(query_n_mins)

# calculate fraction of subject match with orig query
f_match_orig = best_match.minhash.contained_by(orig_query_mh,
downsample=True)

# calculate scores weighted by abundances
f_unique_weighted = sum((orig_abunds[k] for k in intersect_mins)) \
/ sum_abunds
f_unique_weighted = sum((orig_query_abunds[k] for k in intersect_mins))
f_unique_weighted /= sum_abunds

# calculate stats on abundances, if desired.
average_abund, median_abund, std_abund = 0, 0, 0
if track_abundance:
intersect_abunds = list((orig_abunds[k] for k in intersect_mins))
intersect_abunds = (orig_query_abunds[k] for k in intersect_mins)
intersect_abunds = list(intersect_abunds)

average_abund = np.mean(intersect_abunds)
median_abund = np.median(intersect_abunds)
std_abund = np.std(intersect_abunds)
Expand All @@ -183,6 +196,7 @@ def gather_databases(query, databases, threshold_bp, ignore_abundance):
result = GatherResult(intersect_bp=intersect_bp,
f_orig_query=f_orig_query,
f_match=f_match,
f_match_orig=f_match_orig,
f_unique_to_query=f_unique_to_query,
f_unique_weighted=f_unique_weighted,
average_abund=average_abund,
Expand All @@ -198,7 +212,7 @@ def gather_databases(query, databases, threshold_bp, ignore_abundance):

# compute weighted_missed:
query_mins -= set(found_mins)
weighted_missed = sum((orig_abunds[k] for k in query_mins)) \
weighted_missed = sum((orig_query_abunds[k] for k in query_mins)) \
/ sum_abunds

yield result, weighted_missed, new_max_hash, query
61 changes: 61 additions & 0 deletions tests/test_sourmash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2500,6 +2500,67 @@ def test_gather_file_output():
assert '910,1.0,1.0' in output


@utils.in_tempdir
def test_gather_f_match_orig(c):
import copy

testdata_combined = utils.get_test_data('gather/combined.sig')
testdata_glob = utils.get_test_data('gather/GCF*.sig')
testdata_sigs = glob.glob(testdata_glob)

c.run_sourmash('gather', testdata_combined, '-o', 'out.csv',
*testdata_sigs)

combined_sig = sourmash.load_one_signature(testdata_combined, ksize=21)
remaining_mh = copy.copy(combined_sig.minhash)

def approx_equal(a, b, n=5):
return round(a, n) == round(b, n)

with open(c.output('out.csv'), 'rt') as fp:
r = csv.DictReader(fp)
for n, row in enumerate(r):
print(n, row['f_match'], row['f_match_orig'])

# each match is completely in the original query
assert row['f_match_orig'] == "1.0"

# double check -- should match 'search --containment'.
# (this is kind of useless for a 1.0 contained_by, I guess)
filename = row['filename']
match = sourmash.load_one_signature(filename, ksize=21)
assert match.contained_by(combined_sig) == 1.0

# check other fields, too.
f_orig_query = float(row['f_orig_query'])
f_match_orig = float(row['f_match_orig'])
f_match = float(row['f_match'])
f_unique_to_query = float(row['f_unique_to_query'])

# f_orig_query is the containment of the query by the match.
# (note, this only works because containment is 100% in combined).
assert approx_equal(combined_sig.contained_by(match), f_orig_query)

# just redoing above, for completeness; this is always 1.0 for
# this data set.
assert approx_equal(match.contained_by(combined_sig), f_match_orig)

# f_match is how much of the match is in the unallocated hashes
assert approx_equal(match.minhash.contained_by(remaining_mh),
f_match)

# f_unique_to_query is how much of the match is unique wrt
# the original query.
a = set(remaining_mh.get_mins())
b = set(match.minhash.get_mins())
n_intersect = len(a.intersection(b))
f_intersect = n_intersect / float(len(combined_sig.minhash))
assert approx_equal(f_unique_to_query, f_intersect)

# now, subtract current match from remaining... and iterate!
remaining_mh.remove_many(match.minhash.get_mins())


def test_gather_nomatch():
with utils.TempDirectory() as location:
testdata_query = utils.get_test_data('gather/GCF_000006945.2_ASM694v2_genomic.fna.gz.sig')
Expand Down