Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] subtantial refactoring of CounterGather and related Index code. #2116

Merged
merged 45 commits into from
Jul 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
241dbc5
move most CounterGather tests over to index protocol tests
ctb Jul 8, 2022
66490a4
add LinearIndex wrapper
ctb Jul 8, 2022
ebb00ea
getting closer
ctb Jul 8, 2022
a8a4dd9
fix a bunch of the tests
ctb Jul 8, 2022
fdc8d4f
Merge branch 'latest' of https://github.com/sourmash-bio/sourmash int…
ctb Jul 8, 2022
ba114dd
Merge branch 'latest' of https://github.com/sourmash-bio/sourmash int…
ctb Jul 9, 2022
b444a68
fix call to 'peek'
ctb Jul 9, 2022
f87c9d4
adjust 'counter.add' call signature
ctb Jul 9, 2022
68458cf
add CounterGather_LCA
ctb Jul 9, 2022
b835c96
move CounterGather.calc_threshold into search.py
ctb Jul 9, 2022
1903920
minor refactoring
ctb Jul 9, 2022
5099d5a
resolve downsampling for linear index wrapper
ctb Jul 9, 2022
a8125b4
fix downsampling for LCA-based CounterGather
ctb Jul 9, 2022
1760ada
fix location foo
ctb Jul 9, 2022
5c9748a
fix remaining test
ctb Jul 9, 2022
c2d2637
minor cleanup
ctb Jul 10, 2022
6f9eb78
add doc
ctb Jul 10, 2022
f82e1d7
test multiple identical matches
ctb Jul 10, 2022
d9472ed
adjust LinearIndex implementation to skip identical matches
ctb Jul 10, 2022
3e1c1ae
switch to dictionaries for CounterGather
ctb Jul 11, 2022
4c14e01
cleanup protocol tests
ctb Jul 11, 2022
3df8c66
revert LCA_Database fix
ctb Jul 11, 2022
36d4c2c
Merge branch 'latest' into refactor/counter_gather_tests
ctb Jul 11, 2022
846c0ba
Merge branch 'refactor/counter_gather_tests' into update/counter_gather
ctb Jul 11, 2022
39835fc
restore CounterGather_LCA
ctb Jul 11, 2022
1a4e01b
cleanup
ctb Jul 11, 2022
ee0fd18
Merge branch 'refactor/counter_gather_tests' of https://github.com/so…
ctb Jul 11, 2022
dbabfe9
Merge branch 'latest' into refactor/counter_gather_tests
ctb Jul 11, 2022
b7c37bd
Merge branch 'refactor/counter_gather_tests' into update/counter_gather
ctb Jul 11, 2022
a676a69
Merge branch 'refactor/counter_gather_tests' into update/counter_gather
ctb Jul 11, 2022
402dbc6
fix or ignore most errors ;)
ctb Jul 11, 2022
0e4ca95
rename make_gather_query to make_containment_query
ctb Jul 12, 2022
9f7a20e
rename Index.gather to Index.best_containment
ctb Jul 12, 2022
cb2efd7
consolidate threshold_bp => threshold calc code
ctb Jul 12, 2022
22aa74c
change best_containment to return None or a result object, not a list
ctb Jul 12, 2022
e9022c7
add flatten_and_* utility functions
ctb Jul 13, 2022
2ace44f
add .signatures() method to CounterGather class
ctb Jul 13, 2022
430ef9d
change CounterGather to take SourmashSignature instead of Minhash
ctb Jul 13, 2022
f97b8e8
fix test_index tests for counter
ctb Jul 13, 2022
c6078a6
Merge branch 'latest' into refactor/counter_gather_tests
ctb Jul 13, 2022
db87d5e
lightly clean up LCA_Database based counter
ctb Jul 13, 2022
889e731
Merge branch 'refactor/counter_gather_tests' into update/counter_gather
ctb Jul 13, 2022
b5e497d
Merge branch 'latest' of https://github.com/sourmash-bio/sourmash int…
ctb Jul 16, 2022
f8e2edc
add comment and test re duplicate signatures, per @bluegenes
ctb Jul 16, 2022
61624fc
fix typo
ctb Jul 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/sourmash/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,9 @@ def gather(args):
else:
raise # re-raise other errors, if no picklist.

save_prefetch.add_many(counter.siglist)
save_prefetch.add_many(counter.signatures())
# subtract found hashes as we can.
for found_sig in counter.siglist:
for found_sig in counter.signatures():
noident_mh.remove_many(found_sig.minhash)

# optionally calculate and save prefetch csv
Expand Down Expand Up @@ -935,7 +935,7 @@ def multigather(args):
counters = []
for db in databases:
counter = db.counter_gather(prefetch_query, args.threshold_bp)
for found_sig in counter.siglist:
for found_sig in counter.signatures():
noident_mh.remove_many(found_sig.minhash)
counters.append(counter)

Expand Down
142 changes: 63 additions & 79 deletions src/sourmash/index/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,15 @@
from abc import abstractmethod, ABC
from collections import namedtuple, Counter

from sourmash.search import (make_jaccard_search_query, make_gather_query,
from sourmash.search import (make_jaccard_search_query,
make_containment_query,
calc_threshold_from_bp)
from sourmash.manifest import CollectionManifest
from sourmash.logging import debug_literal
from sourmash.signature import load_signatures, save_signatures
from sourmash.minhash import (flatten_and_downsample_scaled,
flatten_and_downsample_num,
flatten_and_intersect_scaled)

# generic return tuple for Index.search and Index.gather
IndexSearchResult = namedtuple('Result', 'score, signature, location')
Expand Down Expand Up @@ -108,7 +112,7 @@ def find(self, search_fn, query, **kwargs):

search_fn follows the protocol in JaccardSearch objects.

Returns a list.
Generator. Returns 0 or more IndexSearchResult objects.
"""
# first: is this query compatible with this search?
search_fn.check_is_compatible(query)
Expand All @@ -124,50 +128,19 @@ def find(self, search_fn, query, **kwargs):
query_scaled = query_mh.scaled

def prepare_subject(subj_mh):
assert subj_mh.scaled
if subj_mh.track_abundance:
subj_mh = subj_mh.flatten()

# downsample subject to highest scaled
subj_scaled = subj_mh.scaled
if subj_scaled < query_scaled:
return subj_mh.downsample(scaled=query_scaled)
else:
return subj_mh
return flatten_and_downsample_scaled(subj_mh, query_scaled)

def prepare_query(query_mh, subj_mh):
assert subj_mh.scaled

# downsample query to highest scaled
subj_scaled = subj_mh.scaled
if subj_scaled > query_scaled:
return query_mh.downsample(scaled=subj_scaled)
else:
return query_mh
return flatten_and_downsample_scaled(query_mh, subj_mh.scaled)

else: # num
query_num = query_mh.num

def prepare_subject(subj_mh):
assert subj_mh.num
if subj_mh.track_abundance:
subj_mh = subj_mh.flatten()

# downsample subject to smallest num
subj_num = subj_mh.num
if subj_num > query_num:
return subj_mh.downsample(num=query_num)
else:
return subj_mh
return flatten_and_downsample_num(subj_mh, query_num)

def prepare_query(query_mh, subj_mh):
assert subj_mh.num
# downsample query to smallest num
subj_num = subj_mh.num
if subj_num < query_num:
return query_mh.downsample(num=subj_num)
else:
return query_mh
return flatten_and_downsample_num(query_mh, subj_mh.num)

# now, do the search!
for subj, location in self.signatures_with_location():
Expand Down Expand Up @@ -195,7 +168,7 @@ def prepare_query(query_mh, subj_mh):
yield IndexSearchResult(score, subj, location)

def search_abund(self, query, *, threshold=None, **kwargs):
"""Return set of matches with angular similarity above 'threshold'.
"""Return list of IndexSearchResult with angular similarity above 'threshold'.

Results will be sorted by similarity, highest to lowest.
"""
Expand Down Expand Up @@ -223,7 +196,7 @@ def search_abund(self, query, *, threshold=None, **kwargs):
def search(self, query, *, threshold=None,
do_containment=False, do_max_containment=False,
best_only=False, **kwargs):
"""Return set of matches with similarity above 'threshold'.
"""Return list of IndexSearchResult with similarity above 'threshold'.

Results will be sorted by similarity, highest to lowest.

Expand All @@ -239,50 +212,55 @@ def search(self, query, *, threshold=None,
threshold = float(threshold)

search_obj = make_jaccard_search_query(do_containment=do_containment,
do_max_containment=do_max_containment,
do_max_containment=do_max_containment,
best_only=best_only,
threshold=threshold)

# do the actual search:
matches = []

for sr in self.find(search_obj, query, **kwargs):
matches.append(sr)
matches = list(self.find(search_obj, query, **kwargs))

# sort!
matches.sort(key=lambda x: -x.score)
return matches

def prefetch(self, query, threshold_bp, **kwargs):
"Return all matches with minimum overlap."
"""Return all matches with minimum overlap.

Generator. Returns 0 or more IndexSearchResult namedtuples.
"""
if not self: # empty database? quit.
raise ValueError("no signatures to search")

search_fn = make_gather_query(query.minhash, threshold_bp,
best_only=False)
# default best_only to False
best_only = kwargs.get('best_only', False)

search_fn = make_containment_query(query.minhash, threshold_bp,
best_only=best_only)

for sr in self.find(search_fn, query, **kwargs):
yield sr

def gather(self, query, threshold_bp=None, **kwargs):
"Return the match with the best Jaccard containment in the Index."
def best_containment(self, query, threshold_bp=None, **kwargs):
"""Return the match with the best Jaccard containment in the Index.

results = []
for result in self.prefetch(query, threshold_bp, **kwargs):
results.append(result)
Returns an IndexSearchResult namedtuple or None.
"""

# sort results by best score.
results.sort(reverse=True,
key=lambda x: (x.score, x.signature.md5sum()))
results = self.prefetch(query, threshold_bp, best_only=True, **kwargs)
results = sorted(results,
key=lambda x: (-x.score, x.signature.md5sum()))

return results[:1]
try:
return next(iter(results))
except StopIteration:
return None

def peek(self, query_mh, *, threshold_bp=0):
"""Mimic CounterGather.peek() on top of Index.

This is implemented for situations where we don't want to use
'prefetch' functionality. It is a light wrapper around the
'gather'/search-by-containment method.
'best_containment(...)' method.
"""
from sourmash import SourmashSignature

Expand All @@ -291,22 +269,18 @@ def peek(self, query_mh, *, threshold_bp=0):

# run query!
try:
result = self.gather(query_ss, threshold_bp=threshold_bp)
result = self.best_containment(query_ss, threshold_bp=threshold_bp)
except ValueError:
result = None

if not result:
return []

# if matches, calculate intersection & return.
sr = result[0]
match_mh = sr.signature.minhash
scaled = max(query_mh.scaled, match_mh.scaled)
match_mh = match_mh.downsample(scaled=scaled).flatten()
query_mh = query_mh.downsample(scaled=scaled)
intersect_mh = match_mh & query_mh
intersect_mh = flatten_and_intersect_scaled(result.signature.minhash,
query_mh)

return [sr, intersect_mh]
return [result, intersect_mh]

def consume(self, intersect_mh):
"Mimic CounterGather.consume on top of Index. Yes, this is backwards."
Expand All @@ -326,7 +300,7 @@ def counter_gather(self, query, threshold_bp, **kwargs):
prefetch_query.minhash = prefetch_query.minhash.flatten()

# find all matches and construct a CounterGather object.
counter = CounterGather(prefetch_query.minhash)
counter = CounterGather(prefetch_query)
for result in self.prefetch(prefetch_query, threshold_bp, **kwargs):
counter.add(result.signature, location=result.location)

Expand Down Expand Up @@ -721,9 +695,14 @@ class CounterGather:
This particular implementation maintains a collections.Counter that
is used to quickly find the best match when 'peek' is called, but
other implementations are possible ;).

Note that redundant matches (SourmashSignature objects) with
duplicate md5s are collapsed inside the class, because we use the
md5sum as a key into the dictionary used to store matches.
"""
def __init__(self, query_mh):
"Constructor - takes a query FracMinHash."
def __init__(self, query):
"Constructor - takes a query SourmashSignature."
query_mh = query.minhash
if not query_mh.scaled:
raise ValueError('gather requires scaled signatures')

Expand All @@ -732,8 +711,8 @@ def __init__(self, query_mh):
self.scaled = query_mh.scaled

# use these to track loaded matches & their locations
self.siglist = []
self.locations = []
self.siglist = {}
self.locations = {}

# ...and also track overlaps with the progressive query
self.counter = Counter()
Expand All @@ -749,11 +728,11 @@ def add(self, ss, *, location=None, require_overlap=True):
# upon insertion, count & track overlap with the specific query.
overlap = self.orig_query_mh.count_common(ss.minhash, True)
if overlap:
i = len(self.siglist)
md5 = ss.md5sum()

self.counter[i] = overlap
self.siglist.append(ss)
self.locations.append(location)
self.counter[md5] = overlap
self.siglist[md5] = ss
self.locations[md5] = location
ctb marked this conversation as resolved.
Show resolved Hide resolved

# note: scaled will be max of all matches.
self.downsample(ss.minhash.scaled)
Expand All @@ -766,6 +745,11 @@ def downsample(self, scaled):
self.scaled = scaled
return self.scaled

def signatures(self):
"Return all signatures."
for ss in self.siglist.values():
yield ss

def peek(self, cur_query_mh, *, threshold_bp=0):
"Get next 'gather' result for this database, w/o changing counters."
self.query_started = 1
Expand All @@ -789,11 +773,11 @@ def peek(self, cur_query_mh, *, threshold_bp=0):
raise ValueError("current query not a subset of original query")

# are we setting a threshold?
threshold, n_threshold_hashes = calc_threshold_from_bp(threshold_bp,
scaled,
len(cur_query_mh))
# is it too high to ever match? if so, exit.
if threshold > 1.0:
try:
x = calc_threshold_from_bp(threshold_bp, scaled, len(cur_query_mh))
threshold, n_threshold_hashes = x
except ValueError:
# too high to ever match => exit
return []

# Find the best match using the internal Counter.
Expand Down
33 changes: 33 additions & 0 deletions src/sourmash/minhash.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,39 @@ def translate_codon(codon):
raise ValueError(e.message)


def flatten_and_downsample_scaled(mh, *scaled_vals):
"Flatten MinHash object and downsample to max of scaled values."
assert mh.scaled
assert all( (x > 0 for x in scaled_vals) )

mh = mh.flatten()
scaled = max(scaled_vals)
if scaled > mh.scaled:
return mh.downsample(scaled=scaled)
return mh


def flatten_and_downsample_num(mh, *num_vals):
"Flatten MinHash object and downsample to min of num values."
assert mh.num
assert all( (x > 0 for x in num_vals) )

mh = mh.flatten()
num = min(num_vals)
if num < mh.num:
return mh.downsample(num=num)
return mh


def flatten_and_intersect_scaled(mh1, mh2):
"Flatten and downsample two scaled MinHash objs, then return intersection."
scaled = max(mh1.scaled, mh2.scaled)
mh1 = mh1.flatten().downsample(scaled=scaled)
mh2 = mh2.flatten().downsample(scaled=scaled)

return mh1 & mh2


class _HashesWrapper(Mapping):
"A read-only view of the hashes contained by a MinHash object."
def __init__(self, h):
Expand Down
26 changes: 9 additions & 17 deletions src/sourmash/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,19 @@ def calc_threshold_from_bp(threshold_bp, scaled, query_size):
n_threshold_hashes = 0

if threshold_bp:
if threshold_bp < 0:
raise TypeError("threshold_bp must be non-negative")

ctb marked this conversation as resolved.
Show resolved Hide resolved
# if we have a threshold_bp of N, then that amounts to N/scaled
# hashes:
n_threshold_hashes = float(threshold_bp) / scaled

# that then requires the following containment:
threshold = n_threshold_hashes / query_size

# is it too high to ever match?
if threshold > 1.0:
raise ValueError("requested threshold_bp is unattainable with this query")
return threshold, n_threshold_hashes


Expand Down Expand Up @@ -62,8 +68,8 @@ def make_jaccard_search_query(*,
return search_obj


def make_gather_query(query_mh, threshold_bp, *, best_only=True):
"Make a search object for gather."
def make_containment_query(query_mh, threshold_bp, *, best_only=True):
"Make a search object for containment, with threshold_bp."
if not query_mh:
raise ValueError("query is empty!?")

Expand All @@ -72,21 +78,7 @@ def make_gather_query(query_mh, threshold_bp, *, best_only=True):
raise TypeError("query signature must be calculated with scaled")

# are we setting a threshold?
threshold = 0
if threshold_bp:
if threshold_bp < 0:
raise TypeError("threshold_bp must be non-negative")

# if we have a threshold_bp of N, then that amounts to N/scaled
# hashes:
n_threshold_hashes = threshold_bp / scaled

# that then requires the following containment:
threshold = n_threshold_hashes / len(query_mh)

# is it too high to ever match? if so, exit.
if threshold > 1.0:
raise ValueError("requested threshold_bp is unattainable with this query")
threshold, _ = calc_threshold_from_bp(threshold_bp, scaled, len(query_mh))

if best_only:
search_obj = JaccardSearchBestOnly(SearchType.CONTAINMENT,
Expand Down
Loading