diff --git a/README.md b/README.md index 43f9f1b..d19c186 100644 --- a/README.md +++ b/README.md @@ -20,12 +20,6 @@ of `combine_scorefile` to produce scoring files for plink 2 $ pip install pgscatalog-utils ``` -Or clone the repo: - -``` -$ git clone https://github.com/PGScatalog/pgscatalog_utils.git -``` - ## Quickstart ``` @@ -33,3 +27,43 @@ $ download_scorefiles -i PGS000922 PGS001229 -o . -b GRCh37 $ combine_scorefiles -s PGS*.txt.gz -o combined.txt $ match_variants -s combined.txt -t --min_overlap 0.75 --outdir . ``` + +More details are available using the `--help` parameter. + +## Install from source + +Requirements: + +- python 3.10 +- [poetry](https://python-poetry.org) + +``` +$ git clone https://github.com/PGScatalog/pgscatalog_utils.git +$ cd pgscatalog_utils +$ poetry install +$ poetry build +$ pip install --user dist/*.whl +``` + +## Credits + +The `pgscatalog_utils` package is developed as part of the **Polygenic Score (PGS) Catalog** +([www.PGSCatalog.org](https://www.PGSCatalog.org)) project, a collaboration between the +University of Cambridge’s Department of Public Health and Primary Care (Michael Inouye, Samuel Lambert, Laurent Gil) +and the European Bioinformatics Institute (Helen Parkinson, Aoife McMahon, Ben Wingfield, Laura Harris). + +A manuscript describing the tool and larger PGS Catalog Calculator pipeline +[(`PGSCatalog/pgsc_calc`)](https://github.com/PGScatalog/pgsc_calc) is in preparation. In the meantime +if you use these tools we ask you to cite the repo(s) and the paper describing the PGS Catalog resource: + +- >PGS Catalog utilities _(in development)_. PGS Catalog + Team. [https://github.com/PGScatalog/pgscatalog_utils](https://github.com/PGScatalog/pgscatalog_utils) +- >PGS Catalog Calculator _(in development)_. PGS Catalog + Team. [https://github.com/PGScatalog/pgsc_calc](https://github.com/PGScatalog/pgsc_calc) +- >Lambert _et al._ (2021) The Polygenic Score Catalog as an open database for +reproducibility and systematic evaluation. Nature Genetics. 53:420–425 +doi:[10.1038/s41588-021-00783-5](https://doi.org/10.1038/s41588-021-00783-5). + +This work has received funding from EMBL-EBI core funds, the Baker Institute, the University of Cambridge, +Health Data Research UK (HDRUK), and the European Union's Horizon 2020 research and innovation programme +under grant agreement No 101016775 INTERVENE. \ No newline at end of file diff --git a/conftest.py b/conftest.py index 08b0f9b..a30f2cd 100644 --- a/conftest.py +++ b/conftest.py @@ -6,6 +6,7 @@ from pgscatalog_utils.scorefile.combine_scorefiles import combine_scorefiles from pysqlar import SQLiteArchive import pandas as pd +import glob @pytest.fixture(scope="session") @@ -21,11 +22,7 @@ def scorefiles(tmp_path_factory, pgs_accessions): with patch('sys.argv', args): download_scorefile() - paths: list[str] = [os.path.join(fn.resolve(), x + '.txt.gz') for x in pgs_accessions] - - assert all([os.path.exists(x) for x in paths]) - - return paths + return glob.glob(os.path.join(fn.resolve(), "*.txt.gz")) @pytest.fixture(scope="session") @@ -117,7 +114,7 @@ def chain_files(db, tmp_path_factory): def lifted_scorefiles(scorefiles, chain_files, tmp_path_factory): out_path = tmp_path_factory.mktemp("scores") / "lifted.txt" args: list[str] = ['combine_scorefiles', '-s'] + scorefiles + ['--liftover', '-c', chain_files, '-t', 'GRCh38', - '-m', '0.95'] + ['-o', str(out_path.resolve())] + '-m', '0.8'] + ['-o', str(out_path.resolve())] with patch('sys.argv', args): combine_scorefiles() diff --git a/pgscatalog_utils/__init__.py b/pgscatalog_utils/__init__.py index b794fd4..df9144c 100644 --- a/pgscatalog_utils/__init__.py +++ b/pgscatalog_utils/__init__.py @@ -1 +1 @@ -__version__ = '0.1.0' +__version__ = '0.1.1' diff --git a/pgscatalog_utils/download/api.py b/pgscatalog_utils/download/api.py deleted file mode 100644 index 8fdb1fe..0000000 --- a/pgscatalog_utils/download/api.py +++ /dev/null @@ -1,43 +0,0 @@ -import requests -import jq -import logging -import sys - -logger = logging.getLogger(__name__) - - -def pgscatalog_result(pgs: list[str], build: str) -> dict[str, str]: - result = _parse_json_query(_api_query(pgs), build) - - try: - if len(pgs) > len(result): - missing_pgs: set[str] = set(pgs).difference(set(result.keys())) - logger.warning(f"Some queries missing in PGS Catalog response: {missing_pgs}") - except TypeError: - logger.error(f"Bad response from PGS Catalog API. Is {pgs} a valid ID?") - sys.exit(1) - - return result - - -def _api_query(pgs_id: list[str]) -> dict: - pgs: str = ','.join(pgs_id) - api: str = f'https://www.pgscatalog.org/rest/score/search?pgs_ids={pgs}' - r: requests.models.Response = requests.get(api) - return r.json() - - -def _parse_json_query(json: dict, build: str) -> dict[str, str]: - result = jq.compile(".results").input(json).first() - if not result: - logger.warning("No results in response from PS Catalog API. Please check the PGS IDs.") - else: - return _extract_ftp_url(json, build) - - -def _extract_ftp_url(json: list[dict], build: str) -> dict[str, str]: - id: list[str] = jq.compile('[.results][][].id').input(json).all() - result: list[str] = jq.compile(f'[.results][][].ftp_harmonized_scoring_files.{build}.positions').input(json).all() - return dict(zip(id, [x.replace('https', 'ftp') for x in result])) - - diff --git a/pgscatalog_utils/download/download_scorefile.py b/pgscatalog_utils/download/download_scorefile.py index 74a79e1..30f8ac8 100644 --- a/pgscatalog_utils/download/download_scorefile.py +++ b/pgscatalog_utils/download/download_scorefile.py @@ -1,44 +1,58 @@ -import logging import argparse +import logging import os import shutil +import textwrap from contextlib import closing +from functools import reduce from urllib import request as request -from pgscatalog_utils.download.api import pgscatalog_result + +from pgscatalog_utils.download.publication import query_publication +from pgscatalog_utils.download.score import get_url +from pgscatalog_utils.download.trait import query_trait from pgscatalog_utils.log_config import set_logging_level logger = logging.getLogger(__name__) -def parse_args(args=None) -> argparse.Namespace: - parser: argparse.ArgumentParser = argparse.ArgumentParser(description='Download scoring files') - parser.add_argument('-i', '--id', nargs='+', dest='pgs', - help=' PGS Catalog ID', required=True) - parser.add_argument('-b', '--build', dest='build', required=True, - help=' Genome build: GRCh37 or GRCh38') - parser.add_argument('-o', '--outdir', dest='outdir', required=True, - default='scores/', - help=' Output directory to store downloaded files') - parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', - help=' Extra logging information') - return parser.parse_args(args) +def download_scorefile() -> None: + args = _parse_args() + set_logging_level(args.verbose) + _check_args(args) + _mkdir(args.outdir) + if args.build is None: + logger.critical(f'Downloading scoring file(s) in the author-reported genome build') + elif args.build in ['GRCh37', 'GRCh38']: + logger.critical(f'Downloading harmonized scoring file(s) in build: {args.build}.') + else: + logger.critical(f'Invalid genome build specified: {args.build}. Only -b GRCh37 and -b GRCh38 are supported') + raise Exception -def download_scorefile() -> None: - args = parse_args() + pgs_lst: list[list[str]] = [] - set_logging_level(args.verbose) + if args.efo: + logger.debug("--trait set, querying traits") + pgs_lst = pgs_lst + [query_trait(x) for x in args.efo] - _mkdir(args.outdir) + if args.pgp: + logger.debug("--pgp set, querying publications") + pgs_lst = pgs_lst + [query_publication(x) for x in args.pgp] - if args.build not in ['GRCh37', 'GRCh38']: - raise Exception(f'Invalid genome build specified: {args.build}. Only -b GRCh37 and -b GRCh38 are supported') + if args.pgs: + logger.debug("--id set, querying scores") + pgs_lst.append(args.pgs) # pgs_lst: a list containing up to three flat lists - urls: dict[str, str] = pgscatalog_result(args.pgs, args.build) + pgs_id: list[str] = list(set(reduce(lambda x, y: x + y, pgs_lst))) + + urls: dict[str, str] = get_url(pgs_id, args.build) for pgsid, url in urls.items(): logger.debug(f"Downloading {pgsid} from {url}") - path: str = os.path.join(args.outdir, pgsid + '.txt.gz') + if args.build is None: + path: str = os.path.join(args.outdir, pgsid + '.txt.gz') + else: + path: str = os.path.join(args.outdir, pgsid + f'_hmPOS_{args.build}.txt.gz') _download_ftp(url, path) @@ -58,5 +72,59 @@ def _download_ftp(url: str, path: str) -> None: shutil.copyfileobj(r, f) +def _check_args(args): + if not args.efo: + if not args.pgp: + if not args.pgs: + logger.critical("One of --trait, --pgp, or --id is required to download scorefiles") + raise Exception + + +def _description_text() -> str: + return textwrap.dedent('''\ + Download a set of scoring files from the PGS Catalog using PGS + Scoring IDs, traits, or publication IDs. + + The PGS Catalog API is queried to get a list of scoring file + URLs. Scoring files are downloaded via FTP to a specified + directory. PGS Catalog scoring files are staged with the name: + + {PGS_ID}.txt.gz + + If a valid build is specified harmonized files are downloaded as: + + {PGS_ID}_hmPOS_{genome_build}.txt.gz + + These harmonised scoring files contain genomic coordinates, + remapped from author-submitted information such as rsids. + ''') + + +def _epilog_text() -> str: + return textwrap.dedent('''\ + download_scorefiles will skip downloading a scoring file if it + already exists in the download directory. This can be useful if + the download process is interrupted and needs to be restarted + later. You can track download progress with the verbose flag. + ''') + + +def _parse_args(args=None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=_description_text(), epilog=_epilog_text(), + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument('-i', '--pgs', nargs='+', dest='pgs', help='PGS Catalog ID(s) (e.g. PGS000001)') + parser.add_argument('-t', '--efo', dest='efo', nargs='+', + help='Traits described by an EFO term(s) (e.g. EFO_0004611)') + parser.add_argument('-p', '--pgp', dest='pgp', help='PGP publication ID(s) (e.g. PGP000007)', nargs='+') + parser.add_argument('-b', '--build', dest='build', + help='Download Harmonized Scores with Positions in Genome build: GRCh37 or GRCh38') + parser.add_argument('-o', '--outdir', dest='outdir', required=True, + default='scores/', + help=' Output directory to store downloaded files') + parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', + help=' Extra logging information') + return parser.parse_args(args) + + if __name__ == "__main__": download_scorefile() diff --git a/pgscatalog_utils/download/publication.py b/pgscatalog_utils/download/publication.py new file mode 100644 index 0000000..b5e90fa --- /dev/null +++ b/pgscatalog_utils/download/publication.py @@ -0,0 +1,22 @@ +import requests +import logging +from functools import reduce + +logger = logging.getLogger(__name__) + + +def query_publication(pgp: str) -> list[str]: + api: str = f'https://www.pgscatalog.org/rest/publication/{pgp}' + logger.debug("Querying PGS Catalog with publication PGP ID") + r: requests.models.Response = requests.get(api) + + if r.json() == {}: + logger.critical(f"Bad response from PGS Catalog for EFO term: {pgp}") + raise Exception + + pgs: dict[str, list[str]] = r.json().get('associated_pgs_ids') + logger.debug(f"Valid response from PGS Catalog for PGP ID: {pgp}") + return list(reduce(lambda x, y: set(x).union(set(y)), pgs.values())) + + + diff --git a/pgscatalog_utils/download/score.py b/pgscatalog_utils/download/score.py new file mode 100644 index 0000000..61a0154 --- /dev/null +++ b/pgscatalog_utils/download/score.py @@ -0,0 +1,57 @@ +import requests +import logging +import jq +import sys + +logger = logging.getLogger(__name__) + + +def get_url(pgs: list[str], build: str) -> dict[str, str]: + pgs_result: list[str] = [] + url_result: list[str] = [] + + for chunk in _chunker(pgs): + try: + response = _parse_json_query(query_score(chunk), build) + pgs_result = pgs_result + list(response.keys()) + url_result = url_result + list(response.values()) + except TypeError: + logger.error(f"Bad response from PGS Catalog API. Is {pgs} a valid ID?") + sys.exit(1) + + missing_pgs = set(pgs).difference(set(pgs_result)) + + if missing_pgs: + logger.warning(f"Some queries missing in PGS Catalog response: {missing_pgs}") + + return dict(zip(pgs_result, url_result)) + + +def query_score(pgs_id: list[str]) -> dict: + pgs: str = ','.join(pgs_id) + api: str = f'https://www.pgscatalog.org/rest/score/search?pgs_ids={pgs}' + r: requests.models.Response = requests.get(api) + return r.json() + + +def _chunker(pgs: list[str]): + size = 50 # /rest/score/{pgs_id} limit when searching multiple IDs + return(pgs[pos: pos + size] for pos in range(0, len(pgs), size)) + + +def _parse_json_query(json: dict, build: str | None) -> dict[str, str]: + result = jq.compile(".results").input(json).first() + if not result: + logger.warning("No results in response from PGS Catalog API. Please check the PGS IDs.") + else: + return _extract_ftp_url(json, build) + + +def _extract_ftp_url(json: list[dict], build: str | None) -> dict[str, str]: + id: list[str] = jq.compile('[.results][][].id').input(json).all() + if build is None: + result: list[str] = jq.compile(f'[.results][][].ftp_scoring_file').input( + json).all() + else: + result: list[str] = jq.compile(f'[.results][][].ftp_harmonized_scoring_files.{build}.positions').input(json).all() + return dict(zip(id, [x.replace('https', 'ftp') for x in result])) diff --git a/pgscatalog_utils/download/trait.py b/pgscatalog_utils/download/trait.py new file mode 100644 index 0000000..981b40d --- /dev/null +++ b/pgscatalog_utils/download/trait.py @@ -0,0 +1,23 @@ +import requests +import logging +from functools import reduce + +logger = logging.getLogger(__name__) + + +def query_trait(trait: str) -> list[str]: + api: str = f'https://www.pgscatalog.org/rest/trait/{trait}?include_children=1' + logger.debug(f"Querying PGS Catalog with trait {trait}") + r: requests.models.Response = requests.get(api) + + if r.json() == {}: + logger.critical(f"Bad response from PGS Catalog for EFO term: {trait}") + raise Exception + + keys: list[str] = ['associated_pgs_ids', 'child_associated_pgs_ids'] + pgs: list[str] = [] + for key in keys: + pgs.append(r.json().get(key)) + + logger.debug(f"Valid response from PGS Catalog for EFO term: {trait}") + return list(reduce(lambda x, y: set(x).union(set(y)), pgs)) diff --git a/pgscatalog_utils/match/match.py b/pgscatalog_utils/match/match.py index b0803de..6a3f70c 100644 --- a/pgscatalog_utils/match/match.py +++ b/pgscatalog_utils/match/match.py @@ -1,62 +1,70 @@ -import polars as pl import logging +import polars as pl + +from pgscatalog_utils.match.postprocess import postprocess_matches from pgscatalog_utils.match.write import write_log logger = logging.getLogger(__name__) -def get_all_matches(scorefile: pl.DataFrame, target: pl.DataFrame) -> pl.DataFrame: +def get_all_matches(scorefile: pl.DataFrame, target: pl.DataFrame, remove_ambiguous: bool, + skip_flip: bool) -> pl.DataFrame: scorefile_cat, target_cat = _cast_categorical(scorefile, target) scorefile_oa = scorefile_cat.filter(pl.col("other_allele") != None) scorefile_no_oa = scorefile_cat.filter(pl.col("other_allele") == None) matches: list[pl.DataFrame] = [] + col_order = ['chr_name', 'chr_position', 'effect_allele', 'other_allele', 'effect_weight', 'effect_type', + 'accession', 'effect_allele_FLIP', 'other_allele_FLIP', + 'ID', 'REF', 'ALT', 'is_multiallelic', 'matched_effect_allele', 'match_type'] if scorefile_oa: logger.debug("Getting matches for scores with effect allele and other allele") - matches.append(_match_variants(scorefile_cat, target_cat, effect_allele='REF', other_allele='ALT', - match_type="refalt")) - matches.append(_match_variants(scorefile_cat, target_cat, effect_allele='ALT', other_allele='REF', - match_type="altref")) - matches.append(_match_variants(scorefile_cat, target_cat, effect_allele='REF_FLIP', - other_allele='ALT_FLIP', - match_type="refalt_flip")) - matches.append(_match_variants(scorefile_cat, target_cat, effect_allele='ALT_FLIP', - other_allele='REF_FLIP', - match_type="altref_flip")) + matches.append(_match_variants(scorefile_cat, target_cat, match_type="refalt").select(col_order)) + matches.append(_match_variants(scorefile_cat, target_cat, match_type="altref").select(col_order)) + if skip_flip is False: + matches.append(_match_variants(scorefile_cat, target_cat, match_type="refalt_flip").select(col_order)) + matches.append(_match_variants(scorefile_cat, target_cat, match_type="altref_flip").select(col_order)) if scorefile_no_oa: logger.debug("Getting matches for scores with effect allele only") - matches.append(_match_variants(scorefile_no_oa, target_cat, effect_allele='REF', other_allele=None, - match_type="no_oa_ref")) - matches.append(_match_variants(scorefile_no_oa, target_cat, effect_allele='ALT', other_allele=None, - match_type="no_oa_alt")) - matches.append(_match_variants(scorefile_no_oa, target_cat, effect_allele='REF_FLIP', - other_allele=None, match_type="no_oa_ref_flip")) - matches.append(_match_variants(scorefile_no_oa, target_cat, effect_allele='ALT_FLIP', - other_allele=None, match_type="no_oa_alt_flip")) + matches.append(_match_variants(scorefile_no_oa, target_cat, match_type="no_oa_ref").select(col_order)) + matches.append(_match_variants(scorefile_no_oa, target_cat, match_type="no_oa_alt").select(col_order)) + if skip_flip is False: + matches.append(_match_variants(scorefile_no_oa, target_cat, match_type="no_oa_ref_flip").select(col_order)) + matches.append(_match_variants(scorefile_no_oa, target_cat, match_type="no_oa_alt_flip").select(col_order)) - return pl.concat(matches) + return pl.concat(matches).pipe(postprocess_matches, remove_ambiguous) -def check_match_rate(scorefile: pl.DataFrame, matches: pl.DataFrame, min_overlap: float, dataset: str) -> None: +def check_match_rate(scorefile: pl.DataFrame, matches: pl.DataFrame, min_overlap: float, dataset: str) -> pl.DataFrame: scorefile: pl.DataFrame = scorefile.with_columns([ pl.col('effect_type').cast(pl.Categorical), pl.col('accession').cast(pl.Categorical)]) # same dtypes for join match_log: pl.DataFrame = _join_matches(matches, scorefile, dataset) - write_log(match_log, dataset) fail_rates: pl.DataFrame = (match_log.groupby('accession') .agg([pl.count(), (pl.col('match_type') == None).sum().alias('no_match')]) .with_column((pl.col('no_match') / pl.col('count')).alias('fail_rate')) ) - + pass_df: pl.DataFrame = pl.DataFrame() for accession, rate in zip(fail_rates['accession'].to_list(), fail_rates['fail_rate'].to_list()): if rate < (1 - min_overlap): - logger.debug(f"Score {accession} passes minimum matching threshold ({1-rate:.2%} variants match)") + df = pl.DataFrame({'accession': [accession], 'match_pass': [True], 'match_rate': [1 - rate]}) + pass_df = pl.concat([pass_df, df]) + logger.debug(f"Score {accession} passes minimum matching threshold ({1 - rate:.2%} variants match)") else: - logger.error(f"Score {accession} fails minimum matching threshold ({1-rate:.2%} variants match)") - raise Exception + df = pl.DataFrame({'accession': [accession], 'match_pass': [False], 'match_rate': [1 - rate]}) + pass_df = pl.concat([pass_df, df]) + logger.error(f"Score {accession} fails minimum matching threshold ({1 - rate:.2%} variants match)") + + # add match statistics to log and matches + write_log((match_log.with_column(pl.col('accession').cast(str)) + .join(pass_df, on='accession', how='left')), dataset) + + return (matches.with_column(pl.col('accession').cast(str)) + .join(pass_df, on='accession', how='left')) + def _match_keys(): @@ -68,32 +76,57 @@ def _join_matches(matches: pl.DataFrame, scorefile: pl.DataFrame, dataset: str): return scorefile.join(matches, on=_match_keys(), how='left').with_column(pl.lit(dataset).alias('dataset')) -def _match_variants(scorefile: pl.DataFrame, - target: pl.DataFrame, - effect_allele: str, - other_allele: str | None, - match_type: str) -> pl.DataFrame: +def _match_variants(scorefile: pl.DataFrame, target: pl.DataFrame, match_type: str) -> pl.DataFrame: logger.debug(f"Matching strategy: {match_type}") - return (scorefile.join(target, - left_on=_scorefile_keys(other_allele), - right_on=_target_keys(effect_allele, other_allele), - how='inner')).pipe(_post_match, effect_allele, other_allele, match_type) - - -def _post_match(df: pl.DataFrame, - effect_allele: str, - other_allele: str, - match_type: str) -> pl.DataFrame: - """ Annotate matches with parameters """ - if other_allele is None: - logger.debug("Dropping missing other_allele during annotation") - other_allele = 'dummy' # prevent trying to alias a column to None + match match_type: + case 'refalt': + score_keys = ["chr_name", "chr_position", "effect_allele", "other_allele"] + target_keys = ["#CHROM", "POS", "REF", "ALT"] + effect_allele_column = "effect_allele" + case 'altref': + score_keys = ["chr_name", "chr_position", "effect_allele", "other_allele"] + target_keys = ["#CHROM", "POS", "ALT", "REF"] + effect_allele_column = "effect_allele" + case 'refalt_flip': + score_keys = ["chr_name", "chr_position", "effect_allele_FLIP", "other_allele_FLIP"] + target_keys = ["#CHROM", "POS", "REF", "ALT"] + effect_allele_column = "effect_allele_FLIP" + case 'altref_flip': + score_keys = ["chr_name", "chr_position", "effect_allele_FLIP", "other_allele_FLIP"] + target_keys = ["#CHROM", "POS", "ALT", "REF"] + effect_allele_column = "effect_allele_FLIP" + case 'no_oa_ref': + score_keys = ["chr_name", "chr_position", "effect_allele"] + target_keys = ["#CHROM", "POS", "REF"] + effect_allele_column = "effect_allele" + case 'no_oa_alt': + score_keys = ["chr_name", "chr_position", "effect_allele"] + target_keys = ["#CHROM", "POS", "ALT"] + effect_allele_column = "effect_allele" + case 'no_oa_ref_flip': + score_keys = ["chr_name", "chr_position", "effect_allele_FLIP"] + target_keys = ["#CHROM", "POS", "REF"] + effect_allele_column = "effect_allele_FLIP" + case 'no_oa_alt_flip': + score_keys = ["chr_name", "chr_position", "effect_allele_FLIP"] + target_keys = ["#CHROM", "POS", "ALT"] + effect_allele_column = "effect_allele_FLIP" + case _: + logger.critical(f"Invalid match strategy: {match_type}") + raise Exception - return df.with_columns([pl.col("*"), - pl.col("effect_allele").alias(effect_allele), - pl.col("other_allele").alias(other_allele), - pl.lit(match_type).alias("match_type") - ])[_matched_colnames()] + missing_cols = ['REF', 'ALT'] + if match_type.startswith('no_oa'): + if match_type.startswith('no_oa_ref'): + missing_cols = ['REF'] + else: + missing_cols = ['ALT'] + join_cols = ['ID'] + missing_cols + return (scorefile.join(target, score_keys, target_keys, how='inner') + .with_columns([pl.col("*"), + pl.col(effect_allele_column).alias("matched_effect_allele"), + pl.lit(match_type).alias("match_type")]) + .join(target.select(join_cols), on="ID", how="inner")) # get REF / ALT back after first join def _cast_categorical(scorefile, target) -> tuple[pl.DataFrame, pl.DataFrame]: @@ -103,33 +136,15 @@ def _cast_categorical(scorefile, target) -> tuple[pl.DataFrame, pl.DataFrame]: pl.col("effect_allele").cast(pl.Categorical), pl.col("other_allele").cast(pl.Categorical), pl.col("effect_type").cast(pl.Categorical), + pl.col("effect_allele_FLIP").cast(pl.Categorical), + pl.col("other_allele_FLIP").cast(pl.Categorical), pl.col("accession").cast(pl.Categorical) ]) if target: target = target.with_columns([ + pl.col("ID").cast(pl.Categorical), pl.col("REF").cast(pl.Categorical), - pl.col("ALT").cast(pl.Categorical), - pl.col("ALT_FLIP").cast(pl.Categorical), - pl.col("REF_FLIP").cast(pl.Categorical) + pl.col("ALT").cast(pl.Categorical) ]) return scorefile, target - - -def _scorefile_keys(other_allele: str) -> list[str]: - if other_allele: - return ['chr_name', 'chr_position', 'effect_allele', 'other_allele'] - else: - return ['chr_name', 'chr_position', 'effect_allele'] - - -def _target_keys(effect_allele: str, other_allele: str) -> list[str]: - if other_allele: - return ['#CHROM', 'POS', effect_allele, other_allele] - else: - return ['#CHROM', 'POS', effect_allele] - - -def _matched_colnames() -> list[str]: - return ['chr_name', 'chr_position', 'effect_allele', 'other_allele', 'effect_weight', 'effect_type', 'accession', - 'ID', 'REF', 'ALT', 'REF_FLIP', 'ALT_FLIP', 'match_type'] diff --git a/pgscatalog_utils/match/match_variants.py b/pgscatalog_utils/match/match_variants.py index 2d2a632..0d31da6 100644 --- a/pgscatalog_utils/match/match_variants.py +++ b/pgscatalog_utils/match/match_variants.py @@ -1,60 +1,186 @@ import argparse import logging +import textwrap +from glob import glob + import polars as pl -from pgscatalog_utils.match.postprocess import postprocess_matches from pgscatalog_utils.log_config import set_logging_level from pgscatalog_utils.match.match import get_all_matches, check_match_rate from pgscatalog_utils.match.read import read_target, read_scorefile from pgscatalog_utils.match.write import write_out +logger = logging.getLogger(__name__) + def match_variants(): args = _parse_args() - logger = logging.getLogger(__name__) set_logging_level(args.verbose) + logger.debug(f"polars n_threads: {pl.threadpool_size()}") scorefile: pl.DataFrame = read_scorefile(path=args.scorefile) - target: pl.DataFrame = read_target(path=args.target, n_threads=args.n_threads, - remove_multiallelic=args.remove_multiallelic) - - dataset = args.dataset.replace('_', '-') # underscores are delimiters in pgs catalog calculator with pl.StringCache(): - matches: pl.DataFrame = get_all_matches(scorefile, target).pipe(postprocess_matches, args.remove_ambiguous) - check_match_rate(scorefile, matches, args.min_overlap, dataset) + n_target_files = len(glob(args.target)) + matches: pl.DataFrame + + if n_target_files == 1 and not args.fast: + match_mode: str = 'single' + elif n_target_files > 1 and not args.fast: + match_mode: str = 'multi' + elif args.fast: + match_mode: str = 'fast' + + match match_mode: + case "single": + logger.debug(f"Match mode: {match_mode}") + matches = _match_single_target(args.target, scorefile, args.remove_multiallelic, args.remove_ambiguous, args.skip_flip) + case "multi": + logger.debug(f"Match mode: {match_mode}") + matches = _match_multiple_targets(args.target, scorefile, args.remove_multiallelic, + args.remove_ambiguous, args.skip_flip) + case "fast": + logger.debug(f"Match mode: {match_mode}") + matches = _fast_match(args.target, scorefile, args.remove_multiallelic, + args.remove_ambiguous, args.skip_flip) + case _: + logger.critical(f"Invalid match mode: {match_mode}") + raise Exception + + dataset = args.dataset.replace('_', '-') # underscores are delimiters in pgs catalog calculator + valid_matches: pl.DataFrame = (check_match_rate(scorefile, matches, args.min_overlap, dataset) + .filter(pl.col('match_pass') == True)) - if matches.shape[0] == 0: # this can happen if args.min_overlap = 0 + if valid_matches.is_empty(): # this can happen if args.min_overlap = 0 logger.error("Error: no target variants match any variants in scoring files") raise Exception - write_out(matches, args.split, args.outdir, dataset) + write_out(valid_matches, args.split, args.outdir, dataset) + + +def _check_target_chroms(target) -> None: + chroms: list[str] = target['#CHROM'].unique().to_list() + if len(chroms) > 1: + logger.critical(f"Multiple chromosomes detected: {chroms}. Check input data.") + raise Exception + else: + logger.debug("Split target genome contains one chromosome (good)") + + +def _fast_match(target_path: str, scorefile: pl.DataFrame, remove_multiallelic: bool, + remove_ambiguous: bool, skip_filp: bool) -> pl.DataFrame: + # fast match is fast because: + # 1) all target files are read into memory + # 2) matching occurs without iterating through chromosomes + target: pl.DataFrame = read_target(path=target_path, + remove_multiallelic=remove_multiallelic) + logger.debug("Split target chromosomes not checked with fast match mode") + return get_all_matches(scorefile, target, remove_ambiguous, skip_filp) + + +def _match_multiple_targets(target_path: str, scorefile: pl.DataFrame, remove_multiallelic: bool, + remove_ambiguous: bool, skip_filp: bool) -> pl.DataFrame: + matches = [] + for i, loc_target_current in enumerate(glob(target_path)): + logger.debug(f'Matching scorefile(s) against target: {loc_target_current}') + target: pl.DataFrame = read_target(path=loc_target_current, + remove_multiallelic=remove_multiallelic) # + _check_target_chroms(target) + matches.append(get_all_matches(scorefile, target, remove_ambiguous, skip_filp)) + return pl.concat(matches) + + +def _match_single_target(target_path: str, scorefile: pl.DataFrame, remove_multiallelic: bool, + remove_ambiguous: bool, skip_filp: bool) -> pl.DataFrame: + matches = [] + for chrom in scorefile['chr_name'].unique().to_list(): + target = read_target(target_path, remove_multiallelic=remove_multiallelic, + single_file=True, chrom=chrom) # scans and filters + if target: + logger.debug(f"Matching chromosome {chrom}") + matches.append(get_all_matches(scorefile, target, remove_ambiguous, skip_filp)) + + return pl.concat(matches) + + +def _description_text() -> str: + return textwrap.dedent('''\ + Match variants from a combined scoring file against a set of + target genomes from the same fileset, and output scoring files + compatible with the plink2 --score function. + + A combined scoring file is the output of the combine_scorefiles + script. It has the following structure: + + | chr_name | chr_position | ... | accession | + | -------- | ------------ | --- | --------- | + | 1 | 1 | ... | PGS000802 | + + The combined scoring file is in long format, with one row per + variant for each scoring file (accession). This structure is + different to the PGS Catalog standard, because the long format + makes matching faster and simpler. + + Target genomes can be in plink1 bim format or plink2 pvar + format. Variant IDs should be unique so that they can be specified + in the scoring file as: variant_id|effect_allele|[effect_weight column(s)...] + + Only one set of target genomes should be matched at a time. Don't + try to match target genomes from different plink filesets. Matching + against a set of chromosomes from the same fileset is OK (see --split). + ''') + + +def _epilog_text() -> str: + return textwrap.dedent('''\ + match_variants will output at least one scoring file in a + format compatible with the plink2 --score function. This + output might be split across different files to ensure each + variant ID, effect allele, and effect type appears only once + in each file. Output files have the pattern: + + {dataset}_{chromosome}_{effect_type}_{n}.scorefile. + + If multiple chromosomes are combined into a single file (i.e. not + --split), then {chromosome} is replaced with 'ALL'. Once the + scorefiles are used to calculate a score with plink2, the .sscore + files will need to be aggregated to calculate a single polygenic + score for each dataset, sample, and accession (scoring file). The + PGS Catalog Calculator does this automatically. + ''') def _parse_args(args=None): - parser = argparse.ArgumentParser(description='Read and format scoring files') + parser = argparse.ArgumentParser(description=_description_text(), epilog=_epilog_text(), + formatter_class=argparse.RawDescriptionHelpFormatter) parser.add_argument('-d', '--dataset', dest='dataset', required=True, - help=' Label for target genomic dataset (e.g. "-d thousand_genomes")') + help=' Label for target genomic dataset') parser.add_argument('-s', '--scorefiles', dest='scorefile', required=True, help=' Combined scorefile path (output of read_scorefiles.py)') parser.add_argument('-t', '--target', dest='target', required=True, help=' A table of target genomic variants (.bim format)') + parser.add_argument('-f', '--fast', dest='fast', action='store_true', + help=' Enable faster matching at the cost of increased RAM usage') parser.add_argument('--split', dest='split', default=False, action='store_true', help=' Split scorefile per chromosome?') parser.add_argument('--outdir', dest='outdir', required=True, help=' Output directory') - parser.add_argument('-n', '--n_threads', dest='n_threads', default=1, type=int, - help=' Number of threads used to match (default = 1)') parser.add_argument('-m', '--min_overlap', dest='min_overlap', required=True, type=float, help=' Minimum proportion of variants to match before error') parser.add_argument('--keep_ambiguous', dest='remove_ambiguous', action='store_false', - help='Flag to force the program to keep variants with ambiguous alleles, (e.g. A/T and G/C ' - 'SNPs), which are normally excluded (default: false). In this case the program proceeds ' - 'assuming that the genotype data is on the same strand as the GWAS whose summary ' - 'statistics were used to construct the score.'), + help=''' Flag to force the program to keep variants with + ambiguous alleles, (e.g. A/T and G/C SNPs), which are normally + excluded (default: false). In this case the program proceeds + assuming that the genotype data is on the same strand as the + GWAS whose summary statistics were used to construct the score. + ''') parser.add_argument('--keep_multiallelic', dest='remove_multiallelic', action='store_false', - help='Flag to allow matching to multiallelic variants (default: false).') + help=' Flag to allow matching to multiallelic variants (default: false).') + parser.add_argument('--ignore_strand_flips', dest='skip_flip', action='store_true', + help=''' Flag to not consider matched variants that may be reported + on the opposite strand. Default behaviour is to flip/complement unmatched variants and check if + they match.''') parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help=' Extra logging information') return parser.parse_args(args) @@ -62,8 +188,3 @@ def _parse_args(args=None): if __name__ == "__main__": match_variants() - - -# join matches and scorefile with keys depending on liftover -# count match type column -# matches.groupby('accession').agg([pl.count(), (pl.col('match_type') == None).sum().alias('no_match')]) diff --git a/pgscatalog_utils/match/postprocess.py b/pgscatalog_utils/match/postprocess.py index adb4932..33a0220 100644 --- a/pgscatalog_utils/match/postprocess.py +++ b/pgscatalog_utils/match/postprocess.py @@ -1,6 +1,9 @@ +from functools import reduce import polars as pl import logging +from pgscatalog_utils.match.preprocess import complement_valid_alleles + logger = logging.getLogger(__name__) @@ -11,43 +14,73 @@ def postprocess_matches(df: pl.DataFrame, remove_ambiguous: bool) -> pl.DataFram return df.filter(pl.col("ambiguous") == False) else: logger.debug("Keeping best possible match from ambiguous matches") - # pick the best possible match from the ambiguous matches - # EA = REF and OA = ALT or EA = REF and OA = None ambiguous: pl.DataFrame = df.filter((pl.col("ambiguous") == True) & \ - (pl.col("match_type") == "refalt") | - (pl.col("ambiguous") == True) & \ - (pl.col("match_type") == "no_oa_ref")) + (pl.col("match_type").str.contains('flip').is_not())) unambiguous: pl.DataFrame = df.filter(pl.col("ambiguous") == False) return pl.concat([ambiguous, unambiguous]) def _label_biallelic_ambiguous(df: pl.DataFrame) -> pl.DataFrame: - # A / T or C / G may match multiple times + logger.debug("Labelling ambiguous variants") df = df.with_columns([ - pl.col(["effect_allele", "other_allele", "REF", "ALT", "REF_FLIP", "ALT_FLIP"]).cast(str), + pl.col(["effect_allele", "other_allele", "REF", "ALT", "effect_allele_FLIP", "other_allele_FLIP"]).cast(str), pl.lit(True).alias("ambiguous") - ]) + ]).pipe(complement_valid_alleles, ["REF"]) return (df.with_column( - pl.when((pl.col("effect_allele") == pl.col("ALT_FLIP")) | (pl.col("effect_allele") == pl.col("REF_FLIP"))) + pl.when(pl.col("REF_FLIP") == pl.col("ALT")) .then(pl.col("ambiguous")) .otherwise(False))).pipe(_get_distinct_weights) def _get_distinct_weights(df: pl.DataFrame) -> pl.DataFrame: - """ Get a single effect weight for each matched variant per accession """ + """ Select single matched variant in target for each variant in the scoring file (e.g. per accession) """ count: pl.DataFrame = df.groupby(['accession', 'chr_name', 'chr_position', 'effect_allele']).count() singletons: pl.DataFrame = (count.filter(pl.col('count') == 1)[:, "accession":"effect_allele"] .join(df, on=['accession', 'chr_name', 'chr_position', 'effect_allele'], how='left')) - # TODO: something more complex than .unique()? - # TODO: prioritise unambiguous -> ref -> alt -> ref_flip -> alt_flip dups: pl.DataFrame = (count.filter(pl.col('count') > 1)[:, "accession":"effect_allele"] - .join(df, on=['accession', 'chr_name', 'chr_position', 'effect_allele'], how='left') - .distinct(subset=['accession', 'chr_name', 'chr_position', 'effect_allele'])) - distinct: pl.DataFrame = pl.concat([singletons, dups]) + .join(df, on=['accession', 'chr_name', 'chr_position', 'effect_allele'], how='left')) + + if dups: + distinct: pl.DataFrame = pl.concat([singletons, _prioritise_match_type(dups)]) + else: + distinct: pl.DataFrame = singletons - assert all((distinct.groupby(['accession', 'chr_name', 'chr_position', 'effect_allele']).count()['count']) == 1), \ - "Duplicate effect weights for a variant" + assert all(distinct.groupby(['accession', 'ID']).count()['count'] == 1), "Duplicate effect weights for a variant" return distinct + + +def _prioritise_match_type(duplicates: pl.DataFrame) -> pl.DataFrame: + dup_oa: pl.DataFrame = duplicates.filter(pl.col("other_allele") != None) + dup_no_oa: pl.DataFrame = duplicates.filter(pl.col("other_allele") == None) + best_matches: list[pl.DataFrame] = [] + + if dup_oa: + match_priority: list[str] = ['refalt', 'altref', 'refalt_flip', 'altref_flip'] + logger.debug(f"Prioritising matches in order {match_priority}") + best_matches.append(_get_best_match(dup_oa, match_priority)) + + if dup_no_oa: + match_priority: list[str] = ['no_oa_ref', 'no_oa_alt', 'no_oa_ref_flip', 'no_oa_alt_flip'] + logger.debug(f"Prioritising matches in order {match_priority}") + best_matches.append(_get_best_match(dup_no_oa, match_priority)) + + return pl.concat(best_matches) + + +def _get_best_match(df: pl.DataFrame, match_priority: list[str]) -> pl.DataFrame: + match: list[pl.DataFrame] = [] + for match_type in match_priority: + match.append(df.filter(pl.col("match_type") == match_type)) + logger.debug("Filtering best match types") + return reduce(lambda x, y: _join_best_match(x, y), match) + + +def _join_best_match(x: pl.DataFrame, y: pl.DataFrame) -> pl.DataFrame: + # variants in dataframe x have a higher priority than dataframe y + # when concatenating the two dataframes, use an anti join to first remove variants in y that are in x + not_in: pl.DataFrame = y.join(x, how='anti', + on=['accession', 'chr_name', 'chr_position', 'effect_allele', 'other_allele']) + return pl.concat([x, not_in]) diff --git a/pgscatalog_utils/match/preprocess.py b/pgscatalog_utils/match/preprocess.py index be64d7e..3cc66f7 100644 --- a/pgscatalog_utils/match/preprocess.py +++ b/pgscatalog_utils/match/preprocess.py @@ -4,40 +4,44 @@ logger = logging.getLogger(__name__) -def ugly_complement(df: pl.DataFrame) -> pl.DataFrame: - """ Complementing alleles with a pile of regexes seems weird, but polars string functions are currently limited - (i.e. no str.translate). This is fast, and I stole the regex idea from Scott. +def complement_valid_alleles(df: pl.DataFrame, flip_cols: list[str]) -> pl.DataFrame: + """ Improved function to complement alleles. Will only complement sequences that are valid DNA. """ - logger.debug("Complementing target alleles") - return df.with_columns([ - (pl.col("REF").str.replace_all("A", "V") - .str.replace_all("T", "X") - .str.replace_all("C", "Y") - .str.replace_all("G", "Z") - .str.replace_all("V", "T") - .str.replace_all("X", "A") - .str.replace_all("Y", "G") - .str.replace_all("Z", "C")) - .alias("REF_FLIP"), - (pl.col("ALT").str.replace_all("A", "V") - .str.replace_all("T", "X") - .str.replace_all("C", "Y") - .str.replace_all("G", "Z") - .str.replace_all("V", "T") - .str.replace_all("X", "A") - .str.replace_all("Y", "G") - .str.replace_all("Z", "C")) - .alias("ALT_FLIP") - ]) - - -def handle_multiallelic(df: pl.DataFrame, remove_multiallelic: bool) -> pl.DataFrame: - is_ma: pl.Series = df['ALT'].str.contains(',') # plink2 pvar multi-alleles are comma-separated - if is_ma.sum() > 0: + for col in flip_cols: + logger.debug(f"Complementing column {col}") + new_col = col + '_FLIP' + df = df.with_column( + pl.when(pl.col(col).str.contains('^[ACGT]+$')) + .then(pl.col(col).str.replace_all("A", "V") + .str.replace_all("T", "X") + .str.replace_all("C", "Y") + .str.replace_all("G", "Z") + .str.replace_all("V", "T") + .str.replace_all("X", "A") + .str.replace_all("Y", "G") + .str.replace_all("Z", "C")) + .otherwise(pl.col(col)) + .alias(new_col) + ) + return df + + +def handle_multiallelic(df: pl.DataFrame, remove_multiallelic: bool, pvar: bool) -> pl.DataFrame: + # plink2 pvar multi-alleles are comma-separated + df: pl.DataFrame = (df.with_column( + pl.when(pl.col("ALT").str.contains(',')) + .then(pl.lit(True)) + .otherwise(pl.lit(False)) + .alias('is_multiallelic'))) + + if df['is_multiallelic'].sum() > 0: logger.debug("Multiallelic variants detected") if remove_multiallelic: + if not pvar: + logger.warning("--remove_multiallelic requested for bim format, which already contains biallelic " + "variant representations only") logger.debug('Dropping multiallelic variants') - return df[~is_ma] + return df[~df['is_multiallelic']] else: logger.debug("Exploding dataframe to handle multiallelic variants") df.replace('ALT', df['ALT'].str.split(by=',')) # turn ALT to list of variants @@ -48,8 +52,12 @@ def handle_multiallelic(df: pl.DataFrame, remove_multiallelic: bool) -> pl.DataF def check_weights(df: pl.DataFrame) -> None: - weight_count = df.groupby(['accession', 'chr_name', 'chr_position', 'effect_allele']).count()['count'] - - if any(weight_count > 1): - logger.error("Multiple effect weights per variant per accession detected") + """ Checks weights for scoring file variants that could be matched (e.g. have a chr & pos) """ + weight_count = df.filter(pl.col('chr_name').is_not_null() & pl.col('chr_position').is_not_null()).groupby(['accession', 'chr_name', 'chr_position', 'effect_allele']).count() + if any(weight_count['count'] > 1): + logger.error("Multiple effect weights per variant per accession detected in files: {}".format(list(weight_count.filter(pl.col('count') > 1)['accession'].unique()))) raise Exception + + +def _annotate_multiallelic(df: pl.DataFrame) -> pl.DataFrame: + df.with_column(pl.when(pl.col("ALT").str.contains(',')).then(pl.lit(True)).otherwise(pl.lit(False)).alias('is_multiallelic')) \ No newline at end of file diff --git a/pgscatalog_utils/match/read.py b/pgscatalog_utils/match/read.py index 138855d..edb69b5 100644 --- a/pgscatalog_utils/match/read.py +++ b/pgscatalog_utils/match/read.py @@ -1,26 +1,43 @@ -import polars as pl -import logging import glob +import logging from typing import NamedTuple -from pgscatalog_utils.match.preprocess import ugly_complement, handle_multiallelic, check_weights + +import polars as pl + +from pgscatalog_utils.match.preprocess import handle_multiallelic, check_weights, complement_valid_alleles logger = logging.getLogger(__name__) -def read_target(path: str, n_threads: int, remove_multiallelic: bool) -> pl.DataFrame: +def read_target(path: str, remove_multiallelic: bool, single_file: bool = False, + chrom: str = "") -> pl.DataFrame: target: Target = _detect_target_format(path) d = {'column_1': str} # column_1 is always CHROM. CHROM must always be a string - df: pl.DataFrame = pl.read_csv(path, sep='\t', has_header=False, comment_char='#', dtype=d, n_threads=n_threads) + + if single_file: + logger.debug(f"Scanning target genome for chromosome {chrom}") + # scan target and filter to reduce memory usage on big files + df: pl.DataFrame = ( + pl.scan_csv(path, sep='\t', has_header=False, comment_char='#', dtype=d) + .filter(pl.col('column_1') == chrom) + .collect()) + + if df.is_empty(): + logger.warning(f"Chromosome missing from target genome: {chrom}") + return df + else: + logger.debug(f"Reading target {path}") + df: pl.DataFrame = pl.read_csv(path, sep='\t', has_header=False, comment_char='#', dtype=d) + df.columns = target.header match target.file_format: case 'bim': return (df[_default_cols()] - .pipe(ugly_complement)) + .pipe(handle_multiallelic, remove_multiallelic=remove_multiallelic, pvar=False)) case 'pvar': return (df[_default_cols()] - .pipe(handle_multiallelic, remove_multiallelic=remove_multiallelic) - .pipe(ugly_complement)) + .pipe(handle_multiallelic, remove_multiallelic=remove_multiallelic, pvar=True)) case _: logger.error("Invalid file format detected") raise Exception @@ -28,7 +45,8 @@ def read_target(path: str, n_threads: int, remove_multiallelic: bool) -> pl.Data def read_scorefile(path: str) -> pl.DataFrame: logger.debug("Reading scorefile") - scorefile: pl.DataFrame = pl.read_csv(path, sep='\t', dtype={'chr_name': str}) + scorefile: pl.DataFrame = (pl.read_csv(path, sep='\t', dtype={'chr_name': str}) + .pipe(complement_valid_alleles, flip_cols=['effect_allele', 'other_allele'])) check_weights(scorefile) return scorefile @@ -44,7 +62,7 @@ def _detect_target_format(path: str) -> Target: header: list[str] if "*" in path: - logger.debug("Wildcard detected in target path, guessing format from first match") + logger.debug("Detecting target file format") path = glob.glob(path)[0] # guess format from first file in directory with open(path, 'rt') as f: @@ -79,5 +97,3 @@ def _pvar_header(path: str) -> list[str]: def _bim_header() -> list[str]: return ['#CHROM', 'ID', 'CM', 'POS', 'REF', 'ALT'] - - diff --git a/pgscatalog_utils/match/write.py b/pgscatalog_utils/match/write.py index 45b74dc..110e308 100644 --- a/pgscatalog_utils/match/write.py +++ b/pgscatalog_utils/match/write.py @@ -6,12 +6,17 @@ def write_out(df: pl.DataFrame, split: bool, outdir: str, dataset: str) -> None: + if not os.path.isdir(outdir): + os.mkdir(outdir) + logger.debug("Splitting by effect type") effect_types: dict[str, pl.DataFrame] = _split_effect_type(df) + logger.debug("Deduplicating variants") - deduplicated: dict[str, pl.DataFrame] = {k: _deduplicate_variants(v) for k, v in effect_types.items()} - ea_dict: dict[str, str] = {'is_dominant': 'dominant', 'is_recessive': 'recessive', 'additive': 'additive'} + deduplicated: dict[str, pl.DataFrame] = {k: _deduplicate_variants(k, v) for k, v in effect_types.items()} + logger.debug("Writing out scorefiles") + ea_dict: dict[str, str] = {'is_dominant': 'dominant', 'is_recessive': 'recessive', 'additive': 'additive'} [_write_scorefile(ea_dict.get(k), v, split, outdir, dataset) for k, v in deduplicated.items()] @@ -21,8 +26,6 @@ def write_log(df: pl.DataFrame, dataset: str) -> None: def _write_scorefile(effect_type: str, scorefiles: pl.DataFrame, split: bool, outdir: str, dataset: str) -> None: """ Write a list of scorefiles with the same effect type """ - fout: str = '{dataset}_{chr}_{et}_{split}.scorefile' - # each list element contains a dataframe of variants # lists are split to ensure variants have unique ID - effect alleles for i, scorefile in enumerate(scorefiles): @@ -47,13 +50,15 @@ def _format_scorefile(df: pl.DataFrame, split: bool) -> dict[str, pl.DataFrame]: logger.debug("Split output requested") chroms: list[int] = df["chr_name"].unique().to_list() return {x: (df.filter(pl.col("chr_name") == x) - .pivot(index=["ID", "effect_allele"], values="effect_weight", columns="accession") - .pipe(_fill_null)) + .pivot(index=["ID", "matched_effect_allele"], values="effect_weight", columns="accession") + .rename({"matched_effect_allele": "effect_allele"}) + .fill_null(strategy="zero")) for x in chroms} else: logger.debug("Split output not requested") - formatted: pl.DataFrame = (df.pivot(index=["ID", "effect_allele"], values="effect_weight", columns="accession") - .pipe(_fill_null)) + formatted: pl.DataFrame = (df.pivot(index=["ID", "matched_effect_allele"], values="effect_weight", columns="accession") + .rename({"matched_effect_allele": "effect_allele"}) + .fill_null(strategy="zero")) return {'false': formatted} @@ -63,7 +68,7 @@ def _split_effect_type(df: pl.DataFrame) -> dict[str, pl.DataFrame]: return {x: df.filter(pl.col("effect_type") == x) for x in effect_types} -def _deduplicate_variants(df: pl.DataFrame) -> list[pl.DataFrame]: +def _deduplicate_variants(effect_type: str, df: pl.DataFrame) -> list[pl.DataFrame]: """ Find variant matches that have duplicate identifiers When merging a lot of scoring files, sometimes a variant might be duplicated this can happen when the effect allele differs at the same position, e.g.: @@ -101,20 +106,10 @@ def _deduplicate_variants(df: pl.DataFrame) -> list[pl.DataFrame]: df_lst.append(x) if len(df_lst) > 1: - logger.debug("Duplicate variant identifiers split") + logger.debug(f"Duplicate variant identifiers split for effect type {effect_type}") else: - logger.debug("No duplicate variant identifiers found") + logger.debug(f"No duplicate variant identifiers found for effect type {effect_type}") assert n_var == df.shape[0] return df_lst - - -def _fill_null(df): - # nulls are created when pivoting wider - if any(df.null_count() > 0): - logger.debug("Filling null weights with zero after pivoting wide") - return df.fill_null(0) - else: - logger.debug("No null weights detected") - return df diff --git a/pgscatalog_utils/scorefile/combine_scorefiles.py b/pgscatalog_utils/scorefile/combine_scorefiles.py index 5ec21d4..35d9b85 100644 --- a/pgscatalog_utils/scorefile/combine_scorefiles.py +++ b/pgscatalog_utils/scorefile/combine_scorefiles.py @@ -1,6 +1,8 @@ import argparse -import sys import logging +import sys +import textwrap + import pandas as pd from pgscatalog_utils.log_config import set_logging_level @@ -11,36 +13,15 @@ from pgscatalog_utils.scorefile.write import write_scorefile -def parse_args(args=None) -> argparse.Namespace: - parser: argparse.ArgumentParser = argparse.ArgumentParser(description='Combine multiple scoring files') - parser.add_argument('-s', '--scorefiles', dest='scorefiles', nargs='+', - help=' Scorefile path (wildcard * is OK)', required=True) - parser.add_argument('--liftover', dest='liftover', - help=' Convert scoring file variants to target genome build?', action='store_true') - parser.add_argument('-t', '--target_build', dest='target_build', help='Build of target genome ', - required='--liftover' in sys.argv) - parser.add_argument('-c', '--chain_dir', dest='chain_dir', help='Path to directory containing chain files', - required="--liftover" in sys.argv) - parser.add_argument('-m', '--min_lift', dest='min_lift', - help='If liftover, minimum proportion of variants lifted over', - required="--liftover" in sys.argv, default=0.95, type=float) - parser.add_argument('-o', '--outfile', dest='outfile', required=True, - default='combined.txt', - help=' Output path to combined long scorefile') - parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', - help=' Extra logging information') - return parser.parse_args(args) - - def combine_scorefiles(): - args = parse_args() + args = _parse_args() logger = logging.getLogger(__name__) set_logging_level(args.verbose) paths: list[str] = list(set(args.scorefiles)) # unique paths only logger.debug(f"Input scorefiles: {paths}") - scorefiles: pd.DataFrame = pd.concat([_read_and_melt(x) for x in paths]) + scorefiles: pd.DataFrame = pd.concat([_read_and_melt(x, drop_missing=args.drop_missing) for x in paths]) if args.liftover: logger.debug("Annotating scorefiles with liftover parameters") @@ -49,12 +30,55 @@ def combine_scorefiles(): write_scorefile(scorefiles, args.outfile) -def _read_and_melt(path): +def _read_and_melt(path, drop_missing: bool = False): """ Load a scorefile, melt it, and set the effect types""" - return (load_scorefile(path) + return (load_scorefile(path, drop_missing=drop_missing) .pipe(melt_effect_weights) .pipe(set_effect_type)) if __name__ == "__main__": combine_scorefiles() + + +def _description_text() -> str: + return textwrap.dedent('''\ + Combine multiple scoring files in PGS Catalog format (see + https://www.pgscatalog.org/downloads/ for details) to a 'long' + table, and optionally liftover genomic coordinates to GRCh37 or + GRCh38. Custom scorefiles in PGS Catalog format can be combined + with PGS Catalog scoring files. The program can accept a mix of + unharmonised and harmonised PGS Catalog data. + ''') + + +def _epilog_text() -> str: + return textwrap.dedent('''\ + The long table is used to simplify intersecting variants in target + genomes and the scoring files with the match_variants program. + ''') + + +def _parse_args(args=None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=_description_text(), epilog=_epilog_text(), + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument('-s', '--scorefiles', dest='scorefiles', nargs='+', + help=' Scorefile path (wildcard * is OK)', required=True) + parser.add_argument('--liftover', dest='liftover', + help=' Convert scoring file variants to target genome build?', action='store_true') + parser.add_argument('-t', '--target_build', dest='target_build', help='Build of target genome ', + required='--liftover' in sys.argv) + parser.add_argument('-c', '--chain_dir', dest='chain_dir', help='Path to directory containing chain files', + required="--liftover" in sys.argv) + parser.add_argument('-m', '--min_lift', dest='min_lift', + help='If liftover, minimum proportion of variants lifted over', + required="--liftover" in sys.argv, default=0.95, type=float) + parser.add_argument('--drop_missing', dest='drop_missing', action='store_true', + help='Drop variants with missing information (chr/pos) and ' + 'non-standard alleles from the output file.') + parser.add_argument('-o', '--outfile', dest='outfile', required=True, + default='combined.txt', + help=' Output path to combined long scorefile') + parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', + help=' Extra logging information') + return parser.parse_args(args) diff --git a/pgscatalog_utils/scorefile/liftover.py b/pgscatalog_utils/scorefile/liftover.py index 6390afb..0d3008c 100644 --- a/pgscatalog_utils/scorefile/liftover.py +++ b/pgscatalog_utils/scorefile/liftover.py @@ -62,12 +62,14 @@ def _check_min_liftover(mapped: pd.DataFrame, unmapped: pd.DataFrame, min_lift: def _convert_coordinates(df: pd.Series, lo_dict: dict[str, pyliftover.LiftOver]) -> pd.Series: """ Convert genomic coordinates to different build """ - - lo = lo_dict[df['genome_build'] + df['target_build']] # extract lo object from dict - chrom: str = 'chr' + str(df['chr_name']) - pos: int = int(df['chr_position']) - 1 # liftOver is 0 indexed, VCF is 1 indexed - # converted example: [('chr22', 15460378, '+', 3320966530)] or None - converted: list[tuple[str, int, str, int] | None] = lo.convert_coordinate(chrom, pos) + if df[['chr_name', 'chr_position']].isnull().values.any(): + converted = None + else: + lo = lo_dict[df['genome_build'] + df['target_build']] # extract lo object from dict + chrom: str = 'chr' + str(df['chr_name']) + pos: int = int(df['chr_position']) - 1 # liftOver is 0 indexed, VCF is 1 indexed + # converted example: [('chr22', 15460378, '+', 3320966530)] or None + converted: list[tuple[str, int, str, int] | None] = lo.convert_coordinate(chrom, pos) if converted: lifted_chrom: str = _parse_lifted_chrom(converted[0][0][3:]) # return first matching liftover diff --git a/pgscatalog_utils/scorefile/qc.py b/pgscatalog_utils/scorefile/qc.py index fe2b725..4316f1e 100644 --- a/pgscatalog_utils/scorefile/qc.py +++ b/pgscatalog_utils/scorefile/qc.py @@ -4,15 +4,19 @@ logger = logging.getLogger(__name__) -def quality_control(df: pd.DataFrame) -> pd.DataFrame: +def quality_control(df: pd.DataFrame, drop_missing: bool) -> pd.DataFrame: """ Do quality control checks on a scorefile """ _check_shape(df) _check_columns(df) logger.debug("Quality control: checking for bad variants") - return (df.pipe(_drop_hla) - .pipe(_drop_missing_variants) - .pipe(_check_duplicate_identifiers) - .pipe(_drop_multiple_oa)) + if drop_missing is True: + return (df.pipe(_drop_hla) + .pipe(_drop_missing_variants) + .pipe(_check_duplicate_identifiers) + .pipe(_drop_multiple_oa)) + else: + return (df.pipe(_check_duplicate_identifiers) + .pipe(_drop_multiple_oa)) def _drop_multiple_oa(df: pd.DataFrame) -> pd.DataFrame: diff --git a/pgscatalog_utils/scorefile/read.py b/pgscatalog_utils/scorefile/read.py index 43fe176..7674c7c 100644 --- a/pgscatalog_utils/scorefile/read.py +++ b/pgscatalog_utils/scorefile/read.py @@ -7,13 +7,13 @@ logger = logging.getLogger(__name__) -def load_scorefile(path: str, use_harmonised: bool = True) -> pd.DataFrame: +def load_scorefile(path: str, use_harmonised: bool = True, drop_missing: bool = False) -> pd.DataFrame: logger.debug(f'Reading scorefile {path}') return (pd.read_table(path, dtype=_scorefile_dtypes(), comment='#', na_values=['None'], low_memory=False) .pipe(remap_harmonised, use_harmonised=use_harmonised) .assign(filename_prefix=_get_basename(path), filename=path) - .pipe(quality_control)) + .pipe(quality_control, drop_missing=drop_missing)) def _scorefile_dtypes() -> dict[str]: diff --git a/pgscatalog_utils/scorefile/write.py b/pgscatalog_utils/scorefile/write.py index a2f04a8..9204096 100644 --- a/pgscatalog_utils/scorefile/write.py +++ b/pgscatalog_utils/scorefile/write.py @@ -22,7 +22,6 @@ def write_scorefile(df: pd.DataFrame, path: str) -> None: logger.warning("No other allele information detected, writing out as missing data") out_df['other_allele'] = None - _write_log(out_df) out_df[cols].to_csv(path, index=False, sep="\t") @@ -33,20 +32,3 @@ def _filter_failed_liftover(df: pd.DataFrame) -> pd.DataFrame: else: return df - -def _write_log(df: pd.DataFrame) -> None: - logger.debug("Writing log to local database") - conn: sqlite3.Connection = sqlite3.connect('scorefiles.db') - - if 'liftover' not in df: - df = df.assign(liftover=None, lifted_chr=None, lifted_pos=None) - - cols: list[str] = ['chr_name', 'chr_position', 'effect_allele', 'other_allele', 'effect_weight', 'effect_type', - 'accession', 'liftover', 'lifted_chr', 'lifted_pos'] - - # change some column types for sqlite - # nullable_ints: list[str] = ['liftover', 'lifted_chr', 'lifted_pos'] - # df[nullable_ints] = df[nullable_ints].astype(pd.Int64Dtype()) - df['other_allele'] = df['other_allele'].astype(str) - df[cols].to_sql('scorefile', conn, if_exists='replace') - conn.close() diff --git a/poetry.lock b/poetry.lock index 28f0366..e920a73 100644 --- a/poetry.lock +++ b/poetry.lock @@ -49,7 +49,7 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" [[package]] name = "coverage" -version = "6.4.2" +version = "6.4.3" description = "Code coverage measurement for Python" category = "dev" optional = false @@ -134,19 +134,19 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "polars" -version = "0.13.59" +version = "0.13.62" description = "Blazingly fast DataFrame library" category = "main" optional = false python-versions = ">=3.7" [package.extras] -connectorx = ["connectorx"] -numpy = ["numpy (>=1.16.0)"] fsspec = ["fsspec"] -pandas = ["pyarrow (>=4.0)", "pandas"] xlsx2csv = ["xlsx2csv (>=0.8.0)"] +connectorx = ["connectorx"] +pandas = ["pyarrow (>=4.0)", "pandas"] pyarrow = ["pyarrow (>=4.0)"] +numpy = ["numpy (>=1.16.0)"] [[package]] name = "py" diff --git a/pyproject.toml b/pyproject.toml index b3968f7..44ef233 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [tool.poetry] name = "pgscatalog_utils" version = "0.1.1" -description = "Some useful utilities for working with PGS Catalog data" +description = "Utilities for working with PGS Catalog API and scoring files" homepage = "https://github.com/PGScatalog/pgscatalog_utils" -authors = ["Benjamin Wingfield "] +authors = ["Benjamin Wingfield ", "Samuel Lambert "] license = "Apache-2.0" readme = "README.md" diff --git a/tests/test_combine.py b/tests/test_combine.py index 86fb824..6243cef 100644 --- a/tests/test_combine.py +++ b/tests/test_combine.py @@ -1,14 +1,24 @@ import pandas as pd +import pytest +import jq +from pgscatalog_utils.download.score import query_score -def test_combine_scorefiles(combined_scorefile): + +def test_combine_scorefiles(combined_scorefile, _n_variants): df = pd.read_table(combined_scorefile) cols = {'chr_name', 'chr_position', 'effect_allele', 'other_allele', 'effect_weight', 'effect_type', 'accession'} assert set(df.columns).issubset(cols) - assert df.shape[0] == 51215 # combined number of variants + assert df.shape[0] == _n_variants def test_liftover(lifted_scorefiles): df = pd.read_table(lifted_scorefiles) assert df.shape[0] > 50000 # approx size + +@pytest.fixture +def _n_variants(pgs_accessions): + json = query_score(pgs_accessions) + n: list[int] = jq.compile("[.results][][].variants_number").input(json).all() + return sum(n) diff --git a/tests/test_download.py b/tests/test_download.py index 78f1c83..611740e 100644 --- a/tests/test_download.py +++ b/tests/test_download.py @@ -1,13 +1,16 @@ import os import pytest from unittest.mock import patch -from pgscatalog_utils.download.api import pgscatalog_result + +from pgscatalog_utils.download.trait import query_trait +from pgscatalog_utils.download.publication import query_publication +from pgscatalog_utils.download.score import get_url from pgscatalog_utils.download.download_scorefile import download_scorefile @pytest.fixture(params=[["PGS000001"], ["PGS000001", "PGS000802"]]) def pgscatalog_api(request): - return pgscatalog_result(request.param, "GRCh37") + return get_url(request.param, "GRCh37") def test_pgscatalog_result(pgscatalog_api): @@ -21,11 +24,28 @@ def test_pgscatalog_result(pgscatalog_api): assert v.endswith(".txt.gz") -def test_download_scorefile(tmp_path): +def test_download_scorefile_author(tmp_path): out_dir = str(tmp_path.resolve()) - args: list[str] = ['download_scorefiles', '-i', 'PGS000001', '-b', 'GRCh38', '-o', out_dir] + args: list[str] = ['download_scorefiles', '-i', 'PGS000001', '-o', out_dir] with patch('sys.argv', args): download_scorefile() assert os.listdir(out_dir) == ['PGS000001.txt.gz'] +def test_download_scorefile_hmPOS(tmp_path): + out_dir = str(tmp_path.resolve()) + args: list[str] = ['download_scorefiles', '-i', 'PGS000001', '-b', 'GRCh38', '-o', out_dir] + + with patch('sys.argv', args): + download_scorefile() + assert os.listdir(out_dir) == ['PGS000001_hmPOS_GRCh38.txt.gz'] + + +def test_query_publication(): + # publications are relatively static + assert not set(query_publication("PGP000001")).difference(['PGS000001', 'PGS000002', 'PGS000003']) + + +def test_query_trait(): + # new scores may be added to traits in the future + assert {'PGS001901', 'PGS002115'}.issubset(set(query_trait("EFO_0004329"))) diff --git a/tests/test_match.py b/tests/test_match.py index 596461a..6f3394d 100644 --- a/tests/test_match.py +++ b/tests/test_match.py @@ -1,9 +1,11 @@ import os from unittest.mock import patch -import pandas as pd +import polars as pl import pytest +from pgscatalog_utils.match.match import get_all_matches, _cast_categorical from pgscatalog_utils.match.match_variants import match_variants +from pgscatalog_utils.match.preprocess import complement_valid_alleles def test_match_fail(combined_scorefile, target_path, tmp_path): @@ -11,7 +13,7 @@ def test_match_fail(combined_scorefile, target_path, tmp_path): args: list[str] = ['match_variants', '-s', combined_scorefile, '-t', target_path, - '-m', 0, + '-m', 1, '-d', 'test', '--outdir', out_dir, '--keep_ambiguous', '--keep_multiallelic'] @@ -34,3 +36,91 @@ def test_match_pass(mini_scorefile, target_path, tmp_path): with patch('sys.argv', args): match_variants() + +def _cast_cat(scorefile, target): + with pl.StringCache(): + return _cast_categorical(scorefile, target) + + +def test_match_strategies(small_scorefile, small_target): + scorefile, target = _cast_cat(small_scorefile, small_target) + + # check unambiguous matches + df = get_all_matches(scorefile, target, remove_ambiguous=True, skip_flip=True) + assert set(df['ID'].to_list()).issubset({'3:3:T:G', '1:1:A:C'}) + assert set(df['match_type'].to_list()).issubset(['altref', 'refalt']) + + # when keeping ambiguous and flipping alleles: + # 2:2:T:A is ambiguous, and matches 'altref' and 'refalt_flip' + # flipped matches should be dropped for ambiguous matches + flip = (get_all_matches(scorefile, target, remove_ambiguous=False, skip_flip=False)\ + .filter(pl.col('ambiguous') == True)) + assert set(flip['ID'].to_list()).issubset({'2:2:T:A'}) + assert set(flip['match_type'].to_list()).issubset({'altref'}) + + +def test_no_oa_match(small_scorefile_no_oa, small_target): + scorefile, target = _cast_cat(small_scorefile_no_oa, small_target) + + df = get_all_matches(scorefile, target, remove_ambiguous=True,skip_flip=True) + assert set(df['ID'].to_list()).issubset(['3:3:T:G', '1:1:A:C']) + assert set(df['match_type'].to_list()).issubset(['no_oa_alt', 'no_oa_ref']) + + # one of the matches is ambiguous + flip = (get_all_matches(scorefile, target, remove_ambiguous=False, skip_flip=False) + .filter(pl.col('ambiguous') == True)) + assert set(flip['ID'].to_list()).issubset({'2:2:T:A'}) + assert set(flip['match_type'].to_list()).issubset({'no_oa_alt'}) + + +def test_flip_match(small_flipped_scorefile, small_target): + scorefile, target = _cast_cat(small_flipped_scorefile, small_target) + + df = get_all_matches(scorefile, target, remove_ambiguous=True, skip_flip=True) + assert df.is_empty() + + flip = get_all_matches(scorefile, target, remove_ambiguous=True, skip_flip=False) + assert flip['match_type'].str.contains('flip').all() + assert set(flip['ID'].to_list()).issubset(['3:3:T:G', '1:1:A:C']) + + flip_ambig = (get_all_matches(scorefile, target, remove_ambiguous=False, skip_flip=False) + .filter(pl.col('ambiguous') == True)) + assert not flip_ambig['match_type'].str.contains('flip').any() # no flip matches for ambiguous + + +@pytest.fixture +def small_scorefile(): + df = pl.DataFrame({"accession": ["test", "test", "test"], + "chr_name": [1, 2, 3], + "chr_position": [1, 2, 3], + "effect_allele": ["A", "A", "G"], + "other_allele": ["C", "T", "T"], + "effect_weight": [1, 2, 3], + "effect_type": ["additive", "additive", "additive"]}) + + return complement_valid_alleles(df, ["effect_allele", "other_allele"]) + + +@pytest.fixture +def small_scorefile_no_oa(small_scorefile): + return small_scorefile.with_column(pl.lit(None).alias('other_allele')) + + +@pytest.fixture +def small_flipped_scorefile(small_scorefile): + # simulate a scorefile on the wrong strand + return (complement_valid_alleles(small_scorefile, ['effect_allele', 'other_allele']) + .drop(['effect_allele', 'other_allele']) + .rename({'effect_allele_FLIP': 'effect_allele', 'other_allele_FLIP': 'other_allele'}) + .pipe(complement_valid_alleles, ['effect_allele', 'other_allele'])) + + +@pytest.fixture +def small_target(): + return pl.DataFrame({"#CHROM": [1, 2, 3], + "POS": [1, 2, 3], + "REF": ["A", "T", "T"], + "ALT": ["C", "A", "G"], + "ID": ["1:1:A:C", "2:2:T:A", "3:3:T:G"], + "is_multiallelic": [False, False, False]}) +