Skip to content

Commit

Permalink
feat: add studyType to StudyLocus and Colocalisation (and `Stud…
Browse files Browse the repository at this point in the history
…yLocusOverlap`) (#782)

* feat: add studyType to StudyLocus schema

* feat: add annotate_study_type function to add studyType to StudyLocus

* fix: remove lines for retrieving studyType as StudyLocus now contains studyType

* fix: add studyType to test input data as StudyLocus now contains studyType

* feat: add leftStudyType and rightStudyType to Colocalisation and StudyLocusOverlap schemas

* feat: update _convert_to_square_matrix and its test with leftStudyType and rightStudyType

* feat: update test_find_overlaps_semantic inputs with leftStudyType and rightStudyType

* feat: update tests in test_colocalisation_method.py with leftStudyType and rightStudyType

* feat: add leftStudyType and rightStudyType when creating StudyLocusOverlap

* feat: add leftStudyType and rightStudyType to Colocalisation results

* fix: remove redundant study_index parameter from filter_by_study_type function def and calls

* fix: remove redundant study_index parameter from find_overlaps function def and calls

* fix: remove leftStudyType from Colocalisation (not needed as always gwas)

* fix: remove leftStudyType from StudyLocusOverlap (not needed as always gwas)

* fix: missing comma

* feat: update tests in test_locus_to_gene.py with studyType and rightStudyType

* feat: update tests (colocalisation, l2g, l2g feature matrix) with rightStudyType

* fix: remove studyType from metadata_cols in append_study_metadata function call
  • Loading branch information
vivienho authored Sep 24, 2024
1 parent a29222e commit dcacaf7
Show file tree
Hide file tree
Showing 17 changed files with 83 additions and 83 deletions.
6 changes: 6 additions & 0 deletions src/gentropy/assets/schemas/colocalisation.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
"type": "long",
"metadata": {}
},
{
"name": "rightStudyType",
"nullable": false,
"type": "string",
"metadata": {}
},
{
"name": "chromosome",
"nullable": false,
Expand Down
6 changes: 6 additions & 0 deletions src/gentropy/assets/schemas/study_locus.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
"nullable": false,
"type": "long"
},
{
"metadata": {},
"name": "studyType",
"nullable": true,
"type": "string"
},
{
"metadata": {},
"name": "variantId",
Expand Down
6 changes: 6 additions & 0 deletions src/gentropy/assets/schemas/study_locus_overlap.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
"nullable": false,
"type": "long"
},
{
"metadata": {},
"name": "rightStudyType",
"nullable": false,
"type": "string"
},
{
"metadata": {},
"name": "chromosome",
Expand Down
8 changes: 1 addition & 7 deletions src/gentropy/colocalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pyspark.sql.functions import col

from gentropy.common.session import Session
from gentropy.dataset.study_index import StudyIndex
from gentropy.dataset.study_locus import CredibleInterval, StudyLocus
from gentropy.method.colocalisation import Coloc

Expand All @@ -23,7 +22,6 @@ def __init__(
self,
session: Session,
credible_set_path: str,
study_index_path: str,
coloc_path: str,
colocalisation_method: str,
) -> None:
Expand All @@ -32,7 +30,6 @@ def __init__(
Args:
session (Session): Session object.
credible_set_path (str): Input credible sets path.
study_index_path (str): Input study index path.
coloc_path (str): Output Colocalisation path.
colocalisation_method (str): Colocalisation method.
"""
Expand All @@ -47,14 +44,11 @@ def __init__(
session, credible_set_path, recursiveFileLookup=True
)
)
si = StudyIndex.from_parquet(
session, study_index_path, recursiveFileLookup=True
)

# Transform
overlaps = credible_set.filter_credible_set(
CredibleInterval.IS95
).find_overlaps(si)
).find_overlaps()
colocalisation_results = colocalisation_class.colocalise(overlaps) # type: ignore

# Load
Expand Down
2 changes: 1 addition & 1 deletion src/gentropy/dataset/colocalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def extract_maximum_coloc_probability_per_region_and_gene(
self.append_study_metadata(
study_locus,
study_index,
metadata_cols=["studyType", "geneId"],
metadata_cols=["geneId"],
colocalisation_side="right",
)
# it also filters based on method and qtl type
Expand Down
31 changes: 25 additions & 6 deletions src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,24 @@ def validate_study(self: StudyLocus, study_index: StudyIndex) -> StudyLocus:
_schema=self.get_schema(),
)

def annotate_study_type(self: StudyLocus, study_index: StudyIndex) -> StudyLocus:
"""Gets study type from study index and adds it to study locus.
Args:
study_index (StudyIndex): Study index to get study type.
Returns:
StudyLocus: Updated study locus with study type.
"""
return StudyLocus(
_df=(
self.df
.drop("studyType")
.join(study_index.study_type_lut(), on="studyId", how="left")
),
_schema=self.get_schema(),
)

def validate_variant_identifiers(
self: StudyLocus, variant_index: VariantIndex
) -> StudyLocus:
Expand Down Expand Up @@ -394,6 +412,7 @@ def _align_overlapping_tags(
f.col("chromosome"),
f.col("tagVariantId"),
f.col("studyLocusId").alias("rightStudyLocusId"),
f.col("studyType").alias("rightStudyType"),
*[f.col(col).alias(f"right_{col}") for col in stats_cols],
).join(peak_overlaps, on=["chromosome", "rightStudyLocusId"], how="inner")

Expand All @@ -410,6 +429,7 @@ def _align_overlapping_tags(
).select(
"leftStudyLocusId",
"rightStudyLocusId",
"rightStudyType",
"chromosome",
"tagVariantId",
f.struct(
Expand Down Expand Up @@ -505,13 +525,12 @@ def get_QC_mappings(cls: type[StudyLocus]) -> dict[str, str]:
return {member.name: member.value for member in StudyLocusQualityCheck}

def filter_by_study_type(
self: StudyLocus, study_type: str, study_index: StudyIndex
self: StudyLocus, study_type: str
) -> StudyLocus:
"""Creates a new StudyLocus dataset filtered by study type.
Args:
study_type (str): Study type to filter for. Can be one of `gwas`, `eqtl`, `pqtl`, `eqtl`.
study_index (StudyIndex): Study index to resolve study types.
Returns:
StudyLocus: Filtered study-locus dataset.
Expand All @@ -524,7 +543,7 @@ def filter_by_study_type(
f"Study type {study_type} not supported. Supported types are: gwas, eqtl, pqtl, sqtl."
)
new_df = (
self.df.join(study_index.study_type_lut(), on="studyId", how="inner")
self.df
.filter(f.col("studyType") == study_type)
.drop("studyType")
)
Expand Down Expand Up @@ -576,22 +595,22 @@ def filter_ld_set(ld_set: Column, r2_threshold: float) -> Column:
)

def find_overlaps(
self: StudyLocus, study_index: StudyIndex, intra_study_overlap: bool = False
self: StudyLocus, intra_study_overlap: bool = False
) -> StudyLocusOverlap:
"""Calculate overlapping study-locus.
Find overlapping study-locus that share at least one tagging variant. All GWAS-GWAS and all GWAS-Molecular traits are computed with the Molecular traits always
appearing on the right side.
Args:
study_index (StudyIndex): Study index to resolve study types.
intra_study_overlap (bool): If True, finds intra-study overlaps for credible set deduplication. Default is False.
Returns:
StudyLocusOverlap: Pairs of overlapping study-locus with aligned tags.
"""
loci_to_overlap = (
self.df.join(study_index.study_type_lut(), on="studyId", how="inner")
self.df
.filter(f.col("studyType").isNotNull())
.withColumn("locus", f.explode("locus"))
.select(
"studyLocusId",
Expand Down
7 changes: 3 additions & 4 deletions src/gentropy/dataset/study_locus_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
if TYPE_CHECKING:
from pyspark.sql.types import StructType

from gentropy.dataset.study_index import StudyIndex
from gentropy.dataset.study_locus import StudyLocus


Expand All @@ -36,18 +35,17 @@ def get_schema(cls: type[StudyLocusOverlap]) -> StructType:

@classmethod
def from_associations(
cls: type[StudyLocusOverlap], study_locus: StudyLocus, study_index: StudyIndex
cls: type[StudyLocusOverlap], study_locus: StudyLocus
) -> StudyLocusOverlap:
"""Find the overlapping signals in a particular set of associations (StudyLocus dataset).
Args:
study_locus (StudyLocus): Study-locus associations to find the overlapping signals
study_index (StudyIndex): Study index to find the overlapping signals
Returns:
StudyLocusOverlap: Study-locus overlap dataset
"""
return study_locus.find_overlaps(study_index)
return study_locus.find_overlaps()

def _convert_to_square_matrix(self: StudyLocusOverlap) -> StudyLocusOverlap:
"""Convert the dataset to a square matrix.
Expand All @@ -60,6 +58,7 @@ def _convert_to_square_matrix(self: StudyLocusOverlap) -> StudyLocusOverlap:
self.df.selectExpr(
"leftStudyLocusId as rightStudyLocusId",
"rightStudyLocusId as leftStudyLocusId",
"rightStudyType",
"tagVariantId",
)
).distinct(),
Expand Down
4 changes: 2 additions & 2 deletions src/gentropy/l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr
ValueError: If write_feature_matrix is set to True but a path is not provided.
ValueError: If dependencies to build features are not set.
"""
if self.gs_curation and self.interactions and self.v2g and self.studies:
if self.gs_curation and self.interactions and self.v2g:
study_locus_overlap = StudyLocus(
_df=self.credible_set.df.join(
f.broadcast(
Expand All @@ -225,7 +225,7 @@ def _generate_feature_matrix(self, write_feature_matrix: bool) -> L2GFeatureMatr
"inner",
),
_schema=StudyLocus.get_schema(),
).find_overlaps(self.studies)
).find_overlaps()

gold_standards = L2GGoldStandard.from_otg_curation(
gold_standard_curation=self.gs_curation,
Expand Down
4 changes: 2 additions & 2 deletions src/gentropy/method/colocalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def colocalise(
f.col("statistics.right_posteriorProbability"),
),
)
.groupBy("leftStudyLocusId", "rightStudyLocusId", "chromosome")
.groupBy("leftStudyLocusId", "rightStudyLocusId", "rightStudyType", "chromosome")
.agg(
f.count("*").alias("numberColocalisingVariants"),
f.sum(f.col("clpp")).alias("clpp"),
Expand Down Expand Up @@ -168,7 +168,7 @@ def colocalise(
f.col("left_logBF") + f.col("right_logBF"),
)
# Group by overlapping peak and generating dense vectors of log_BF:
.groupBy("chromosome", "leftStudyLocusId", "rightStudyLocusId")
.groupBy("chromosome", "leftStudyLocusId", "rightStudyLocusId", "rightStudyType")
.agg(
f.count("*").alias("numberColocalisingVariants"),
fml.array_to_vector(f.collect_list(f.col("left_logBF"))).alias(
Expand Down
1 change: 1 addition & 0 deletions src/gentropy/study_locus_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(
# Add flag for MHC region
.qc_MHC_region()
.validate_study(study_index) # Flagging studies not in study index
.annotate_study_type(study_index) # Add study type to study locus
.qc_redundant_top_hits_from_PICS() # Flagging top hits from studies with PICS summary statistics
.validate_unique_study_locus_id() # Flagging duplicated study locus ids
).persist() # we will need this for 2 types of outputs
Expand Down
3 changes: 2 additions & 1 deletion tests/gentropy/dataset/test_colocalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,11 @@ def _setup(self: TestAppendStudyMetadata, spark: SparkSession) -> None:
)
self.sample_colocalisation = Colocalisation(
_df=spark.createDataFrame(
[(1, 2, "X", "COLOC", 1, 0.9)],
[(1, 2, "eqtl", "X", "COLOC", 1, 0.9)],
[
"leftStudyLocusId",
"rightStudyLocusId",
"rightStudyType",
"chromosome",
"colocalisationMethod",
"numberColocalisingVariants",
Expand Down
4 changes: 2 additions & 2 deletions tests/gentropy/dataset/test_l2g.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def test_filter_unique_associations(spark: SparkSession) -> None:
)

mock_sl_overlap_df = spark.createDataFrame(
[(1, 2, "variant2"), (1, 4, "variant4")],
"leftStudyLocusId LONG, rightStudyLocusId LONG, tagVariantId STRING",
[(1, 2, "eqtl", "variant2"), (1, 4, "eqtl", "variant4")],
"leftStudyLocusId LONG, rightStudyLocusId LONG, rightStudyType STRING, tagVariantId STRING",
)

expected_df = spark.createDataFrame(
Expand Down
3 changes: 2 additions & 1 deletion tests/gentropy/dataset/test_l2g_feature_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,11 @@ def _setup(self: TestFromFeaturesList, spark: SparkSession) -> None:
)
self.sample_colocalisation = Colocalisation(
_df=spark.createDataFrame(
[(1, 2, "X", "COLOC", 1, 0.9)],
[(1, 2, "eqtl", "X", "COLOC", 1, 0.9)],
[
"leftStudyLocusId",
"rightStudyLocusId",
"rightStudyType",
"chromosome",
"colocalisationMethod",
"numberColocalisingVariants",
Expand Down
Loading

0 comments on commit dcacaf7

Please sign in to comment.