Skip to content

Commit

Permalink
Merge pull request #39 from TRON-Bioinformatics/develop
Browse files Browse the repository at this point in the history
Make co-occurence matrix computation more efficient
  • Loading branch information
priesgo authored Jun 6, 2022
2 parents 2b2b318 + 2ae8305 commit b897404
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 184 deletions.
2 changes: 1 addition & 1 deletion covigator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
VERSION = "v0.7.4"
VERSION = "v0.7.5"
ANALYSIS_PIPELINE_VERSION = "v0.9.3"

MISSENSE_VARIANT = "missense_variant"
Expand Down
4 changes: 2 additions & 2 deletions covigator/command_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def cooccurrence():
args = parser.parse_args()

database = Database(initialize=True, config=Configuration())
loader = CooccurrenceMatrixLoader(session=database.get_database_session())
loader = CooccurrenceMatrixLoader(session=database.get_database_session(), source=args.data_source)
logger.info("Starting precomputation...")
loader.load(data_source=args.data_source, maximum_length=int(args.maximum_length))
loader.load(maximum_length=int(args.maximum_length))
logger.info("Done precomputing")
1 change: 1 addition & 0 deletions covigator/dashboard/figures/recurrent_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,7 @@ def get_variants_clustering(self, sparse_matrix, min_cooccurrence, min_samples):
tables.append(html.Button("Download CSV", id="btn_csv"))
tables.append(dcc.Download(id="download-dataframe-csv"))
tables.append(dcc.Store(id="memory", data=data.to_dict('records')))
tables.append(html.Br()),
tables.append(dcc.Markdown("""
***Co-occurrence clustering*** *shows the resulting clusters from the
co-occurrence matrix with the Jaccard index corrected with the Cohen's kappa coefficient.
Expand Down
7 changes: 5 additions & 2 deletions covigator/dashboard/tabs/recurrent_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,5 +332,8 @@ def update_cooccurrence_heatmap(
prevent_initial_call=True,
)
def func(n_clicks, df):
return dcc.send_data_frame(pd.DataFrame.from_dict(df).to_csv, "covigator_clustering_results_{}.csv".format(
datetime.now().strftime("%Y%m%d%H%M%S)")))
return dcc.send_data_frame(
pd.DataFrame.from_dict(df).to_csv,
"covigator_clustering_results_{}.csv".format(datetime.now().strftime("%Y%m%d%H%M%S")),
index=False
)
36 changes: 0 additions & 36 deletions covigator/database/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,42 +334,6 @@ def get_non_synonymous_variants_by_region(self, start, end, source) -> pd.DataFr
subquery = query.group_by(klass.position, klass.annotation_highest_impact, klass.hgvs_p).subquery()
return pd.read_sql(
self.session.query(subquery).filter(subquery.c.count_occurrences > 1).statement, self.session.bind)

def get_variant_ids_by_sample(self, sample_id, source: str, maximum_length: int) -> List[str]:
"""
Returns the variant ids of all mutations in a given sample after filtering out:
mutations not overlapping any gene, synonymous mutations, long indels according to maximum_length parameter
"""
klass = self.get_variant_observation_klass(source=source)
return self.session.query(klass.variant_id) \
.filter(and_(klass.sample == sample_id,
klass.gene_name != None,
klass.annotation_highest_impact != SYNONYMOUS_VARIANT,
klass.length < maximum_length,
klass.length > -maximum_length)) \
.order_by(klass.position, klass.reference, klass.alternate) \
.all()

def increment_variant_cooccurrence(
self, variant_id_one: str, variant_id_two: str, source: str) -> \
Union[VariantCooccurrence, GisaidVariantCooccurrence, None]:

# NOTE: this method does not commit to DB due to performance reasons
klazz = self.get_variant_cooccurrence_klass(source=source)

variant_cooccurrence = self.session.query(klazz) \
.filter(and_(klazz.variant_id_one == variant_id_one,
klazz.variant_id_two == variant_id_two)) \
.first()
if variant_cooccurrence is None:
variant_cooccurrence = klazz(
variant_id_one=variant_id_one,
variant_id_two=variant_id_two,
count=1)
return variant_cooccurrence
else:
variant_cooccurrence.count = variant_cooccurrence.count + 1
return None

def count_samples(self, source: str, cache=True) -> int:
self._assert_data_source(source)
Expand Down
27 changes: 0 additions & 27 deletions covigator/pipeline/cooccurrence_matrix.py

This file was deleted.

111 changes: 96 additions & 15 deletions covigator/precomputations/load_cooccurrences.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,119 @@
from itertools import combinations
from typing import Union, List

from sqlalchemy import and_
from sqlalchemy.orm import Session
from covigator.database.model import DataSource, JobStatus

from covigator import SYNONYMOUS_VARIANT
from covigator.database.model import JobStatus, VariantCooccurrence, GisaidVariantCooccurrence
from covigator.database.queries import Queries
from covigator.pipeline.cooccurrence_matrix import CooccurrenceMatrix
from logzero import logger

BATCH_SIZE = 1000


class CooccurrenceMatrixLoader:

def __init__(self, session: Session):
def __init__(self, session: Session, source: str):
self.session = session
self.queries = Queries(session=self.session)
self.cooccurrence_matrix = CooccurrenceMatrix()
self.source = source
self.cache = {}
self.variant_klazz = self.queries.get_variant_cooccurrence_klass(source=source)
self.variant_cooccurrence_klazz = self.queries.get_variant_cooccurrence_klass(source=self.source)

def load(self, data_source: str, maximum_length: int):
def load(self, maximum_length: int):

# deletes the database before loading
self.session.query(self.queries.get_variant_cooccurrence_klass(data_source)).delete()
self.session.query(self.variant_cooccurrence_klazz).delete()

# iterates over every sample in FINISHED status and computes the cooccurrence matrix
sample_klass = self.queries.get_sample_klass(data_source)
count_samples = self.queries.count_samples(source=data_source, cache=False)
sample_klass = self.queries.get_sample_klass(self.source)
count_samples = self.queries.count_samples(source=self.source, cache=False)
computed = 0
query = self.session.query(sample_klass).filter(sample_klass.status == JobStatus.FINISHED)
for sample in self.queries.windowed_query(query=query, column=sample_klass.run_accession, windowsize=1000):
self.cooccurrence_matrix.compute(sample.run_accession, data_source, self.session,
maximum_length=maximum_length)
for sample in self.queries.windowed_query(query=query, column=sample_klass.run_accession, windowsize=BATCH_SIZE):
self._compute_sample(sample.run_accession, maximum_length=maximum_length)
computed += 1
if computed % 1000 == 0:
if computed % BATCH_SIZE == 0:
# commits batches of 1000 samples
self._commit_cache()
logger.info('Processed cooccurrence over {}/{} ({} %) samples'.format(
computed, count_samples, round(float(computed) / count_samples * 100, 3)))

# commits the last batch
self._commit_cache()

# once finished deletes the unique observations
variant_cooccurrence_klazz = self.queries.get_variant_cooccurrence_klass(source=data_source)
self.session.query(self.queries.get_variant_cooccurrence_klass(data_source)) \
.filter(variant_cooccurrence_klazz.count == 1) \
self.session.query(self.queries.get_variant_cooccurrence_klass(self.source)) \
.filter(self.variant_cooccurrence_klazz.count == 1) \
.delete()
self.session.commit()

def _get_variant_ids_by_sample(self, sample_id, source: str, maximum_length: int) -> List[str]:
"""
Returns the variant ids of all mutations in a given sample after filtering out:
mutations not overlapping any gene, synonymous mutations, long indels according to maximum_length parameter
"""
klass = self.queries.get_variant_observation_klass(source=source)
return self.session.query(klass.variant_id) \
.filter(and_(klass.sample == sample_id,
klass.gene_name != None,
klass.annotation_highest_impact != SYNONYMOUS_VARIANT,
klass.length < maximum_length,
klass.length > -maximum_length)) \
.order_by(klass.position, klass.reference, klass.alternate) \
.all()

def _unique_id(self, variant_id_one: str, variant_id_two: str):
return "{}-{}".format(variant_id_one, variant_id_two)

def _get_from_cache(self, variant_id_one: str, variant_id_two: str):
return self.cache.get(self._unique_id(variant_id_one, variant_id_two), None)

def _store_in_cache(self, variant_id_one: str, variant_id_two: str,
entry: Union[VariantCooccurrence, GisaidVariantCooccurrence]):
self.cache[self._unique_id(variant_id_one, variant_id_two)] = entry

def _commit_cache(self):
self.session.add_all(list(self.cache.values()))
self.session.commit()
self.cache = {}

def _increment_variant_cooccurrence(self, variant_id_one: str, variant_id_two: str):

# NOTE: this method does not commit to DB due to performance reasons

# first looks in the cache
variant_cooccurrence = self._get_from_cache(variant_id_one=variant_id_one, variant_id_two=variant_id_two)
if variant_cooccurrence is None:
# if not in the cache looks in the DB
variant_cooccurrence = self.session.query(self.variant_klazz) \
.filter(and_(self.variant_klazz.variant_id_one == variant_id_one,
self.variant_klazz.variant_id_two == variant_id_two)) \
.first()
if variant_cooccurrence is None:
# if not in the cache and not in the DB creates a new one
variant_cooccurrence = self.variant_klazz(
variant_id_one=variant_id_one,
variant_id_two=variant_id_two,
count=1)
else:
variant_cooccurrence.count = variant_cooccurrence.count + 1

# stores the changes in the cache
self._store_in_cache(variant_id_one=variant_id_one, variant_id_two=variant_id_two, entry=variant_cooccurrence)

def _compute_sample(self, run_accession: str, maximum_length: int = 10):

assert run_accession is not None or run_accession == "", "Missing sample identifier"

sample_id = run_accession

# the order by position is important to ensure we store only half the matrix and the same half of the matrix
variant_ids = self._get_variant_ids_by_sample(sample_id, source=self.source, maximum_length=maximum_length)

# process all pairwise combinations without repetitions including the diagoonal
for (variant_id_one, variant_id_two) in list(combinations(variant_ids, 2)) + list(
zip(variant_ids, variant_ids)):
self._increment_variant_cooccurrence(variant_id_one, variant_id_two)
4 changes: 1 addition & 3 deletions covigator/processor/ena_processor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
import numpy as np
from datetime import datetime

import pandas as pd
Expand All @@ -8,15 +7,14 @@
from covigator.configuration import Configuration
from covigator.database.queries import Queries
from covigator.exceptions import CovigatorErrorProcessingCoverageResults, CovigatorExcludedSampleBadQualityReads, \
CovigatorExcludedSampleNarrowCoverage, CovigatorErrorProcessingPangolinResults, \
CovigatorExcludedSampleNarrowCoverage, \
CovigatorErrorProcessingDeduplicationResults
from covigator.misc import backoff_retrier
from covigator.database.model import JobStatus, DataSource, SampleEna
from covigator.database.database import Database
from logzero import logger
from dask.distributed import Client
from covigator.processor.abstract_processor import AbstractProcessor
from covigator.pipeline.cooccurrence_matrix import CooccurrenceMatrix
from covigator.pipeline.downloader import Downloader
from covigator.pipeline.ena_pipeline import Pipeline
from covigator.pipeline.vcf_loader import VcfLoader
Expand Down
97 changes: 0 additions & 97 deletions covigator/tests/unit_tests/test_cooccurrence_matrix.py

This file was deleted.

2 changes: 1 addition & 1 deletion covigator/tests/unit_tests/test_precomputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def test_load_cooccurrence_matrix(self, source):
variant_cooccurrence_klass = self.queries.get_variant_cooccurrence_klass(source)

self.assertEqual(self.session.query(variant_cooccurrence_klass).count(), 0)
CooccurrenceMatrixLoader(self.session).load(data_source=source, maximum_length=10)
CooccurrenceMatrixLoader(self.session, source=source).load(maximum_length=10)
self.assertGreater(self.session.query(variant_cooccurrence_klass).count(), 0)
found_greater_one = False
for p in self.session.query(variant_cooccurrence_klass).all():
Expand Down

0 comments on commit b897404

Please sign in to comment.