Skip to content

Commit

Permalink
Merge pull request #34 from TRON-Bioinformatics/develop
Browse files Browse the repository at this point in the history
Separate the computation of the cooccurrence matrix in a separate operation
  • Loading branch information
priesgo authored Jun 4, 2022
2 parents ced1680 + f3f20d7 commit d337730
Show file tree
Hide file tree
Showing 12 changed files with 195 additions and 66 deletions.
2 changes: 1 addition & 1 deletion covigator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION = "v0.6.6"
VERSION = "v0.7.0"
ANALYSIS_PIPELINE_VERSION = "v0.9.3"

MISSENSE_VARIANT = "missense_variant"
Expand Down
24 changes: 24 additions & 0 deletions covigator/command_line.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from argparse import ArgumentParser

from covigator.precomputations.load_cooccurrences import CooccurrenceMatrixLoader
from covigator.precomputations.loader import PrecomputationsLoader
from dask.distributed import Client
from dask_jobqueue import SLURMCluster
Expand Down Expand Up @@ -171,3 +172,26 @@ def precompute_queries():
logger.info("Starting precomputation...")
loader.load()
logger.info("Done precomputing")


def cooccurrence():
parser = ArgumentParser(description="Precompute cooccurrence of mutations")
parser.add_argument(
"--source",
dest="data_source",
help="Specify data source. This can be either ENA or GISAID",
required=True
)
parser.add_argument(
"--maximum-mutation-length",
dest="maximum_length",
help="Only mutations with this maximum size will be included in the cooccurence matrix",
default=10
)
args = parser.parse_args()

database = Database(initialize=True, config=Configuration())
loader = CooccurrenceMatrixLoader(session=database.get_database_session())
logger.info("Starting precomputation...")
loader.load(data_source=args.data_source, maximum_length=args.maximum_length)
logger.info("Done precomputing")
13 changes: 13 additions & 0 deletions covigator/database/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def get_table_versioned_name(basename, config: Configuration):
LOG_TABLE_NAME = get_table_versioned_name('log', config=config)
LAST_UPDATE_TABLE_NAME = get_table_versioned_name('last_update', config=config)
VARIANT_COOCCURRENCE_TABLE_NAME = get_table_versioned_name('variant_cooccurrence', config=config)
GISAID_VARIANT_COOCCURRENCE_TABLE_NAME = get_table_versioned_name('gisaid_variant_cooccurrence', config=config)
VARIANT_OBSERVATION_TABLE_NAME = get_table_versioned_name('variant_observation', config=config)
SUBCLONAL_VARIANT_OBSERVATION_TABLE_NAME = get_table_versioned_name('subclonal_variant_observation', config=config)
LOW_FREQUENCY_VARIANT_OBSERVATION_TABLE_NAME = get_table_versioned_name('low_frequency_variant_observation', config=config)
Expand Down Expand Up @@ -811,6 +812,18 @@ class VariantCooccurrence(Base):
ForeignKeyConstraint([variant_id_two], [Variant.variant_id])


class GisaidVariantCooccurrence(Base):

__tablename__ = GISAID_VARIANT_COOCCURRENCE_TABLE_NAME

variant_id_one = Column(String, primary_key=True)
variant_id_two = Column(String, primary_key=True)
count = Column(Integer, default=0)

ForeignKeyConstraint([variant_id_one], [GisaidVariant.variant_id])
ForeignKeyConstraint([variant_id_two], [GisaidVariant.variant_id])


class CovigatorModule(enum.Enum):

__constraint_name__ = COVIGATOR_MODULE_CONSTRAINT_NAME
Expand Down
98 changes: 90 additions & 8 deletions covigator/database/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, Union
import pandas as pd
from logzero import logger
import sqlalchemy
from sqlalchemy import and_, desc, asc, func, String, DateTime
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.orm import Session, aliased
Expand All @@ -12,7 +13,7 @@
SubclonalVariantObservation, PrecomputedVariantsPerSample, PrecomputedSubstitutionsCounts, PrecomputedIndelLength, \
VariantType, PrecomputedAnnotation, PrecomputedOccurrence, PrecomputedTableCounts, \
PrecomputedVariantAbundanceHistogram, PrecomputedSynonymousNonSynonymousCounts, RegionType, Domain, \
GisaidVariantObservation, GisaidVariant, LastUpdate
GisaidVariantObservation, GisaidVariant, LastUpdate, GisaidVariantCooccurrence
from covigator.exceptions import CovigatorQueryException, CovigatorDashboardMissingPrecomputedData


Expand Down Expand Up @@ -51,6 +52,16 @@ def get_sample_klass(source: str):
raise CovigatorQueryException("Bad data source: {}".format(source))
return klass

@staticmethod
def get_variant_cooccurrence_klass(source: str):
if source == DataSource.ENA.name:
klass = VariantCooccurrence
elif source == DataSource.GISAID.name:
klass = GisaidVariantCooccurrence
else:
raise CovigatorQueryException("Bad data source: {}".format(source))
return klass

def find_job_by_accession_and_status(
self, run_accession: str, status: JobStatus, data_source: DataSource) -> Union[SampleEna, SampleGisaid]:
klass = self.get_sample_klass(source=data_source.name)
Expand Down Expand Up @@ -324,16 +335,33 @@ def get_non_synonymous_variants_by_region(self, start, end, source) -> pd.DataFr
return pd.read_sql(
self.session.query(subquery).filter(subquery.c.count_occurrences > 1).statement, self.session.bind)

def get_variants_by_sample(self, sample_id, source: str) -> List[VariantObservation]:
def get_variant_ids_by_sample(self, sample_id, source: str, maximum_length: int) -> List[str]:
klass = self.get_variant_observation_klass(source=source)
return self.session.query(klass) \
.filter(klass.sample == sample_id).order_by(klass.position, klass.reference, klass.alternate).all()
return self.session.query(klass.variant_id) \
.filter(and_(klass.sample == sample_id, klass.length < maximum_length, klass.length > -maximum_length)) \
.order_by(klass.position, klass.reference, klass.alternate) \
.all()

def get_variant_cooccurrence(self, variant_one: Variant, variant_two: Variant) -> VariantCooccurrence:
return self.session.query(VariantCooccurrence) \
.filter(and_(VariantCooccurrence.variant_id_one == variant_one.variant_id,
VariantCooccurrence.variant_id_two == variant_two.variant_id)) \
def increment_variant_cooccurrence(
self, variant_id_one: str, variant_id_two: str, source: str):

klazz = self.get_variant_cooccurrence_klass(source=source)
variant_cooccurrence = self.session.query(klazz) \
.filter(and_(klazz.variant_id_one == variant_id_one,
klazz.variant_id_two == variant_id_two)) \
.first()
if variant_cooccurrence is None:
variant_cooccurrence = klazz(
variant_id_one=variant_id_one,
variant_id_two=variant_id_two,
count=1)
self.session.add(variant_cooccurrence)
else:
# NOTE: it is important to increase the counter like this to avoid race conditions
# the increase happens in the database server and not in python
# see https://stackoverflow.com/questions/2334824/how-to-increase-a-counter-in-sqlalchemy
variant_cooccurrence.count = klazz.count + 1
self.session.commit()

def count_samples(self, source: str, cache=True) -> int:
self._assert_data_source(source)
Expand Down Expand Up @@ -700,3 +728,57 @@ class LiteralDialect(DefaultDialect):
logger.info(query.statement.compile(
dialect=LiteralDialect(),
compile_kwargs={'literal_binds': True}).string)

def column_windows(self, session, column, windowsize):
"""Return a series of WHERE clauses against
a given column that break it into windows.
Result is an iterable of tuples, consisting of
((start, end), whereclause), where (start, end) are the ids.
Requires a database that supports window functions,
i.e. Postgresql, SQL Server, Oracle.
Enhance this yourself ! Add a "where" argument
so that windows of just a subset of rows can
be computed.
"""
def int_for_range(start_id, end_id):
if end_id:
return and_(
column >= start_id,
column < end_id
)
else:
return column >= start_id

q = session.query(
column,
func.row_number(). \
over(order_by=column). \
label('rownum')
). \
from_self(column)
if windowsize > 1:
q = q.filter(sqlalchemy.text("rownum %% %d=1" % windowsize))

intervals = [id for id, in q]

while intervals:
start = intervals.pop(0)
if intervals:
end = intervals[0]
else:
end = None
yield int_for_range(start, end)

def windowed_query(self, query, column, windowsize):
""""
Break a Query into windows on a given column.
This magic comes from here: https://github.com/sqlalchemy/sqlalchemy/wiki/RangeQuery-and-WindowedRangeQuery
"""
for whereclause in self.column_windows(query.session, column, windowsize):
for row in query.filter(whereclause).order_by(column):
yield row
43 changes: 4 additions & 39 deletions covigator/pipeline/cooccurrence_matrix.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,21 @@
from itertools import combinations
from typing import Union

from sqlalchemy.orm import Session
from covigator.database.model import VariantCooccurrence, DataSource
from logzero import logger
from covigator.database.queries import Queries
from sqlalchemy.exc import IntegrityError


class CooccurrenceMatrixException(Exception):
pass


class CooccurrenceMatrix:

def compute(self, run_accession: str, source: DataSource, session: Session):
def compute(self, run_accession: str, source: str, session: Session, maximum_length: int = 10):

assert run_accession is not None or run_accession == "", "Missing sample identifier"
assert session is not None, "Missing DB session"

queries = Queries(session=session)
sample_id = run_accession
logger.info("Processing cooccurrent variants for sample {}".format(sample_id))

# the order by position is important to ensure we store only half the matrix and the same half of the matrix
variants = queries.get_variants_by_sample(sample_id, source=source.name)
failed_variants = []
variant_ids = queries.get_variant_ids_by_sample(sample_id, source=source, maximum_length=maximum_length)

# process all pairwise combinations without repetitions including the diagoonal
for (variant_one, variant_two) in list(combinations(variants, 2)) + list(zip(variants, variants)):
try:
variant_cooccurrence = queries.get_variant_cooccurrence(variant_one, variant_two)
if variant_cooccurrence is None:
variant_cooccurrence = VariantCooccurrence(
variant_id_one=variant_one.variant_id,
variant_id_two=variant_two.variant_id,
count=1
)
session.add(variant_cooccurrence)
session.commit()
else:
# NOTE: it is important to increase the counter like this to avoid race conditions
# the increase happens in the database server and not in python
# see https://stackoverflow.com/questions/2334824/how-to-increase-a-counter-in-sqlalchemy
variant_cooccurrence.count = VariantCooccurrence.count + 1
except IntegrityError:
session.rollback()
failed_variants.append((variant_one, variant_two))

# tries again the failed variants as these are expected to be there now
for (variant_one, variant_two) in failed_variants:
variant_cooccurrence = queries.get_variant_cooccurrence(variant_one, variant_two)
if variant_cooccurrence is None:
raise CooccurrenceMatrixException("Some cooccurrent variants failed to be persisted twice")
variant_cooccurrence.count = VariantCooccurrence.count + 1
for (variant_id_one, variant_id_two) in list(combinations(variant_ids, 2)) + list(zip(variant_ids, variant_ids)):
queries.increment_variant_cooccurrence(variant_id_one, variant_id_two, source)
31 changes: 31 additions & 0 deletions covigator/precomputations/load_cooccurrences.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from sqlalchemy.orm import Session
from covigator.database.model import DataSource, JobStatus
from covigator.database.queries import Queries
from covigator.pipeline.cooccurrence_matrix import CooccurrenceMatrix
from logzero import logger


class CooccurrenceMatrixLoader:

def __init__(self, session: Session):
self.session = session
self.queries = Queries(session=self.session)
self.cooccurrence_matrix = CooccurrenceMatrix()

def load(self, data_source: str, maximum_length: int):

# deletes the database before loading
self.session.query(self.queries.get_variant_cooccurrence_klass(data_source)).delete()

# iterates over every sample in FINISHED status and computes the cooccurrence matrix
sample_klass = self.queries.get_sample_klass(data_source)
count_samples = self.queries.count_samples(source=data_source, cache=False)
computed = 0
query = self.session.query(sample_klass).filter(sample_klass.status == JobStatus.FINISHED)
for sample in self.queries.windowed_query(query=query, column=sample_klass.run_accession, windowsize=1000):
self.cooccurrence_matrix.compute(sample.run_accession, data_source, self.session,
maximum_length=maximum_length)
computed += 1
if computed % 1000 == 0:
logger.info('Processed cooccurrence over {}/{} ({}) samples'.format(
computed, count_samples, round(float(computed) / count_samples, 3)))
8 changes: 0 additions & 8 deletions covigator/processor/ena_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def run_all(sample: SampleEna, queries: Queries, config: Configuration) -> Sampl
sample = EnaProcessor.download(sample=sample, queries=queries, config=config)
sample = EnaProcessor.run_pipeline(sample=sample, queries=queries, config=config)
sample = EnaProcessor.load(sample=sample, queries=queries, config=config)
sample = EnaProcessor.compute_cooccurrence(sample=sample, queries=queries, config=config)
return sample

@staticmethod
Expand Down Expand Up @@ -178,10 +177,3 @@ def load(sample: SampleEna, queries: Queries, config: Configuration) -> SampleEn
vcf_file=sample.lofreq_vcf_path, run_accession=sample.run_accession, source=DataSource.ENA, session=queries.session)
sample.loaded_at = datetime.now()
return sample

@staticmethod
def compute_cooccurrence(sample: SampleEna, queries: Queries, config: Configuration) -> SampleEna:
CooccurrenceMatrix().compute(run_accession=sample.run_accession, source=DataSource.ENA, session=queries.session)
sample.cooccurrence_at = datetime.now()

return sample
4 changes: 3 additions & 1 deletion covigator/tests/unit_tests/abstract_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
SampleGisaid, SubclonalVariantObservation, Variant, PrecomputedVariantAbundanceHistogram, PrecomputedTableCounts, \
PrecomputedSynonymousNonSynonymousCounts, PrecomputedOccurrence, PrecomputedAnnotation, PrecomputedIndelLength, \
PrecomputedSubstitutionsCounts, PrecomputedVariantsPerSample, Log, GisaidVariant, GisaidVariantObservation, \
LowFrequencyVariantObservation, LowFrequencyVariant, SubclonalVariant, PrecomputedVariantsPerLineage
LowFrequencyVariantObservation, LowFrequencyVariant, SubclonalVariant, PrecomputedVariantsPerLineage, \
GisaidVariantCooccurrence
from covigator.database.queries import Queries
from covigator.tests.unit_tests.faked_objects import FakeConfiguration

Expand Down Expand Up @@ -47,6 +48,7 @@ def _clean_test_database(self):
self._clean_table(SampleEna)
self._clean_table(SampleGisaid)
self._clean_table(VariantCooccurrence)
self._clean_table(GisaidVariantCooccurrence)
self._clean_table(Variant)
self._clean_table(SubclonalVariant)
self._clean_table(LowFrequencyVariant)
Expand Down
12 changes: 8 additions & 4 deletions covigator/tests/unit_tests/mocked.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,14 @@ def get_mocked_variant(faker: Faker, chromosome=None, gene_name=None, source=Dat
end = 30000

klass = Queries.get_variant_klass(source)
reference = faker.random_choices(list(IUPACData.unambiguous_dna_letters), length=1)[0]
alternate = faker.random_choices(list(IUPACData.unambiguous_dna_letters), length=1)[0]
variant = klass(
chromosome=chromosome if chromosome else faker.bothify(text="chr##"),
position=faker.random_int(min=start, max=end),
reference=faker.random_choices(list(IUPACData.unambiguous_dna_letters), length=1)[0],
reference=reference,
# TODO: reference and alternate could be equal!
alternate=faker.random_choices(list(IUPACData.unambiguous_dna_letters), length=1)[0],
alternate=alternate,
variant_type=VariantType.SNV,
gene_name=gene_name,
hgvs_p="p.{}{}{}".format(
Expand All @@ -49,7 +51,8 @@ def get_mocked_variant(faker: Faker, chromosome=None, gene_name=None, source=Dat
),
annotation=annotation,
annotation_highest_impact=annotation,
pfam_name=domain_name
pfam_name=domain_name,
length=len(reference) - len(alternate)
)
variant.variant_id = variant.get_variant_id()
return variant
Expand All @@ -73,7 +76,8 @@ def get_mocked_variant_observation(
gene_name=variant.gene_name,
pfam_name=variant.pfam_name,
date=faker.date_time(),
hgvs_p=variant.hgvs_p
hgvs_p=variant.hgvs_p,
length=len(variant.reference) - len(variant.alternate)
)


Expand Down
6 changes: 3 additions & 3 deletions covigator/tests/unit_tests/test_cooccurrence_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_missing_session(self):
def test_non_existing_sample_does_not_add_new_entries(self):
count = self.session.query(VariantCooccurrence).count()
self.assertEqual(count, 0)
CooccurrenceMatrix().compute(run_accession="12345", source=DataSource.ENA, session=self.session)
CooccurrenceMatrix().compute(run_accession="12345", source=DataSource.ENA.name, session=self.session)
self.session.commit()
count = self.session.query(VariantCooccurrence).count()
self.assertEqual(count, 0)
Expand All @@ -58,7 +58,7 @@ def test_one_existing_sample(self):
count = self.session.query(VariantCooccurrence).count()
self.assertEqual(count, 0)
CooccurrenceMatrix().compute(
run_accession=self.samples[0].run_accession, source=DataSource.ENA, session=self.session)
run_accession=self.samples[0].run_accession, source=DataSource.ENA.name, session=self.session)
self.session.commit()
count = self.session.query(VariantCooccurrence).count()
# size of matrix = n*(n-1)/2 + n (one half of matrix + diagonal)
Expand All @@ -72,7 +72,7 @@ def test_all_samples(self):
count = self.session.query(VariantCooccurrence).count()
self.assertEqual(count, 0)
for s in self.samples:
CooccurrenceMatrix().compute(run_accession=s.run_accession, source=DataSource.ENA, session=self.session)
CooccurrenceMatrix().compute(run_accession=s.run_accession, source=DataSource.ENA.name, session=self.session)
self.session.commit()
count = self.session.query(VariantCooccurrence).count()
self.assertLess(
Expand Down
Loading

0 comments on commit d337730

Please sign in to comment.