Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(vep_parser): store consequence to impact score as a project config #811

Merged
merged 9 commits into from
Oct 3, 2024
46 changes: 0 additions & 46 deletions src/gentropy/assets/data/variant_consequence_to_score.tsv

This file was deleted.

4 changes: 2 additions & 2 deletions src/gentropy/common/spark_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,14 @@ def order_array_of_structs_by_two_fields(
)


def map_column_by_dictionary(col: Column, mapping_dict: dict[str, str]) -> Column:
def map_column_by_dictionary(col: Column, mapping_dict: dict[str, Any]) -> Column:
"""Map column values to dictionary values by key.

Missing consequence label will be converted to None, unmapped consequences will be mapped as None.

Args:
col (Column): Column containing labels to map.
mapping_dict (dict[str, str]): Dictionary with mapping key/value pairs.
mapping_dict (dict[str, Any]): Dictionary with mapping key/value pairs.

Returns:
Column: Column with mapped values.
Expand Down
64 changes: 63 additions & 1 deletion src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import os
from dataclasses import dataclass, field
from typing import Any, List
from typing import Any, ClassVar, List, TypedDict

from hail import __file__ as hail_location
from hydra.core.config_store import ConfigStore
Expand Down Expand Up @@ -348,11 +348,73 @@ class GnomadVariantConfig(StepConfig):
class VariantIndexConfig(StepConfig):
"""Variant index step configuration."""

class _ConsequenceToPathogenicityScoreMap(TypedDict):
"""Typing definition for CONSEQUENCE_TO_PATHOGENICITY_SCORE."""

id: str
label: str
score: float

session: SessionConfig = SessionConfig()
vep_output_json_path: str = MISSING
variant_index_path: str = MISSING
gnomad_variant_annotations_path: str | None = None
hash_threshold: int = 300
consequence_to_pathogenicity_score: ClassVar[
list[_ConsequenceToPathogenicityScoreMap]
] = [
{"id": "SO_0001575", "label": "splice_donor_variant", "score": 1.0},
{"id": "SO_0001589", "label": "frameshift_variant", "score": 1.0},
{"id": "SO_0001574", "label": "splice_acceptor_variant", "score": 1.0},
{"id": "SO_0001587", "label": "stop_gained", "score": 1.0},
{"id": "SO_0002012", "label": "start_lost", "score": 1.0},
{"id": "SO_0001578", "label": "stop_lost", "score": 1.0},
{"id": "SO_0001893", "label": "transcript_ablation", "score": 1.0},
{"id": "SO_0001822", "label": "inframe_deletion", "score": 0.66},
{
"id": "SO_0001818",
"label": "protein_altering_variant",
"score": 0.66,
},
{"id": "SO_0001821", "label": "inframe_insertion", "score": 0.66},
{
"id": "SO_0001787",
"label": "splice_donor_5th_base_variant",
"score": 0.66,
},
{"id": "SO_0001583", "label": "missense_variant", "score": 0.66},
{"id": "SO_0001567", "label": "stop_retained_variant", "score": 0.33},
{"id": "SO_0001630", "label": "splice_region_variant", "score": 0.33},
{"id": "SO_0002019", "label": "start_retained_variant", "score": 0.33},
{
"id": "SO_0002169",
"label": "splice_polypyrimidine_tract_variant",
"score": 0.33,
},
{"id": "SO_0001819", "label": "synonymous_variant", "score": 0.33},
{
"id": "SO_0002170",
"label": "splice_donor_region_variant",
"score": 0.33,
},
{"id": "SO_0001624", "label": "3_prime_UTR_variant", "score": 0.1},
{"id": "SO_0001623", "label": "5_prime_UTR_variant", "score": 0.1},
{"id": "SO_0001627", "label": "intron_variant", "score": 0.1},
{
"id": "SO_0001619",
"label": "non_coding_transcript_variant",
"score": 0.0,
},
{"id": "SO_0001580", "label": "coding_sequence_variant", "score": 0.0},
{"id": "SO_0001632", "label": "downstream_gene_variant", "score": 0.0},
{"id": "SO_0001631", "label": "upstream_gene_variant", "score": 0.0},
{
"id": "SO_0001792",
"label": "non_coding_transcript_exon_variant",
"score": 0.0,
},
{"id": "SO_0001620", "label": "mature_miRNA_variant", "score": 0.0},
]

_target_: str = "gentropy.variant_index.VariantIndexStep"

Expand Down
56 changes: 6 additions & 50 deletions src/gentropy/dataset/variant_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import (
get_nested_struct_schema,
get_record_with_maximum_value,
rename_all_columns,
safe_array_union,
)
Expand All @@ -22,7 +21,6 @@
from pyspark.sql.types import StructType



@dataclass
class VariantIndex(Dataset):
"""Dataset for representing variants and methods applied on them."""
Expand Down Expand Up @@ -130,7 +128,6 @@ def add_annotation(
# Prefix for renaming columns:
prefix = "annotation_"


# Generate select expressions that to merge and import columns from annotation:
select_expressions = []

Expand All @@ -146,9 +143,13 @@ def add_annotation(
if isinstance(field.dataType.elementType, t.StructType):
# Extract the schema of the array to get the order of the fields:
array_schema = [
field for field in VariantIndex.get_schema().fields if field.name == column
field
for field in VariantIndex.get_schema().fields
if field.name == column
][0].dataType
fields_order = get_nested_struct_schema(array_schema).fieldNames()
fields_order = get_nested_struct_schema(
array_schema
).fieldNames()
select_expressions.append(
safe_array_union(
f.col(column), f.col(f"{prefix}{column}"), fields_order
Expand Down Expand Up @@ -286,48 +287,3 @@ def get_loftee(self: VariantIndex) -> DataFrame:
"isHighQualityPlof",
)
)

def get_most_severe_gene_consequence(
self: VariantIndex,
*,
vep_consequences: DataFrame,
) -> DataFrame:
"""Returns a dataframe with the most severe consequence for a variant/gene pair.

Args:
vep_consequences (DataFrame): A dataframe of VEP consequences

Returns:
DataFrame: A dataframe with the most severe consequence (plus a severity score) for a variant/gene pair
"""
return (
self.df.select("variantId", f.explode("transcriptConsequences").alias("tc"))
.select(
"variantId",
f.col("tc.targetId"),
f.explode(f.col("tc.variantFunctionalConsequenceIds")).alias(
"variantFunctionalConsequenceId"
),
)
.join(
# TODO: make this table a project config
f.broadcast(
vep_consequences.selectExpr(
"variantFunctionalConsequenceId", "score as severityScore"
)
),
on="variantFunctionalConsequenceId",
how="inner",
)
.filter(f.col("severityScore").isNull())
.transform(
# A variant can have multiple predicted consequences on a transcript, the most severe one is selected
lambda df: get_record_with_maximum_value(
df, ["variantId", "targetId"], "severityScore"
)
)
.withColumnRenamed(
"variantFunctionalConsequenceId",
"mostSevereVariantFunctionalConsequenceId",
)
)
30 changes: 12 additions & 18 deletions src/gentropy/datasource/ensembl/vep_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,12 @@

from __future__ import annotations

import importlib.resources as pkg_resources
from typing import TYPE_CHECKING

import pandas as pd
from pyspark.sql import SparkSession
from pyspark.sql import functions as f
from pyspark.sql import types as t

from gentropy.assets import data
from gentropy.common.schemas import parse_spark_schema
from gentropy.common.spark_helpers import (
enforce_schema,
Expand All @@ -24,9 +21,12 @@
if TYPE_CHECKING:
from pyspark.sql import Column, DataFrame

from gentropy.config import VariantIndexConfig


class VariantEffectPredictorParser:
"""Collection of methods to parse VEP output in json format."""

# NOTE: Due to the fact that the comparison of the xrefs is done om the base of rsids
# if the field `colocalised_variants` have multiple rsids, this extracting xrefs will result in
# an array of xref structs, rather then the struct itself.
Expand Down Expand Up @@ -568,22 +568,16 @@ def process_vep_output(
Returns:
DataFrame: processed data in the right shape.
"""
so_df = pd.read_csv(
pkg_resources.open_text(
data, "variant_consequence_to_score.tsv", encoding="utf-8"
),
sep="\t",
)

# Reading consequence to sequence ontology map:
# Consequence to sequence ontology map:
sequence_ontology_map = {
row["label"]: row["variantFunctionalConsequenceId"]
for _, row in so_df.iterrows()
item["label"]: item["id"]
for item in VariantIndexConfig.consequence_to_pathogenicity_score
}
# Sequence ontology to score map:
label_to_score_map = {
item["label"]: item["score"]
for item in VariantIndexConfig.consequence_to_pathogenicity_score
}

# Reading score dictionary:
score_dictionary = {row["label"]: row["score"] for _, row in so_df.iterrows()}

# Processing VEP output:
return (
vep_output
Expand Down Expand Up @@ -694,7 +688,7 @@ def process_vep_output(
f.transform(
transcript.consequence_terms,
lambda term: map_column_by_dictionary(
term, score_dictionary
term, label_to_score_map
),
)
)
Expand Down
20 changes: 1 addition & 19 deletions tests/gentropy/dataset/test_variant_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from gentropy.dataset.variant_index import VariantIndex

if TYPE_CHECKING:
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql import SparkSession


def test_variant_index_creation(mock_variant_index: VariantIndex) -> None:
Expand Down Expand Up @@ -144,24 +144,6 @@ def test_get_distance_to_gene(
for col in expected_cols:
assert col in observed.columns, f"Column {col} not in {observed.columns}"

def test_get_most_severe_gene_consequence(
self: TestVariantIndex,
mock_variant_index: VariantIndex,
mock_variant_consequence_to_score: DataFrame,
) -> None:
"""Assert that the function returns a df with the requested columns."""
expected_cols = [
"variantId",
"targetId",
"mostSevereVariantFunctionalConsequenceId",
"severityScore",
]
observed = mock_variant_index.get_most_severe_gene_consequence(
vep_consequences=mock_variant_consequence_to_score
)
for col in expected_cols:
assert col in observed.columns, f"Column {col} not in {observed.columns}"

def test_get_loftee(
self: TestVariantIndex, mock_variant_index: VariantIndex
) -> None:
Expand Down
Loading