diff --git a/Dockerfile b/Dockerfile index 8c19690..0d42228 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,8 +11,8 @@ FROM python:3.10 WORKDIR /opt/ -COPY --from=builder /app/dist/pgscatalog_utils-0.1.2-py3-none-any.whl . +COPY --from=builder /app/dist/pgscatalog_utils-0.2.0-py3-none-any.whl . -RUN pip install pgscatalog_utils-0.1.2-py3-none-any.whl +RUN pip install pgscatalog_utils-0.2.0-py3-none-any.whl RUN apt-get update && apt-get install -y sqlite3 \ No newline at end of file diff --git a/README.md b/README.md index d19c186..086bd75 100644 --- a/README.md +++ b/README.md @@ -2,9 +2,8 @@ [![CI](https://github.com/PGScatalog/pgscatalog_utils/actions/workflows/main.yml/badge.svg)](https://github.com/PGScatalog/pgscatalog_utils/actions/workflows/main.yml) -This repository is a collection of useful tools for working with data from the -PGS Catalog. This is mostly used internally by the PGS Catalog calculator, but -other users might find some of these tools helpful. +This repository is a collection of useful tools for downloading and working with scoring files from the +PGS Catalog. This is mostly used internally by the PGS Catalog Calculator ([`PGScatalog/pgsc_calc`](https://github.com/PGScatalog/pgsc_calc)); however, other users may find some of these tools helpful. ## Overview @@ -13,6 +12,7 @@ other users might find some of these tools helpful. in 'long' format * `match_variants`: Match target variants (bim or pvar files) against the output of `combine_scorefile` to produce scoring files for plink 2 +* `validate_scorefiles`: Check/validate that the scoring files and harmonized scoring files match the PGS Catalog scoring file formats. ## Installation @@ -26,6 +26,7 @@ $ pip install pgscatalog-utils $ download_scorefiles -i PGS000922 PGS001229 -o . -b GRCh37 $ combine_scorefiles -s PGS*.txt.gz -o combined.txt $ match_variants -s combined.txt -t --min_overlap 0.75 --outdir . +$ validate_scorefiles -t formatted --dir --log_dir ``` More details are available using the `--help` parameter. @@ -66,4 +67,4 @@ doi:[10.1038/s41588-021-00783-5](https://doi.org/10.1038/s41588-021-00783-5). This work has received funding from EMBL-EBI core funds, the Baker Institute, the University of Cambridge, Health Data Research UK (HDRUK), and the European Union's Horizon 2020 research and innovation programme -under grant agreement No 101016775 INTERVENE. \ No newline at end of file +under grant agreement No 101016775 INTERVENE. diff --git a/pgscatalog_utils/__init__.py b/pgscatalog_utils/__init__.py index 10939f0..7fd229a 100644 --- a/pgscatalog_utils/__init__.py +++ b/pgscatalog_utils/__init__.py @@ -1 +1 @@ -__version__ = '0.1.2' +__version__ = '0.2.0' diff --git a/pgscatalog_utils/aggregate/__init__.py b/pgscatalog_utils/aggregate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pgscatalog_utils/aggregate/aggregate_scores.py b/pgscatalog_utils/aggregate/aggregate_scores.py new file mode 100644 index 0000000..653a81d --- /dev/null +++ b/pgscatalog_utils/aggregate/aggregate_scores.py @@ -0,0 +1,92 @@ +import argparse +import textwrap + +import pandas as pd + +from pgscatalog_utils.config import set_logging_level +import glob +import logging + +logger = logging.getLogger(__name__) + + +def aggregate_scores(): + args = _parse_args() + set_logging_level(args.verbose) + df = aggregate(list(set(args.scores))) + logger.debug("Compressing and writing combined scores") + df.to_csv('aggregated_scores.txt.gz', sep='\t', compression='gzip') + + +def aggregate(scorefiles: list[str]): + combined = pd.DataFrame() + aggcols = set() + + for i, path in enumerate(scorefiles): + logger.debug(f"Reading {path}") + # pandas can automatically detect zst compression, neat! + df = (pd.read_table(path) + .assign(sampleset=path.split('_')[0]) + .set_index(['sampleset', '#IID'])) + + df.index.names = ['sampleset', 'IID'] + + # Subset to aggregatable columns + df = df[_select_agg_cols(df.columns)] + aggcols.update(set(df.columns)) + + # Combine DFs + if i == 0: + logger.debug('Initialising combined DF') + combined = df.copy() + else: + logger.debug('Adding to combined DF') + combined = combined.add(df, fill_value=0) + + assert all([x in combined.columns for x in aggcols]), "All Aggregatable Columns are present in the final DF" + + return combined.pipe(_calculate_average) + + +def _calculate_average(combined: pd.DataFrame): + logger.debug("Averaging data") + avgs = combined.loc[:, combined.columns.str.endswith('_SUM')].divide(combined['DENOM'], axis=0) + avgs.columns = avgs.columns.str.replace('_SUM', '_AVG') + return pd.concat([combined, avgs], axis=1) + + +def _select_agg_cols(cols): + keep_cols = ['DENOM'] + return [x for x in cols if (x.endswith('_SUM') and (x != 'NAMED_ALLELE_DOSAGE_SUM')) or (x in keep_cols)] + + +def _description_text() -> str: + return textwrap.dedent(''' + Aggregate plink .sscore files into a combined TSV table. + + This aggregation sums scores that were calculated from plink + .scorefiles. Scorefiles may be split to calculate scores over different + chromosomes or effect types. The PGS Catalog calculator automatically splits + scorefiles where appropriate, and uses this script to combine them. + + Input .sscore files can be optionally compressed with zstd or gzip. + + The aggregated output scores are compressed with gzip. + ''') + + +def _parse_args(args=None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=_description_text(), + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument('-s', '--scores', dest='scores', required=True, nargs='+', + help=' List of scorefile paths. Use a wildcard (*) to select multiple files.') + parser.add_argument('-o', '--outdir', dest='outdir', required=True, + default='scores/', help=' Output directory to store downloaded files') + parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', + help=' Extra logging information') + return parser.parse_args(args) + + +if __name__ == "__main__": + aggregate_scores() + diff --git a/pgscatalog_utils/log_config.py b/pgscatalog_utils/config.py similarity index 86% rename from pgscatalog_utils/log_config.py rename to pgscatalog_utils/config.py index dcd9cbe..7a6b8eb 100644 --- a/pgscatalog_utils/log_config.py +++ b/pgscatalog_utils/config.py @@ -1,5 +1,7 @@ import logging +POLARS_MAX_THREADS = 1 # dummy value, is reset by args.n_threads (default: 1) + def set_logging_level(verbose: bool): log_fmt = "%(name)s: %(asctime)s %(levelname)-8s %(message)s" diff --git a/pgscatalog_utils/download/download_scorefile.py b/pgscatalog_utils/download/download_scorefile.py index fc35529..6abd365 100644 --- a/pgscatalog_utils/download/download_scorefile.py +++ b/pgscatalog_utils/download/download_scorefile.py @@ -3,14 +3,16 @@ import os import shutil import textwrap +import time from contextlib import closing from functools import reduce from urllib import request as request +from urllib.error import HTTPError, URLError from pgscatalog_utils.download.publication import query_publication from pgscatalog_utils.download.score import get_url from pgscatalog_utils.download.trait import query_trait -from pgscatalog_utils.log_config import set_logging_level +from pgscatalog_utils.config import set_logging_level logger = logging.getLogger(__name__) @@ -31,13 +33,17 @@ def download_scorefile() -> None: pgs_lst: list[list[str]] = [] + pgsc_calc_info = None + if args.pgsc_calc: + pgsc_calc_info = args.pgsc_calc + if args.efo: logger.debug("--trait set, querying traits") - pgs_lst = pgs_lst + [query_trait(x) for x in args.efo] + pgs_lst = pgs_lst + [query_trait(x, pgsc_calc_info) for x in args.efo] if args.pgp: logger.debug("--pgp set, querying publications") - pgs_lst = pgs_lst + [query_publication(x) for x in args.pgp] + pgs_lst = pgs_lst + [query_publication(x, pgsc_calc_info) for x in args.pgp] if args.pgs: logger.debug("--id set, querying scores") @@ -45,7 +51,7 @@ def download_scorefile() -> None: pgs_id: list[str] = list(set(reduce(lambda x, y: x + y, pgs_lst))) - urls: dict[str, str] = get_url(pgs_id, args.build) + urls: dict[str, str] = get_url(pgs_id, args.build, pgsc_calc_info) for pgsid, url in urls.items(): logger.debug(f"Downloading {pgsid} from {url}") @@ -62,14 +68,26 @@ def _mkdir(outdir: str) -> None: os.makedirs(outdir) -def _download_ftp(url: str, path: str) -> None: +def _download_ftp(url: str, path: str, retry:int = 0) -> None: if os.path.exists(path): logger.warning(f"File already exists at {path}, skipping download") return else: - with closing(request.urlopen(url)) as r: - with open(path, 'wb') as f: - shutil.copyfileobj(r, f) + try: + with closing(request.urlopen(url)) as r: + with open(path, 'wb') as f: + shutil.copyfileobj(r, f) + except (HTTPError, URLError) as error: + max_retries = 5 + print(f'Download failed: {error.reason}') + # Retry to download the file if the server is busy + if '421' in error.reason and retry < max_retries: + print(f'> Retry to download the file ... attempt {retry+1} out of {max_retries}.') + retry += 1 + time.sleep(10) + _download_ftp(url,path,retry) + else: + raise RuntimeError("Failed to download '{}'.\nError message: '{}'".format(url, error.reason)) def _check_args(args): @@ -121,6 +139,8 @@ def _parse_args(args=None) -> argparse.Namespace: parser.add_argument('-o', '--outdir', dest='outdir', required=True, default='scores/', help=' Output directory to store downloaded files') + parser.add_argument('-c', '--pgsc_calc', dest='pgsc_calc', + help=' Provide information about downloading scoring files via pgsc_calc') parser.add_argument('-v', '--verbose', dest='verbose', action='store_true', help=' Extra logging information') return parser.parse_args(args) diff --git a/pgscatalog_utils/download/publication.py b/pgscatalog_utils/download/publication.py index 843b8a2..675b263 100644 --- a/pgscatalog_utils/download/publication.py +++ b/pgscatalog_utils/download/publication.py @@ -1,20 +1,20 @@ import logging from functools import reduce -import requests +from pgscatalog_utils.download.score import query_api logger = logging.getLogger(__name__) -def query_publication(pgp: str) -> list[str]: - api: str = f'https://www.pgscatalog.org/rest/publication/{pgp}' +def query_publication(pgp: str, user_agent:str = None) -> list[str]: logger.debug("Querying PGS Catalog with publication PGP ID") - r: requests.models.Response = requests.get(api) + api: str = f'/publication/{pgp}' + results_json = query_api(api, user_agent) - if r.json() == {}: + if results_json == {} or results_json == None: logger.critical(f"Bad response from PGS Catalog for EFO term: {pgp}") raise Exception - pgs: dict[str, list[str]] = r.json().get('associated_pgs_ids') + pgs: dict[str, list[str]] = results_json.get('associated_pgs_ids') logger.debug(f"Valid response from PGS Catalog for PGP ID: {pgp}") return list(reduce(lambda x, y: set(x).union(set(y)), pgs.values())) diff --git a/pgscatalog_utils/download/score.py b/pgscatalog_utils/download/score.py index a38dc0c..3c2bf29 100644 --- a/pgscatalog_utils/download/score.py +++ b/pgscatalog_utils/download/score.py @@ -3,17 +3,19 @@ import jq import requests +import time +from pgscatalog_utils import __version__ as pgscatalog_utils_version logger = logging.getLogger(__name__) -def get_url(pgs: list[str], build: str) -> dict[str, str]: +def get_url(pgs: list[str], build: str, user_agent:str = None) -> dict[str, str]: pgs_result: list[str] = [] url_result: list[str] = [] for chunk in _chunker(pgs): try: - response = _parse_json_query(query_score(chunk), build) + response = _parse_json_query(query_score(chunk,user_agent), build) pgs_result = pgs_result + list(response.keys()) url_result = url_result + list(response.values()) except TypeError: @@ -28,11 +30,40 @@ def get_url(pgs: list[str], build: str) -> dict[str, str]: return dict(zip(pgs_result, url_result)) -def query_score(pgs_id: list[str]) -> dict: +def query_api(api: str, user_agent:str = None, retry:int = 0) -> dict: + max_retries = 5 + wait = 60 + results_json = None + rest_url_root = 'https://www.pgscatalog.org/rest' + # Set pgscatalog_utils user agent if none provided + if not user_agent: + user_agent = 'pgscatalog_utils/'+pgscatalog_utils_version + try: + headers = {'User-Agent': user_agent} + r: requests.models.Response = requests.get(rest_url_root+api, headers=headers) + r.raise_for_status() + results_json = r.json() + except requests.exceptions.HTTPError as e: + print(f'HTTP Error: {e}') + if r.status_code in [421,429] and retry < 5: + retry +=1 + print(f'> Retry to query the PGS Catalog REST API in {wait}s ... attempt {retry} out of {max_retries}.') + time.sleep(wait) + results_json = query_api(api,retry) + except requests.exceptions.ConnectionError as e: + print(f'Error Connecting: {e}') + except requests.exceptions.Timeout as e: + print(f'Timeout Error: {e}') + except requests.exceptions.RequestException as e: + print(f'Request Error: {e}') + return results_json + + +def query_score(pgs_id: list[str], user_agent:str = None) -> dict: pgs: str = ','.join(pgs_id) - api: str = f'https://www.pgscatalog.org/rest/score/search?pgs_ids={pgs}' - r: requests.models.Response = requests.get(api) - return r.json() + api: str = f'/score/search?pgs_ids={pgs}' + results_json = query_api(api, user_agent) + return results_json def _chunker(pgs: list[str]): diff --git a/pgscatalog_utils/download/trait.py b/pgscatalog_utils/download/trait.py index c2db495..609e3e1 100644 --- a/pgscatalog_utils/download/trait.py +++ b/pgscatalog_utils/download/trait.py @@ -1,24 +1,24 @@ import logging from functools import reduce -import requests +from pgscatalog_utils.download.score import query_api logger = logging.getLogger(__name__) -def query_trait(trait: str) -> list[str]: - api: str = f'https://www.pgscatalog.org/rest/trait/{trait}?include_children=1' +def query_trait(trait: str, user_agent:str = None) -> list[str]: logger.debug(f"Querying PGS Catalog with trait {trait}") - r: requests.models.Response = requests.get(api) + api: str = f'/trait/{trait}?include_children=1' + results_json = query_api(api, user_agent) - if r.json() == {}: + if results_json == {} or results_json == None: logger.critical(f"Bad response from PGS Catalog for EFO term: {trait}") raise Exception keys: list[str] = ['associated_pgs_ids', 'child_associated_pgs_ids'] pgs: list[str] = [] for key in keys: - pgs.append(r.json().get(key)) + pgs.append(results_json.get(key)) logger.debug(f"Valid response from PGS Catalog for EFO term: {trait}") return list(reduce(lambda x, y: set(x).union(set(y)), pgs)) diff --git a/pgscatalog_utils/match/filter.py b/pgscatalog_utils/match/filter.py index c47a449..c2d0364 100644 --- a/pgscatalog_utils/match/filter.py +++ b/pgscatalog_utils/match/filter.py @@ -5,14 +5,14 @@ logger = logging.getLogger(__name__) -def filter_scores(scorefile: pl.DataFrame, matches: pl.DataFrame, min_overlap: float, - dataset: str) -> tuple[pl.DataFrame, pl.DataFrame]: +def filter_scores(scorefile: pl.LazyFrame, matches: pl.LazyFrame, min_overlap: float, + dataset: str) -> tuple[pl.LazyFrame, pl.LazyFrame]: """ Check overlap between filtered matches and scorefile, remove scores that don't match well and report stats """ - filtered_matches: pl.DataFrame = _filter_matches(matches) - match_log: pl.DataFrame = (_join_filtered_matches(filtered_matches, scorefile, dataset) + filtered_matches: pl.LazyFrame = _filter_matches(matches) + match_log: pl.LazyFrame = (_join_filtered_matches(filtered_matches, scorefile, dataset) .with_columns(pl.col('best_match').fill_null(False))) - fail_rates: pl.DataFrame = _calculate_match_rate(match_log) + fail_rates: pl.DataFrame = _calculate_match_rate(match_log).collect() # collect for iteration scores: list[pl.DataFrame] = [] for accession, rate in zip(fail_rates['accession'].to_list(), fail_rates['fail_rate'].to_list()): @@ -25,7 +25,7 @@ def filter_scores(scorefile: pl.DataFrame, matches: pl.DataFrame, min_overlap: f logger.error(f"Score {accession} fails minimum matching threshold ({1 - rate:.2%} variants match)") scores.append(df.with_column(pl.col('accession').cast(pl.Categorical))) - score_summary: pl.DataFrame = pl.concat(scores) + score_summary: pl.LazyFrame = pl.concat(scores).lazy() filtered_scores: pl.DataFrame = (filtered_matches.join(score_summary, on='accession', how='left') .filter(pl.col('score_pass') == True)) @@ -39,12 +39,12 @@ def _calculate_match_rate(df: pl.DataFrame) -> pl.DataFrame: .with_column((pl.col('no_match') / pl.col('count')).alias('fail_rate'))) -def _filter_matches(df: pl.DataFrame) -> pl.DataFrame: +def _filter_matches(df: pl.LazyFrame) -> pl.LazyFrame: logger.debug("Filtering variants with exclude flag") return df.filter((pl.col('best_match') == True) & (pl.col('exclude') == False)) -def _join_filtered_matches(matches: pl.DataFrame, scorefile: pl.DataFrame, dataset: str) -> pl.DataFrame: +def _join_filtered_matches(matches: pl.LazyFrame, scorefile: pl.LazyFrame, dataset: str) -> pl.LazyFrame: return (scorefile.join(matches, on=['row_nr', 'accession'], how='left') .with_column(pl.lit(dataset).alias('dataset')) .select(pl.exclude("^.*_right$"))) diff --git a/pgscatalog_utils/match/label.py b/pgscatalog_utils/match/label.py index 0d38ccb..1c55ba3 100644 --- a/pgscatalog_utils/match/label.py +++ b/pgscatalog_utils/match/label.py @@ -7,7 +7,7 @@ logger = logging.getLogger(__name__) -def label_matches(df: pl.DataFrame, remove_ambiguous, keep_first_match) -> pl.DataFrame: +def label_matches(df: pl.LazyFrame, params: dict[str, bool]) -> pl.LazyFrame: """ Label match candidates with additional metadata. Column definitions: - match_candidate: All input variants that were returned from match.get_all_matches() (always True in this function) @@ -15,17 +15,20 @@ def label_matches(df: pl.DataFrame, remove_ambiguous, keep_first_match) -> pl.Da - duplicate: True if more than one best match exists for the same accession and ID - ambiguous: True if ambiguous """ + assert set(params.keys()) == {'keep_first_match', 'remove_ambiguous', 'remove_multiallelic', 'skip_flip'} labelled = (df.with_column(pl.lit(False).alias('exclude')) # set up dummy exclude column for _label_* .pipe(_label_best_match) .pipe(_label_duplicate_best_match) - .pipe(_label_duplicate_id, keep_first_match) - .pipe(_label_biallelic_ambiguous, remove_ambiguous) + .pipe(_label_duplicate_id, params['keep_first_match']) + .pipe(_label_biallelic_ambiguous, params['remove_ambiguous']) + .pipe(_label_multiallelic, params['remove_multiallelic']) + .pipe(_label_flips, params['skip_flip']) .with_column(pl.lit(True).alias('match_candidate'))) return _encode_match_priority(labelled) -def _encode_match_priority(df: pl.DataFrame) -> pl.DataFrame: +def _encode_match_priority(df: pl.LazyFrame) -> pl.LazyFrame: """ Encode a new column called match status containing matched, unmatched, excluded, and not_best """ return (df.with_columns([ # set false best match to not_best @@ -39,7 +42,7 @@ def _encode_match_priority(df: pl.DataFrame) -> pl.DataFrame: .cast(pl.Categorical)).drop(["max", "excluded_match_priority", "match_priority"])) -def _label_best_match(df: pl.DataFrame) -> pl.DataFrame: +def _label_best_match(df: pl.LazyFrame) -> pl.LazyFrame: """ Best matches have the lowest match priority type. Find the best matches and label them. """ logger.debug("Labelling best match type (refalt > altref > ...)") match_priority = {'refalt': 0, 'altref': 1, 'refalt_flip': 2, 'altref_flip': 3, 'no_oa_ref': 4, 'no_oa_alt': 5, @@ -48,7 +51,7 @@ def _label_best_match(df: pl.DataFrame) -> pl.DataFrame: # use a groupby aggregation to guarantee the number of rows stays the same # rows were being lost using an anti join + reduce approach - prioritised: pl.DataFrame = (df.with_column(pl.col('match_type') + prioritised: pl.LazyFrame = (df.with_column(pl.col('match_type') .apply(lambda x: match_priority[x]) .alias('match_priority')) .with_column(pl.col("match_priority") @@ -60,11 +63,11 @@ def _label_best_match(df: pl.DataFrame) -> pl.DataFrame: .then(pl.lit(True)) .otherwise(pl.lit(False)) .alias('best_match'))) - assert prioritised.shape[0] == df.shape[0] # I'm watching you, Wazowski. Always watching. Always. + return prioritised.drop(['match_priority', 'best_match_type']) -def _label_duplicate_best_match(df: pl.DataFrame) -> pl.DataFrame: +def _label_duplicate_best_match(df: pl.LazyFrame) -> pl.LazyFrame: """ A scoring file row_nr in an accession group can be duplicated if a target position has different REF, e.g.: ┌────────┬────────────────────────┬────────────┬────────────────┬─────┬────────────┐ @@ -80,7 +83,7 @@ def _label_duplicate_best_match(df: pl.DataFrame) -> pl.DataFrame: Label the first row with best_match = true, and duplicate rows with best_match = false """ logger.debug("Labelling duplicated best match: keeping first instance as best_match = True") - labelled: pl.DataFrame = (df.with_column(pl.col('best_match') + labelled: pl.LazyFrame = (df.with_column(pl.col('best_match') .count() .over(['accession', 'row_nr', 'best_match']) .alias('count')) @@ -104,7 +107,7 @@ def _label_duplicate_best_match(df: pl.DataFrame) -> pl.DataFrame: return labelled -def _label_duplicate_id(df: pl.DataFrame, keep_first_match: bool) -> pl.DataFrame: +def _label_duplicate_id(df: pl.LazyFrame, keep_first_match: bool) -> pl.LazyFrame: """ Label best match duplicates made when the scoring file is remapped to a different genome build ┌─────────┬────────────────────────┬─────────────┬────────────────┬─────┬────────────┐ @@ -151,7 +154,7 @@ def _label_duplicate_id(df: pl.DataFrame, keep_first_match: bool) -> pl.DataFram .rename({"max": "exclude"})) -def _label_biallelic_ambiguous(df: pl.DataFrame, remove_ambiguous) -> pl.DataFrame: +def _label_biallelic_ambiguous(df: pl.LazyFrame, remove_ambiguous) -> pl.LazyFrame: logger.debug("Labelling ambiguous variants") ambig = ((df.with_columns([ pl.col(["effect_allele", "other_allele", "REF", "ALT", "effect_allele_FLIP", "other_allele_FLIP"]).cast(str), @@ -177,3 +180,33 @@ def _label_biallelic_ambiguous(df: pl.DataFrame, remove_ambiguous) -> pl.DataFra .rename({"max": "exclude"})) +def _label_multiallelic(df: pl.LazyFrame, remove_multiallelic: bool) -> pl.LazyFrame: + """ Label multiallelic variants with exclude flag + + (Multiallelic variants are already labelled with the "is_multiallelic" column in match.preprocess) + """ + if remove_multiallelic: + logger.debug("Labelling multiallelic matches with exclude flag") + return df.with_column(pl.when(pl.col('is_multiallelic') == True) + .then(True) + .otherwise(pl.col('exclude')) # don't overwrite existing exclude flags + .alias('exclude')) + else: + logger.debug("Not excluding multiallelic variants") + return df + + +def _label_flips(df: pl.LazyFrame, skip_flip: bool) -> pl.LazyFrame: + df = df.with_column(pl.when(pl.col('match_type').str.contains('_flip')) + .then(True) + .otherwise(False) + .alias('match_flipped')) + if skip_flip: + logger.debug("Labelling flipped matches with exclude flag") + return df.with_column(pl.when(pl.col('match_flipped') == True) + .then(True) + .otherwise(pl.col('exclude')) # don't overwrite existing exclude flags + .alias('exclude')) + else: + logger.debug("Not excluding flipped matches") + return df \ No newline at end of file diff --git a/pgscatalog_utils/match/log.py b/pgscatalog_utils/match/log.py index 91f3999..5b74517 100644 --- a/pgscatalog_utils/match/log.py +++ b/pgscatalog_utils/match/log.py @@ -5,7 +5,7 @@ logger = logging.getLogger(__name__) -def make_logs(scorefile, match_candidates, filter_summary, dataset): +def make_logs(scorefile: pl.LazyFrame, match_candidates: pl.LazyFrame, filter_summary: pl.LazyFrame, dataset: str): # summary log -> aggregated from best matches (one per scoring file line) # big log -> unaggregated, written to compressed gzip, possibly multiple matches per scoring file line summary_log, big_log = _join_match_candidates(scorefile=scorefile, matches=match_candidates, @@ -13,29 +13,30 @@ def make_logs(scorefile, match_candidates, filter_summary, dataset): dataset=dataset) # make sure the aggregated best log matches the scoring file accession line count - summary_count = (summary_log.groupby(pl.col('accession')) - .agg(pl.sum('count'))) - log_count = (scorefile.groupby("accession") - .count() - .join(summary_count, on='accession')) - - assert (log_count['count'] == log_count['count_right']).all(), "Log doesn't match input scoring file" + summary_count: pl.LazyFrame = (summary_log.groupby(pl.col('accession')) + .agg(pl.sum('count'))) + log_count: pl.DataFrame = (scorefile.groupby("accession") + .agg(pl.count()) + .join(summary_count, on='accession')).collect() + + assert (log_count.get_column('count') == log_count.get_column( + 'count_right')).all(), "Log doesn't match input scoring file" logger.debug("Log matches input scoring file") return _prettify_log(big_log), _prettify_summary(summary_log) -def make_summary_log(best_matches, filter_summary): +def make_summary_log(best_matches: pl.LazyFrame, filter_summary: pl.LazyFrame) -> pl.LazyFrame: """ Make an aggregated table """ logger.debug("Aggregating best match log into a summary table") return (best_matches - .groupby(['dataset', 'accession', 'match_status', 'ambiguous', 'is_multiallelic', 'duplicate_best_match', - 'duplicate_ID']) - .count() + .groupby(['dataset', 'accession', 'match_status', 'ambiguous', 'is_multiallelic', 'match_flipped', + 'duplicate_best_match', 'duplicate_ID']) + .agg(pl.count()) .join(filter_summary, how='left', on='accession')) -def _prettify_summary(df: pl.DataFrame): +def _prettify_summary(df: pl.LazyFrame) -> pl.LazyFrame: keep_cols = ["dataset", "accession", "score_pass", "match_status", "ambiguous", "is_multiallelic", "duplicate_best_match", "duplicate_ID", "count", "percent"] return (df.with_column((pl.col("count") / pl.sum("count") * 100) @@ -44,18 +45,18 @@ def _prettify_summary(df: pl.DataFrame): .select(keep_cols)) -def _prettify_log(df: pl.DataFrame) -> pl.DataFrame: +def _prettify_log(df: pl.LazyFrame) -> pl.LazyFrame: keep_cols = ["row_nr", "accession", "chr_name", "chr_position", "effect_allele", "other_allele", "effect_weight", "effect_type", "ID", "REF", "ALT", "matched_effect_allele", "match_type", "is_multiallelic", "ambiguous", "duplicate_best_match", "duplicate_ID", "match_status", "dataset"] pretty_df = (df.select(keep_cols) .select(pl.exclude("^.*_right")) - .sort(["accession", "row_nr", "chr_name", "chr_position"])) + .sort(["accession", "row_nr", "chr_name", "chr_position", "match_status"])) return pretty_df -def _join_match_candidates(scorefile: pl.DataFrame, matches: pl.DataFrame, filter_summary: pl.DataFrame, - dataset: str) -> tuple[pl.DataFrame, pl.DataFrame]: +def _join_match_candidates(scorefile: pl.LazyFrame, matches: pl.LazyFrame, filter_summary: pl.LazyFrame, + dataset: str) -> tuple[pl.LazyFrame, pl.LazyFrame]: """ Join match candidates against the original scoring file """ logger.debug("Making big logs") diff --git a/pgscatalog_utils/match/match.py b/pgscatalog_utils/match/match.py index 677f22a..4363dd5 100644 --- a/pgscatalog_utils/match/match.py +++ b/pgscatalog_utils/match/match.py @@ -1,43 +1,63 @@ +import gc import logging +import os +from tempfile import TemporaryDirectory import polars as pl -from pgscatalog_utils.match.label import label_matches - logger = logging.getLogger(__name__) -def get_all_matches(scorefile: pl.DataFrame, target: pl.DataFrame, skip_flip: bool, remove_ambiguous: bool, - keep_first_match: bool) -> pl.DataFrame: - scorefile_cat, target_cat = _cast_categorical(scorefile, target) - scorefile_oa = scorefile_cat.filter(pl.col("other_allele") != None) - scorefile_no_oa = scorefile_cat.filter(pl.col("other_allele") == None) +# @profile # decorator needed to annotate memory profiles, but will cause NameErrors outside of profiling +def get_all_matches(scorefile: pl.LazyFrame, target: pl.LazyFrame, low_memory: bool = True) -> pl.LazyFrame: + scorefile_oa = scorefile.filter(pl.col("other_allele") != None) + scorefile_no_oa = scorefile.filter(pl.col("other_allele") == None) - matches: list[pl.DataFrame] = [] + matches: list[pl.LazyFrame()] = [] col_order = ['row_nr', 'chr_name', 'chr_position', 'effect_allele', 'other_allele', 'effect_weight', 'effect_type', 'accession', 'effect_allele_FLIP', 'other_allele_FLIP', 'ID', 'REF', 'ALT', 'is_multiallelic', 'matched_effect_allele', 'match_type'] - if scorefile_oa: - logger.debug("Getting matches for scores with effect allele and other allele") - matches.append(_match_variants(scorefile_cat, target_cat, match_type="refalt").select(col_order)) - matches.append(_match_variants(scorefile_cat, target_cat, match_type="altref").select(col_order)) - if skip_flip is False: - matches.append(_match_variants(scorefile_cat, target_cat, match_type="refalt_flip").select(col_order)) - matches.append(_match_variants(scorefile_cat, target_cat, match_type="altref_flip").select(col_order)) + logger.debug("Getting matches for scores with effect allele and other allele") + matches.append(_match_variants(scorefile=scorefile_oa, target=target, match_type="refalt").select(col_order)) + matches.append(_match_variants(scorefile_oa, target, match_type="altref").select(col_order)) + matches.append(_match_variants(scorefile_oa, target, match_type="refalt_flip").select(col_order)) + matches.append(_match_variants(scorefile_oa, target, match_type="altref_flip").select(col_order)) + + logger.debug("Getting matches for scores with effect allele only") + matches.append(_match_variants(scorefile_no_oa, target, match_type="no_oa_ref").select(col_order)) + matches.append(_match_variants(scorefile_no_oa, target, match_type="no_oa_alt").select(col_order)) + matches.append(_match_variants(scorefile_no_oa, target, match_type="no_oa_ref_flip").select(col_order)) + matches.append(_match_variants(scorefile_no_oa, target, match_type="no_oa_alt_flip").select(col_order)) + + if low_memory: + logger.debug("Batch collecting matches (low memory mode)") + match_lf = _batch_collect(matches) + else: + logger.debug("Collecting all matches (parallel)") + match_lf = pl.concat(pl.collect_all(matches)) + + return match_lf.lazy() - if scorefile_no_oa: - logger.debug("Getting matches for scores with effect allele only") - matches.append(_match_variants(scorefile_no_oa, target_cat, match_type="no_oa_ref").select(col_order)) - matches.append(_match_variants(scorefile_no_oa, target_cat, match_type="no_oa_alt").select(col_order)) - if skip_flip is False: - matches.append(_match_variants(scorefile_no_oa, target_cat, match_type="no_oa_ref_flip").select(col_order)) - matches.append(_match_variants(scorefile_no_oa, target_cat, match_type="no_oa_alt_flip").select(col_order)) - return pl.concat(matches).pipe(label_matches, remove_ambiguous, keep_first_match) +def _batch_collect(matches: list[pl.LazyFrame]) -> pl.DataFrame: + """ A slower alternative to pl.collect_all(), but this approach will use less peak memory + This batches the .collect() and writes intermediate results to a temporary working directory -def _match_variants(scorefile: pl.DataFrame, target: pl.DataFrame, match_type: str) -> pl.DataFrame: + IPC files are binary and remember column schema. Reading them can be extremely fast. """ + with TemporaryDirectory() as temp_dir: + n_chunks = 0 + for i, match in enumerate(matches): + out_path = os.path.join(temp_dir, str(i) + ".ipc") + match.collect().write_ipc(out_path) + n_chunks += 1 + logger.debug(f"Staged {n_chunks} match chunks to {temp_dir}") + gc.collect() + return pl.read_ipc(os.path.join(temp_dir, "*.ipc")) + + +def _match_variants(scorefile: pl.LazyFrame, target: pl.LazyFrame, match_type: str) -> pl.LazyFrame: logger.debug(f"Matching strategy: {match_type}") match match_type: case 'refalt': @@ -88,24 +108,3 @@ def _match_variants(scorefile: pl.DataFrame, target: pl.DataFrame, match_type: s pl.col(effect_allele_column).alias("matched_effect_allele"), pl.lit(match_type).alias("match_type")]) .join(target.select(join_cols), on="ID", how="inner")) # get REF / ALT back after first join - - -def _cast_categorical(scorefile, target) -> tuple[pl.DataFrame, pl.DataFrame]: - """ Casting important columns to categorical makes polars fast """ - if scorefile: - scorefile = scorefile.with_columns([ - pl.col("effect_allele").cast(pl.Categorical), - pl.col("other_allele").cast(pl.Categorical), - pl.col("effect_type").cast(pl.Categorical), - pl.col("effect_allele_FLIP").cast(pl.Categorical), - pl.col("other_allele_FLIP").cast(pl.Categorical), - pl.col("accession").cast(pl.Categorical) - ]) - if target: - target = target.with_columns([ - pl.col("ID").cast(pl.Categorical), - pl.col("REF").cast(pl.Categorical), - pl.col("ALT").cast(pl.Categorical) - ]) - - return scorefile, target diff --git a/pgscatalog_utils/match/match_variants.py b/pgscatalog_utils/match/match_variants.py index 336d781..380e71c 100644 --- a/pgscatalog_utils/match/match_variants.py +++ b/pgscatalog_utils/match/match_variants.py @@ -1,14 +1,16 @@ import argparse import logging +import os +import sys import textwrap -from glob import glob import polars as pl -from pgscatalog_utils.log_config import set_logging_level +import pgscatalog_utils.config as config +from pgscatalog_utils.match.filter import filter_scores +from pgscatalog_utils.match.label import label_matches from pgscatalog_utils.match.log import make_logs from pgscatalog_utils.match.match import get_all_matches -from pgscatalog_utils.match.filter import filter_scores from pgscatalog_utils.match.read import read_target, read_scorefile from pgscatalog_utils.match.write import write_out, write_log @@ -17,37 +19,47 @@ def match_variants(): args = _parse_args() + config.set_logging_level(args.verbose) - set_logging_level(args.verbose) - - logger.debug(f"polars n_threads: {pl.threadpool_size()}") + config.POLARS_MAX_THREADS = args.n_threads + os.environ['POLARS_MAX_THREADS'] = str(config.POLARS_MAX_THREADS) + # now the environment variable, parsed argument args.n_threads, and threadpool should agree + logger.debug(f"Setting POLARS_MAX_THREADS environment variable: {os.getenv('POLARS_MAX_THREADS')}") + logger.debug(f"Using {config.POLARS_MAX_THREADS} threads to read CSVs") + logger.debug(f"polars threadpool size: {pl.threadpool_size()}") with pl.StringCache(): - scorefile: pl.DataFrame = read_scorefile(path=args.scorefile) - - n_target_files = len(glob(args.target)) + scorefile: pl.LazyFrame = read_scorefile(path=args.scorefile) + target_paths = list(set(args.target)) + n_target_files = len(target_paths) matches: pl.DataFrame + if n_target_files == 0: + logger.critical("No target genomes found, check the path") + sys.exit(1) + if n_target_files == 1 and not args.fast: + low_memory: bool = True match_mode: str = 'single' elif n_target_files > 1 and not args.fast: + low_memory: bool = True match_mode: str = 'multi' elif args.fast: + low_memory: bool = False match_mode: str = 'fast' match match_mode: case "single": logger.debug(f"Match mode: {match_mode}") - matches = _match_single_target(args.target, scorefile, args.remove_multiallelic, args.skip_flip, - args.remove_ambiguous, args.keep_first_match) + # _fast_match with low_memory = True reads one target in chunks + matches: pl.LazyFrame = _fast_match(target_paths, scorefile, args, low_memory) case "multi": - logger.debug(f"Match mode: {match_mode}") - matches = _match_multiple_targets(args.target, scorefile, args.remove_multiallelic, args.skip_flip, - args.remove_ambiguous, args.keep_first_match) + logger.debug(f"Match mode: {match_mode}") # iterate over multiple targets, in chunks + matches: pl.LazyFrame = _match_multiple_targets(target_paths, scorefile, args, low_memory) case "fast": logger.debug(f"Match mode: {match_mode}") - matches = _fast_match(args.target, scorefile, args.remove_multiallelic, args.skip_flip, - args.remove_ambiguous, args.keep_first_match) + # _fast_match with low_memory = False just read everything into memory for speed + matches: pl.LazyFrame = _fast_match(target_paths, scorefile, args, low_memory) case _: logger.critical(f"Invalid match mode: {match_mode}") raise Exception @@ -56,19 +68,19 @@ def match_variants(): valid_matches, filter_summary = filter_scores(scorefile=scorefile, matches=matches, dataset=dataset, min_overlap=args.min_overlap) - if valid_matches.is_empty(): # this can happen if args.min_overlap = 0 + if valid_matches.fetch().is_empty(): # this can happen if args.min_overlap = 0 logger.error("Error: no target variants match any variants in scoring files") raise Exception big_log, summary_log = make_logs(scorefile, matches, filter_summary, args.dataset) write_log(big_log, prefix=dataset) - summary_log.write_csv(f"{dataset}_summary.csv") + summary_log.collect().write_csv(f"{dataset}_summary.csv") write_out(valid_matches, args.split, args.outdir, dataset) -def _check_target_chroms(target) -> None: - chroms: list[str] = target['#CHROM'].unique().to_list() +def _check_target_chroms(target: pl.LazyFrame) -> None: + chroms: list[str] = target.select(pl.col("#CHROM").unique()).collect().get_column("#CHROM").to_list() if len(chroms) > 1: logger.critical(f"Multiple chromosomes detected: {chroms}. Check input data.") raise Exception @@ -76,40 +88,29 @@ def _check_target_chroms(target) -> None: logger.debug("Split target genome contains one chromosome (good)") -def _fast_match(target_path: str, scorefile: pl.DataFrame, remove_multiallelic: bool, - skip_filp: bool, remove_ambiguous: bool, keep_first_match: bool) -> pl.DataFrame: +def _fast_match(target_paths: list[str], scorefile: pl.LazyFrame, + args: argparse.Namespace, low_memory: bool) -> pl.LazyFrame: # fast match is fast because: - # 1) all target files are read into memory + # 1) all target files are read into memory without batching # 2) matching occurs without iterating through chromosomes - target: pl.DataFrame = read_target(path=target_path, - remove_multiallelic=remove_multiallelic) - logger.debug("Split target chromosomes not checked with fast match mode") - return get_all_matches(scorefile, target, skip_filp, remove_ambiguous, keep_first_match) + # when low memory is true and n_targets = 1, fast match is the same as "single" match mode + params: dict[str, bool] = _make_params_dict(args) + target: pl.LazyFrame = read_target(paths=target_paths, low_memory=low_memory) + return (get_all_matches(scorefile=scorefile, target=target, low_memory=low_memory) + .pipe(label_matches, params=params)) -def _match_multiple_targets(target_path: str, scorefile: pl.DataFrame, remove_multiallelic: bool, - skip_filp: bool, remove_ambiguous: bool, keep_first_match: bool) -> pl.DataFrame: +def _match_multiple_targets(target_paths: list[str], scorefile: pl.LazyFrame, args: argparse.Namespace, + low_memory: bool) -> pl.LazyFrame: matches = [] - for i, loc_target_current in enumerate(glob(target_path)): + params: dict[str, bool] = _make_params_dict(args) + for i, loc_target_current in enumerate(target_paths): logger.debug(f'Matching scorefile(s) against target: {loc_target_current}') - target: pl.DataFrame = read_target(path=loc_target_current, - remove_multiallelic=remove_multiallelic) + target: pl.LazyFrame = read_target(paths=[loc_target_current], low_memory=low_memory) _check_target_chroms(target) - matches.append(get_all_matches(scorefile, target, skip_filp, remove_ambiguous, keep_first_match)) - return pl.concat(matches) - - -def _match_single_target(target_path: str, scorefile: pl.DataFrame, remove_multiallelic: bool, - skip_filp: bool, remove_ambiguous: bool, keep_first_match: bool) -> pl.DataFrame: - matches = [] - for chrom in scorefile['chr_name'].unique().to_list(): - target = read_target(target_path, remove_multiallelic=remove_multiallelic, - single_file=True, chrom=chrom) # scans and filters - if target: - logger.debug(f"Matching chromosome {chrom}") - matches.append(get_all_matches(scorefile, target, skip_filp, remove_ambiguous, keep_first_match)) - - return pl.concat(matches) + matches.append(get_all_matches(scorefile=scorefile, target=target, low_memory=low_memory)) + return (pl.concat(matches) + .pipe(label_matches, params=params)) def _description_text() -> str: @@ -166,10 +167,11 @@ def _parse_args(args=None): help=' Label for target genomic dataset') parser.add_argument('-s', '--scorefiles', dest='scorefile', required=True, help=' Combined scorefile path (output of read_scorefiles.py)') - parser.add_argument('-t', '--target', dest='target', required=True, - help=' A table of target genomic variants (.bim format)') + parser.add_argument('-t', '--target', dest='target', required=True, nargs='+', + help=' A list of paths of target genomic variants (.bim format)') parser.add_argument('-f', '--fast', dest='fast', action='store_true', help=' Enable faster matching at the cost of increased RAM usage') + parser.add_argument('-n', dest='n_threads', default=1, help=' n threads for matching', type=int) parser.add_argument('--split', dest='split', default=False, action='store_true', help=' Split scorefile per chromosome?') parser.add_argument('--outdir', dest='outdir', required=True, @@ -197,5 +199,13 @@ def _parse_args(args=None): return parser.parse_args(args) +def _make_params_dict(args) -> dict[str, bool]: + """ Make a dictionary with parameters that control labelling match candidates """ + return {'keep_first_match': args.keep_first_match, + 'remove_ambiguous': args.remove_ambiguous, + 'skip_flip': args.skip_flip, + 'remove_multiallelic': args.remove_multiallelic} + + if __name__ == "__main__": match_variants() diff --git a/pgscatalog_utils/match/preprocess.py b/pgscatalog_utils/match/preprocess.py index 1723f6d..9997176 100644 --- a/pgscatalog_utils/match/preprocess.py +++ b/pgscatalog_utils/match/preprocess.py @@ -5,6 +5,16 @@ logger = logging.getLogger(__name__) +def filter_target(df: pl.DataFrame) -> pl.DataFrame: + """ Remove variants that won't be matched against the scorefile + + Chromosomes 1 - 22, X, and Y with an efficient join. Remmove variants with missing identifiers also + """ + logger.debug("Filtering target to include chromosomes 1 - 22, X, Y") + chroms = [str(x) for x in list(range(1, 23)) + ['X', 'Y']] + return df.filter((pl.col('#CHROM').is_in(chroms)) & (pl.col('ID') != '.')) + + def complement_valid_alleles(df: pl.DataFrame, flip_cols: list[str]) -> pl.DataFrame: """ Improved function to complement alleles. Will only complement sequences that are valid DNA. """ @@ -27,7 +37,8 @@ def complement_valid_alleles(df: pl.DataFrame, flip_cols: list[str]) -> pl.DataF return df -def handle_multiallelic(df: pl.DataFrame, remove_multiallelic: bool, pvar: bool) -> pl.DataFrame: +def annotate_multiallelic(df: pl.DataFrame) -> pl.DataFrame: + """ Identify variants that are multiallelic with a column flag """ # plink2 pvar multi-alleles are comma-separated df: pl.DataFrame = (df.with_column( pl.when(pl.col("ALT").str.contains(',')) @@ -35,23 +46,10 @@ def handle_multiallelic(df: pl.DataFrame, remove_multiallelic: bool, pvar: bool) .otherwise(pl.lit(False)) .alias('is_multiallelic'))) - if df['is_multiallelic'].sum() > 0: - logger.debug("Multiallelic variants detected") - if remove_multiallelic: - if not pvar: - logger.warning("--remove_multiallelic requested for bim format, which already contains biallelic " - "variant representations only") - logger.debug('Dropping multiallelic variants') - return df.filter(~df['is_multiallelic']) - else: - logger.debug("Exploding dataframe to handle multiallelic variants") - df.replace('ALT', df['ALT'].str.split(by=',')) # turn ALT to list of variants - return df.explode('ALT') # expand the DF to have all the variants in different rows + if (df.get_column('is_multiallelic')).any(): + logger.debug("Exploding dataframe to handle multiallelic variants") + df.replace('ALT', df['ALT'].str.split(by=',')) # turn ALT to list of variants + return df.explode('ALT') # expand the DF to have all the variants in different rows else: logger.debug("No multiallelic variants detected") return df - - -def _annotate_multiallelic(df: pl.DataFrame) -> pl.DataFrame: - df.with_column( - pl.when(pl.col("ALT").str.contains(',')).then(pl.lit(True)).otherwise(pl.lit(False)).alias('is_multiallelic')) diff --git a/pgscatalog_utils/match/read.py b/pgscatalog_utils/match/read.py index fd1a4c3..e7417f1 100644 --- a/pgscatalog_utils/match/read.py +++ b/pgscatalog_utils/match/read.py @@ -1,104 +1,37 @@ -import glob + import logging -from typing import NamedTuple import polars as pl - -from pgscatalog_utils.match.preprocess import handle_multiallelic, complement_valid_alleles +import pgscatalog_utils.config as config +from pgscatalog_utils.match.preprocess import annotate_multiallelic, complement_valid_alleles, filter_target +from pgscatalog_utils.target import Target logger = logging.getLogger(__name__) -def read_target(path: str, remove_multiallelic: bool, single_file: bool = False, - chrom: str = "") -> pl.DataFrame: - target: Target = _detect_target_format(path) - d = {'column_1': str} # column_1 is always CHROM. CHROM must always be a string - - if single_file: - logger.debug(f"Scanning target genome for chromosome {chrom}") - # scan target and filter to reduce memory usage on big files - df: pl.DataFrame = ( - pl.scan_csv(path, sep='\t', has_header=False, comment_char='#', dtype=d) - .filter(pl.col('column_1') == chrom) - .collect()) - - if df.is_empty(): - logger.warning(f"Chromosome missing from target genome: {chrom}") - return df - else: - logger.debug(f"Reading target {path}") - df: pl.DataFrame = pl.read_csv(path, sep='\t', has_header=False, comment_char='#', dtype=d) - - df.columns = target.header +def read_target(paths: list[str], low_memory: bool) -> pl.LazyFrame: + targets: list[Target] = [Target.from_path(x, low_memory) for x in paths] - match target.file_format: - case 'bim': - return (df.select(_default_cols()) - .filter(pl.col('ID') != '.') # remove missing IDs - .pipe(handle_multiallelic, remove_multiallelic=remove_multiallelic, pvar=False)) - case 'pvar': - return (df.select(_default_cols()) - .filter(pl.col('ID') != '.') - .pipe(handle_multiallelic, remove_multiallelic=remove_multiallelic, pvar=True)) - case _: - logger.error("Invalid file format detected") - raise Exception + logger.debug("Reading all target data complete") + # handling multiallelic requires str methods, so don't forget to cast back or matching will break + return (pl.concat([x.read() for x in targets]) + .pipe(filter_target) + .pipe(annotate_multiallelic) + .with_column(pl.col('ALT').cast(pl.Categorical))).lazy() -def read_scorefile(path: str) -> pl.DataFrame: +def read_scorefile(path: str) -> pl.LazyFrame: logger.debug("Reading scorefile") - scorefile: pl.DataFrame = (pl.read_csv(path, sep='\t', dtype={'chr_name': str}) - .pipe(complement_valid_alleles, flip_cols=['effect_allele', 'other_allele']) - .with_columns([ - pl.col('accession').cast(pl.Categorical), - pl.col("effect_type").cast(pl.Categorical)])) - - return scorefile - - -class Target(NamedTuple): - """ Important summary information about a target genome. Cheap to compute (just reads the header). """ - file_format: str - header: list[str] - - -def _detect_target_format(path: str) -> Target: - file_format: str - header: list[str] - - if "*" in path: - logger.debug("Detecting target file format") - path = glob.glob(path)[0] # guess format from first file in directory - - with open(path, 'rt') as f: - for line in f: - if line.startswith('#'): - logger.debug("pvar format detected") - file_format = 'pvar' - header = _pvar_header(path) - break - else: - logger.debug("bim format detected") - file_format = 'bim' - header = _bim_header() - break - - return Target(file_format, header) - - -def _default_cols() -> list[str]: - return ['#CHROM', 'POS', 'ID', 'REF', 'ALT'] # only columns we want from a target genome - - -def _pvar_header(path: str) -> list[str]: - """ Get the column names from the pvar file (not constrained like bim, especially when converted from VCF) """ - line: str = '#' - with open(path, 'rt') as f: - while line.startswith('#'): - line: str = f.readline() - if line.startswith('#CHROM'): - return line.strip().split('\t') - - -def _bim_header() -> list[str]: - return ['#CHROM', 'ID', 'CM', 'POS', 'REF', 'ALT'] + dtypes = {'chr_name': pl.Categorical, + 'chr_position': pl.UInt64, + 'effect_allele': pl.Utf8, # str functions required to complement + 'other_allele': pl.Utf8, + 'effect_type': pl.Categorical, + 'accession': pl.Categorical} + return (pl.read_csv(path, sep='\t', dtype=dtypes, n_threads=config.POLARS_MAX_THREADS) + .lazy() + .pipe(complement_valid_alleles, flip_cols=['effect_allele', 'other_allele'])).with_columns([ + pl.col("effect_allele").cast(pl.Categorical), + pl.col("other_allele").cast(pl.Categorical), + pl.col("effect_allele_FLIP").cast(pl.Categorical), + pl.col("other_allele_FLIP").cast(pl.Categorical)]) diff --git a/pgscatalog_utils/match/write.py b/pgscatalog_utils/match/write.py index 53eb15f..9d4ba92 100644 --- a/pgscatalog_utils/match/write.py +++ b/pgscatalog_utils/match/write.py @@ -7,18 +7,18 @@ logger = logging.getLogger(__name__) -def write_log(df: pl.DataFrame, prefix: str) -> None: +def write_log(df: pl.LazyFrame, prefix: str) -> None: logger.debug(f"Compressing and writing log: {prefix}_log.csv.gz") with gzip.open(f"{prefix}_log.csv.gz", 'wb') as f: - df.write_csv(f) + df.collect().write_csv(f) -def write_out(df: pl.DataFrame, split: bool, outdir: str, dataset: str) -> None: +def write_out(df: pl.LazyFrame, split: bool, outdir: str, dataset: str) -> None: if not os.path.isdir(outdir): os.mkdir(outdir) logger.debug("Splitting by effect type") - effect_types: dict[str, pl.DataFrame] = _split_effect_type(df) + effect_types: dict[str, pl.DataFrame] = _split_effect_type(df.collect()) logger.debug("Deduplicating variants") deduplicated: dict[str, pl.DataFrame] = {k: _deduplicate_variants(k, v) for k, v in effect_types.items()} @@ -37,9 +37,11 @@ def _write_scorefile(effect_type: str, scorefiles: pl.DataFrame, split: bool, ou for k, v in df_dict.items(): chr = k.replace("false", "ALL") - path: str = os.path.join(outdir, f"{dataset}_{chr}_{effect_type}_{i}.scorefile") + path: str = os.path.join(outdir, f"{dataset}_{chr}_{effect_type}_{i}.scorefile.gz") logger.debug(f"Writing matched scorefile to {path}") - v.write_csv(path, sep="\t") + + with gzip.open(path, 'wb') as f: + v.write_csv(f, sep="\t") def _format_scorefile(df: pl.DataFrame, split: bool) -> dict[str, pl.DataFrame]: diff --git a/pgscatalog_utils/scorefile/combine_scorefiles.py b/pgscatalog_utils/scorefile/combine_scorefiles.py index 5b30fda..bcafa61 100644 --- a/pgscatalog_utils/scorefile/combine_scorefiles.py +++ b/pgscatalog_utils/scorefile/combine_scorefiles.py @@ -1,11 +1,10 @@ import argparse import logging +import os import sys import textwrap -import pandas as pd - -from pgscatalog_utils.log_config import set_logging_level +from pgscatalog_utils.config import set_logging_level from pgscatalog_utils.scorefile.effect_type import set_effect_type from pgscatalog_utils.scorefile.effect_weight import melt_effect_weights from pgscatalog_utils.scorefile.genome_build import build2GRC @@ -25,11 +24,18 @@ def combine_scorefiles(): paths: list[str] = list(set(args.scorefiles)) # unique paths only logger.debug(f"Input scorefiles: {paths}") - scorefiles = [] + if os.path.exists(args.outfile): + logger.critical(f"Output file {args.outfile} already exists") + raise Exception + for x in paths: # Read scorefile df and header h, score = load_scorefile(x) + if score.empty: + logger.critical(f"Empty scorefile {x} detected! Please check the input data") + raise Exception + # Check if we should use the harmonized positions use_harmonised = False current_build = None @@ -65,19 +71,15 @@ def combine_scorefiles(): logger.error("Try running with --liftover and specifying the --chain_dir") raise Exception - scorefiles.append(score) - - if len(scorefiles) > 0: - scorefiles: pd.DataFrame = pd.concat(scorefiles) - else: - logger.error("No valid scorefiles could be combined") - raise Exception + if args.liftover: + logger.debug("Annotating scorefile with liftover parameters") + score = liftover(score, args.chain_dir, args.min_lift, args.target_build) - if args.liftover: - logger.debug("Annotating scorefiles with liftover parameters") - scorefiles = liftover(scorefiles, args.chain_dir, args.min_lift, args.target_build) + if score.empty and (args.drop_missing is False): + logger.critical("Empty output score detected, something went wrong while combining") + raise Exception - write_scorefile(scorefiles, args.outfile) + write_scorefile(score, args.outfile) def _description_text() -> str: diff --git a/pgscatalog_utils/scorefile/write.py b/pgscatalog_utils/scorefile/write.py index 0dd7b38..8a3233b 100644 --- a/pgscatalog_utils/scorefile/write.py +++ b/pgscatalog_utils/scorefile/write.py @@ -1,4 +1,5 @@ import logging +import os import pandas as pd @@ -9,23 +10,29 @@ def write_scorefile(df: pd.DataFrame, path: str) -> None: cols: list[str] = ['chr_name', 'chr_position', 'effect_allele', 'other_allele', 'effect_weight', 'effect_type', 'is_duplicated', 'accession', 'row_nr'] - if df.empty: - logger.error("Empty scorefile output! Please check the input data") - raise Exception + if os.path.exists(path): + logger.debug("Output file exists: setting write mode to append") + write_mode = 'a' + header = False else: - out_df: pd.DataFrame = (df.drop('accession', axis=1) - .rename({'filename_prefix': 'accession'}, axis=1) - .pipe(_filter_failed_liftover)) - - if 'other_allele' not in out_df: - logger.warning("No other allele information detected, writing out as missing data") - out_df['other_allele'] = None - if path.endswith('.gz'): - logger.debug("Writing out gzip-compressed combined scorefile") - out_df[cols].to_csv(path, index=False, sep="\t", compression='gzip') - else: - logger.debug("Writing out combined scorefile") - out_df[cols].to_csv(path, index=False, sep="\t") + logger.debug("Output file doesn't exist: setting write mode to write (create new file)") + write_mode = 'w' + header = True + + out_df: pd.DataFrame = (df.drop('accession', axis=1) + .rename({'filename_prefix': 'accession'}, axis=1) + .pipe(_filter_failed_liftover)) + + if 'other_allele' not in out_df: + logger.warning("No other allele information detected, writing out as missing data") + out_df['other_allele'] = None + + if path.endswith('.gz'): + logger.debug("Writing out gzip-compressed combined scorefile") + out_df[cols].to_csv(path, index=False, sep="\t", compression='gzip', mode=write_mode, header=header) + else: + logger.debug("Writing out combined scorefile") + out_df[cols].to_csv(path, index=False, sep="\t", mode=write_mode, header=header) def _filter_failed_liftover(df: pd.DataFrame) -> pd.DataFrame: diff --git a/pgscatalog_utils/target.py b/pgscatalog_utils/target.py new file mode 100644 index 0000000..3573ee6 --- /dev/null +++ b/pgscatalog_utils/target.py @@ -0,0 +1,221 @@ +import gc +import io +import logging +import os +from dataclasses import dataclass +from itertools import islice +from tempfile import TemporaryDirectory + +import polars as pl +import zstandard + +import pgscatalog_utils.config as config + +logger = logging.getLogger(__name__) + + +@dataclass +class Target: + """ Class to detect and read a plink1/plink2 variant information file """ + file_format: str = None + path: str = None + compressed: bool = False + low_memory: bool = True # targets can be big, and use a lot of RAM when reading + + @classmethod + def from_path(cls, path, low_memory): + """ Create a Target object from a path. Cheaply detect file format and headers. """ + try: + with open(path, 'r') as f: + file_format = _get_format(f) + compressed = False + except UnicodeDecodeError: + logger.error("Can't open target as a text file, so trying to read zstd compressed binary file") + with open(path, 'rb') as f: + dctx = zstandard.ZstdDecompressor() + stream_reader = dctx.stream_reader(f) + text_stream = io.TextIOWrapper(stream_reader, encoding='utf-8') + file_format = _get_format(text_stream) + compressed = True + + return cls(file_format=file_format, path=path, compressed=compressed, low_memory=low_memory) + + # @profile # decorator needed to annotate memory profiles, but will cause NameErrors outside of profiling + def read(self): + if self.low_memory: + if self.compressed: + logger.debug("Reading compressed chunks from target genome (slower, lower RAM usage)") + return self._read_compressed_chunks() + else: + logger.debug("Reading uncompressed chunks from target genome (slower, lower RAM usage)") + return self._read_uncompressed_chunks() + else: + if self.compressed: + logger.debug("Reading compressed target genome (fast mode, high RAM usage)") + return self._read_compressed() + else: + logger.debug("Reading uncompressed target genome (fast mode, high RAM usage)") + return self._read_uncompressed() + + def _read_compressed(self) -> pl.DataFrame: + """ Read a zst compressed target as quickly as possible """ + with open(self.path, 'rb') as fh: + dctx = zstandard.ZstdDecompressor() + with dctx.stream_reader(fh) as reader: + dtypes = _get_col_dtypes(self.file_format) + col_idxs, new_col_names = _default_cols(self.file_format) + return (pl.read_csv(reader, sep='\t', has_header=False, comment_char='#', + dtype=dtypes, + columns=col_idxs, + new_columns=new_col_names, + n_threads=config.POLARS_MAX_THREADS)) + + def _read_uncompressed(self) -> pl.DataFrame: + """ Read an uncompressed target as quickly as possible. Uses up to 16GB RAM on 1000 genomes pvar. """ + dtypes = _get_col_dtypes(self.file_format) + col_idxs, new_col_names = _default_cols(self.file_format) + return (pl.read_csv(self.path, sep='\t', has_header=False, comment_char='#', + dtype=dtypes, + columns=col_idxs, + new_columns=new_col_names, + n_threads=config.POLARS_MAX_THREADS)) + + def _read_uncompressed_chunks(self) -> pl.DataFrame: + """ Read a CSV using a BufferedReader in batches to reduce memory usage. + + Reads 1 million variant chunks and immediately writes to feather format in a temporary directory. + + Read all temporary feather files and return a big pl.DataFrame. Reading feather is fast, and preserves dtypes. + + Uses ~ 2GB + """ + dtypes = _get_col_dtypes(self.file_format) + col_idxs, new_col_names = _default_cols(self.file_format) + with TemporaryDirectory() as temp_dir: + batch_n = 0 + batch_size = int(1e6) + with open(self.path, 'rb') as f: + while True: + line_batch = b''.join(islice(f, batch_size)) + if not line_batch: + break + + out_path = os.path.join(temp_dir, str(batch_n) + '.ipc') + + (pl.read_csv(line_batch, sep='\t', has_header=False, comment_char='#', + dtype=dtypes, + columns=col_idxs, + new_columns=new_col_names, + n_threads=config.POLARS_MAX_THREADS).write_ipc(out_path)) + batch_n += 1 + + gc.collect() # just to be safe + logger.debug(f"{batch_n} batches staged in temporary directory {temp_dir}") + return pl.read_ipc(os.path.join(temp_dir, "*.ipc")) + + def _read_compressed_chunks(self) -> pl.DataFrame: + """ Like _read_uncompressed_chunks, but read chunks of bytes and handle incomplete rows + + zstd returns chunks of bytes, not lines, but encoding utf-8 will be faster in rust and polars + """ + logger.debug("Started reading zstd compressed data") + dtypes = _get_col_dtypes(self.file_format) + columns, new_col_names = _default_cols(self.file_format) + + n_chunks = 0 + + with TemporaryDirectory() as temp_dir: + with open(self.path, 'rb') as fh: + dctx = zstandard.ZstdDecompressor() + chunk_buffer = b'' + + for chunk in dctx.read_to_iter(fh, read_size=int(1e8), write_size=int(1e8)): + if not chunk: + logger.debug("Finished reading zstd compressed chunks") + break + + end = chunk.rfind(b'\n') + 1 # only want to read complete rows, which end in \n + if chunk_buffer: + row_chunk = b''.join([chunk_buffer, chunk[:end]]) + chunk_buffer = b'' + else: + row_chunk = chunk[:end] + + out_path = os.path.join(temp_dir, str(n_chunks) + ".ipc") + (pl.read_csv(row_chunk, sep='\t', has_header=False, comment_char='#', + dtype=dtypes, + columns=columns, + new_columns=new_col_names, + n_threads=config.POLARS_MAX_THREADS) + .write_ipc(out_path)) + + chunk_buffer = b''.join([chunk_buffer, chunk[end:]]) + n_chunks += 1 + + gc.collect() # just to be safe + logger.debug(f"{n_chunks} chunks") # write_size will change n_chunks + return pl.read_ipc(os.path.join(temp_dir, "*.ipc")) + + +def _get_col_dtypes(file_format): + """ Manually set up dtypes to save memory. Repeated strings like REF / ALT / CHROM work best as pl.Categorical. + + ID shouldn't be pl.Categorical, or you'll create a massive string cache and waste RAM """ + match file_format: + case 'bim': + # 1. Chromosome code (either an integer, or 'X'/'Y'/'XY'/'MT'; '0' indicates unknown) or name + # 2. Variant identifier + # 3. Position in morgans or centimorgans (safe to use dummy value of '0') + # 4. Base-pair coordinate (1-based; limited to 231-2) + # 5. Allele 1 (corresponding to clear bits in .bed; usually minor) + # 6. Allele 2 (corresponding to set bits in .bed; usually major) + d = {'column_1': pl.Categorical, 'column_2': pl.Utf8, 'column_3': pl.Float64, 'column_4': pl.UInt64, + 'column_5': pl.Categorical, 'column_6': pl.Utf8} + case 'pvar': + # 1. CHROM + # 2. POS (base-pair coordinate) + # 3. ID (variant ID; required) + # 4. REF (reference allele) + # 5. ALT (alternate alleles, comma-separated) + # 6. QUAL (phred-scaled quality score for whether the locus is variable at all) + # 7. FILTER ('PASS', '.', or semicolon-separated list of failing filter codes) + # 8. INFO (semicolon-separated list of flags and key-value pairs, with types declared in header) + d = {'column_1': pl.Categorical, 'column_2': pl.UInt64, 'column_3': pl.Utf8, 'column_4': pl.Categorical, + 'column_5': pl.Utf8, 'column_6': pl.Float32, 'column_7': pl.Utf8, 'column_8': pl.Utf8} + # can't cast ALT to cat yet, because of multiallelic variants! + case _: + logger.critical("Trying to set header dtypes for an invalid file format, time to explode") + raise Exception + return d + + +def _get_format(fh) -> str: + file_format = None + logger.debug(f"Scanning header to get file format") + for line in fh: + if line.startswith('#'): + logger.debug("pvar format detected") + file_format = 'pvar' + break + else: + logger.debug("bim format detected") + file_format = 'bim' + break + + return file_format + + +def _default_cols(file_format) -> tuple[list[int], list[str]]: + """ Return a list of column integers to keep, assuming plink default column sets """ + match file_format: + case 'bim': + idxs = [0, 1, 3, 4, 5] # see _get_col_dtypes, dropping centimorgans + names = ['#CHROM', 'ID', 'POS', 'REF', 'ALT'] # technically A1/A2, but it's ok + return idxs, names + case 'pvar': + idxs = [0, 1, 2, 3, 4] # dropping QUAL FILTER INFO etc + names = ['#CHROM', 'POS', 'ID', 'REF', 'ALT'] + return idxs, names + case _: + logger.critical("Trying to get column idx for an invalid file format, TWENTY THREE NINETEEN") + raise Exception diff --git a/pgscatalog_utils/validate/__init__.py b/pgscatalog_utils/validate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pgscatalog_utils/validate/common_constants.py b/pgscatalog_utils/validate/common_constants.py new file mode 100644 index 0000000..768752a --- /dev/null +++ b/pgscatalog_utils/validate/common_constants.py @@ -0,0 +1,44 @@ +SNP_DSET = 'rsID' +CHR_DSET = 'chr_name' +BP_DSET = 'chr_position' +EFFECT_DSET = 'effect_allele' +OTH_DSET = 'other_allele' +EFFECT_WEIGHT_DSET = 'effect_weight' + +# Other columns +LOCUS_DSET = 'locus_name' +OR_DSET = 'OR' +HR_DSET = 'HR' +BETA_DSET = 'beta' +FREQ_DSET = 'allelefrequency_effect' +FLAG_INTERACTION_DSET = 'is_interaction' +FLAG_RECESSIVE_DSET = 'is_recessive' +FLAG_HAPLOTYPE_DSET = 'is_haplotype' +FLAG_DIPLOTYPE_DSET = 'is_diplotype' +METHOD_DSET = 'imputation_method' +SNP_DESC_DSET = 'variant_description' +INCLUSION_DSET = 'inclusion_criteria' +DOSAGE_0_WEIGHT = 'dosage_0_weight' +DOSAGE_1_WEIGHT = 'dosage_1_weight' +DOSAGE_2_WEIGHT = 'dosage_2_weight' +# hmPOS +HM_SOURCE_DSET = 'hm_source' +HM_SNP_DSET = 'hm_rsID' +HM_CHR_DSET = 'hm_chr' +HM_BP_DSET = 'hm_pos' +HM_OTH_DSET = 'hm_inferOtherAllele' +HM_MATCH_CHR_DSET = 'hm_match_chr' +HM_MATCH_BP_DSET = 'hm_match_pos' +# hmFinal +VARIANT_DSET = 'variant_id' +HM_CODE_DSET = 'hm_code' +HM_INFO_DSET = 'hm_info' + + +DSET_TYPES = {SNP_DSET: str, CHR_DSET: str, BP_DSET: int, EFFECT_DSET: str, OTH_DSET: str, + EFFECT_WEIGHT_DSET: float, VARIANT_DSET: str, HM_CODE_DSET: int, HM_INFO_DSET: str, LOCUS_DSET: str, OR_DSET: float, HR_DSET: float, BETA_DSET: float, FREQ_DSET: float, + FLAG_INTERACTION_DSET: str, FLAG_RECESSIVE_DSET: str, FLAG_HAPLOTYPE_DSET: str, FLAG_DIPLOTYPE_DSET: str, + METHOD_DSET: str, SNP_DESC_DSET: str, INCLUSION_DSET: str, DOSAGE_0_WEIGHT: float, DOSAGE_1_WEIGHT: float, DOSAGE_2_WEIGHT: float, + HM_SOURCE_DSET:str, HM_SNP_DSET: str, HM_CHR_DSET: str, HM_BP_DSET: int, HM_OTH_DSET: str, HM_MATCH_CHR_DSET: str, HM_MATCH_BP_DSET: int} + +TO_DISPLAY_ORDER = [ SNP_DSET, CHR_DSET, BP_DSET, EFFECT_DSET, OTH_DSET, EFFECT_WEIGHT_DSET, LOCUS_DSET, OR_DSET, HR_DSET, HM_CODE_DSET, HM_INFO_DSET, HM_SOURCE_DSET, HM_SNP_DSET, HM_BP_DSET, HM_OTH_DSET, HM_MATCH_CHR_DSET, HM_MATCH_BP_DSET] \ No newline at end of file diff --git a/pgscatalog_utils/validate/formatted/__init__.py b/pgscatalog_utils/validate/formatted/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pgscatalog_utils/validate/formatted/validator.py b/pgscatalog_utils/validate/formatted/validator.py new file mode 100644 index 0000000..1e42336 --- /dev/null +++ b/pgscatalog_utils/validate/formatted/validator.py @@ -0,0 +1,197 @@ +import gzip +import re +from pandas_schema import Schema +from pgscatalog_utils.validate.schemas import * +from pgscatalog_utils.validate.validator_base import * + +''' +PGS Catalog Harmonized file validator +- using pandas_schema https://github.com/TMiguelT/PandasSchema +''' + +class ValidatorFormatted(ValidatorBase): + + def __init__(self, file, score_dir=None, logfile="VALIDATE.log", error_limit=0): + super().__init__(file, score_dir, logfile, error_limit) + self.score_dir=None + self.meta_format = FORMATTED_META_GENERIC + self.schema_validators = FORMATTED_VALIDATORS + self.valid_cols = VALID_COLS_FORMATTED + self.valid_type = VALID_TYPE_FORMATTED + self.setup_field_validation() + + + def extract_specific_metadata(self,line): + ''' Extract some of the metadata. ''' + match_variants_number = re.search(r'#variants_number=(\d+)', line) + if match_variants_number: + self.variants_number = int(match_variants_number.group(1)) + + + def get_and_check_variants_number(self): + ''' Verify that the number of variant lines corresponds to the number of variants in the headers ''' + variant_lines = 0 + + with gzip.open( self.file, 'rb') as f: + line_number = 0 + for line in f: + line_number += 1 + line = line.decode('utf-8').rstrip() + if line.startswith('#'): + match_variants_number = re.search(r'#variants_number=(\d+)', line) + if match_variants_number: + self.variants_number = int(match_variants_number.group(1)) + else: + variant_lines += 1 + if re.search(r'\w+', line): # Line not empty + cols = line.split(self.sep) + has_trailing_spaces = self.check_leading_trailing_spaces(cols,line_number) + if has_trailing_spaces: + self.global_errors += 1 + else: + self.logger.error(f'- Line {line_number} is empty') + self.global_errors += 1 + + if self.variants_number: + variant_lines -= 1 # Remove the header line from the count + if self.variants_number != variant_lines: + self.logger.error(f'- The number of variants lines in the file ({variant_lines}) and the number of variants declared in the headers ({self.variants_number}) are different') + self.global_errors += 1 + else: + self.logger.error("- Can't retrieve the number of variants from the headers") + self.global_errors += 1 + + + def detect_duplicated_rows(self,dataframe_chunk): + ''' Detect duplicated rows in the scoring file. ''' + # Columns of interest to compare the different rows + cols_sel = [] + for col in ['rsID','chr_name','chr_position','effect_allele','other_allele']: + if col in self.cols_to_validate: + cols_sel.append(col) + + duplicate_status = dataframe_chunk.duplicated(cols_sel) + if any(duplicate_status): + duplicated_rows = dataframe_chunk[duplicate_status] + self.logger.error(f'Duplicated row(s) found: {len(duplicated_rows.index)}\n\t-> {duplicated_rows.to_string(header=False,index=False)}') + self.global_errors += 1 + for index in duplicated_rows.index: + self.bad_rows.append(index) + + + def validate_data(self) -> bool: + ''' Validate the file: data format and data content ''' + self.logger.info("Validating data...") + if not self.open_file_and_check_for_squareness(): + self.logger.error("Please fix the table. Some rows have different numbers of columns to the header") + self.logger.info("Rows with different numbers of columns to the header are not validated") + # Check the consitence between the declared variants number and the actual number of variants in the file + self.get_and_check_variants_number() + + for chunk in self.df_iterator(self.file): + dataframe_to_validate = chunk[self.cols_to_read] + dataframe_to_validate.columns = self.cols_to_validate # sets the headers to standard format if neeeded + + # Detect duplicated rows + self.detect_duplicated_rows(dataframe_to_validate) + + # validate the snp column if present + if SNP_DSET in self.header: + sub_schema = FORMATTED_VALIDATORS_SNP + if CHR_DSET and BP_DSET in self.header: + sub_schema = FORMATTED_VALIDATORS_SNP_EMPTY + self.validate_schema(sub_schema,dataframe_to_validate) + + if CHR_DSET and BP_DSET in self.header: + self.validate_schema(FORMATTED_VALIDATORS_POS, dataframe_to_validate) + + if OR_DSET in self.header: + self.validate_schema(FORMATTED_VALIDATORS_OR,dataframe_to_validate) + + if HR_DSET in self.header: + self.validate_schema(FORMATTED_VALIDATORS_HR,dataframe_to_validate) + + self.process_errors() + if len(self.bad_rows) >= self.error_limit: + break + if not self.bad_rows and not self.global_errors: + if self.is_file_valid(): + self.logger.info("File is valid") + else: + self.logger.info("File is invalid") + else: + self.logger.info("File is invalid - {} bad rows, limit set to {}".format(len(self.bad_rows), self.error_limit)) + self.set_file_is_invalid() + return self.is_file_valid() + + + def validate_filename(self) -> bool: + ''' Validate the file name structure. ''' + self.logger.info("Validating file name...") + filename = self.file.split('/')[-1].split('.')[0] + is_valid_filename = True + if not re.match(r'^PGS\d{6}$', filename): + self.logger.info("Invalid filename: {}".format(self.file)) + self.logger.error("Filename: {} should follow the pattern 'PGSXXXXXX.txt.gz', where the 'X' are the 6 digits of the PGS identifier (e.g. PGS000001)".format(filename)) + is_valid_filename = False + self.set_file_is_invalid() + + return is_valid_filename + + + def validate_headers(self) -> bool: + ''' Validate the list of column names. ''' + self.logger.info("Validating headers...") + self.detect_genomebuild_with_rsid() + required_is_subset = set(STD_COLS_VAR_FORMATTED).issubset(self.header) + if not required_is_subset: + # check if everything but snp: + required_is_subset = set(CHR_COLS_VAR_FORMATTED).issubset(self.header) + if not required_is_subset: + required_is_subset = set(SNP_COLS_VAR_FORMATTED).issubset(self.header) + if not required_is_subset: + self.logger.error("Required headers: {} are not in the file header: {}".format(STD_COLS_VAR_FORMATTED, self.header)) + + # Check if at least one of the effect columns is there + has_effect_col = 0 + for col in STD_COLS_EFFECT_FORMATTED: + if set([col]).issubset(self.header): + has_effect_col = 1 + break + if not has_effect_col: + self.logger.error("Required headers: at least one of the columns '{}' must be in the file header: {}".format(STD_COLS_EFFECT_FORMATTED, self.header)) + required_is_subset = None + + if not required_is_subset: + self.logger.info("Invalid headers...exiting before any further checks") + self.set_file_is_invalid() + + return required_is_subset + + + def detect_genomebuild_with_rsid(self): + ''' The column "rsID" should always be in the scoring file when the genome build is not reported (i.e. "NR") ''' + self.get_genomebuild() + if self.genomebuild == 'NR': + if SNP_DSET not in self.header: + self.logger.error(f"- The combination: Genome Build = '{self.genomebuild}' & the missing column '{SNP_DSET}' in the header is not allowed as we have to manually guess the genome build.") + self.global_errors += 1 + + + def get_genomebuild(self): + ''' Retrieve the Genome Build from the comments ''' + with gzip.open(self.file, 'rb') as f_in: + for f_line in f_in: + line = f_line.decode() + # Update header + if line.startswith('#genome_build'): + gb = (line.split('='))[1] + self.genomebuild = gb.strip() + return + + +################################################################## + +def init_validator(file, logfile, score_dir=None) -> ValidatorFormatted: + validator = ValidatorFormatted(file=file, score_dir=score_dir, logfile=logfile) + return validator \ No newline at end of file diff --git a/pgscatalog_utils/validate/harmonized_position/__init__.py b/pgscatalog_utils/validate/harmonized_position/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pgscatalog_utils/validate/harmonized_position/validator.py b/pgscatalog_utils/validate/harmonized_position/validator.py new file mode 100644 index 0000000..87b9346 --- /dev/null +++ b/pgscatalog_utils/validate/harmonized_position/validator.py @@ -0,0 +1,98 @@ +import re +from pgscatalog_utils.validate.schemas import * +from pgscatalog_utils.validate.validator_base import * + +''' +PGS Catalog Harmonized file validator +- using pandas_schema https://github.com/TMiguelT/PandasSchema +''' + +class ValidatorPos(ValidatorBase): + ''' Validator for the HmPOS Harmonized file format. ''' + + def __init__(self, file, score_dir=None, logfile="VALIDATE.log", error_limit=0): + super().__init__(file, score_dir, logfile, error_limit) + self.meta_format = HM_META_POS + self.schema_validators = POS_VALIDATORS + self.valid_cols = VALID_COLS_POS + self.valid_type = VALID_TYPE_POS + self.setup_field_validation() + + + def extract_specific_metadata(self,line): + ''' Extract some of the metadata. ''' + match_variants_number = re.search(r'#variants_number=(\d+)', line) + if match_variants_number: + self.variants_number = int(match_variants_number.group(1)) + + + def validate_line_content(self,cols_content,var_line_number): + ''' Populate the abstract method from ValidatorBase, to check some data in esch row. ''' + # Check lines + line_dict = dict(zip(self.header, cols_content)) + line_cols = line_dict.keys() + # Check each chromosome data is consistent + chr_cols = ['chr_name', 'hm_chr', 'hm_match_chr'] + if all(col_name in line_cols for col_name in chr_cols): + if line_dict['chr_name'] == line_dict['hm_chr'] and line_dict['hm_match_chr'] != 'True': + self.logger.error(f"- Variant line {var_line_number} | 'hm_match_chr' should be 'True': same chromosome ('chr_name={line_dict['chr_name']}' vs 'hm_chr={line_dict['hm_chr']}')") + # Check each position data is consistent + pos_cols = ['chr_position', 'hm_pos', 'hm_match_pos'] + if all(col_name in line_cols for col_name in pos_cols): + if line_dict['chr_position'] == line_dict['hm_pos'] and line_dict['hm_match_pos'] != 'True': + self.logger.error(f"- Variant line {var_line_number} | 'hm_match_pos' should be 'True': same position ('chr_position={line_dict['chr_position']}' vs 'hm_pos={line_dict['hm_pos']}')") + + + def validate_filename(self) -> bool: + ''' Validate the file name structure. ''' + self.logger.info("Validating file name...") + pgs_id, build = None, None + is_valid_filename = True + # hmPOS + filename = self.file.split('/')[-1].split('.')[0] + filename_parts = filename.split('_hmPOS_') + if len(filename_parts) != 2: + self.logger.error("Filename: {} should follow the pattern _hmPOS_.txt.gz [build=GRChXX]".format(filename)) + self.set_file_is_invalid() + is_valid_filename = False + else: + pgs_id, build = filename_parts + self.file_pgs_id = pgs_id + self.file_genomebuild = build + if not self.check_build_is_legit(build): + self.logger.error("Build: {} is not an accepted build value".format(build)) + self.set_file_is_invalid() + is_valid_filename = False + + return is_valid_filename + + + def validate_headers(self) -> bool: + ''' Validate the list of column names. ''' + self.logger.info("Validating headers...") + # Check if it has at least a "SNP" column or a "chromosome" column + required_is_subset = set(STD_COLS_VAR_POS).issubset(self.header) + if not required_is_subset: + self.logger.error("Required headers: {} are not in the file header: {}".format(STD_COLS_VAR_POS, self.header)) + + # Check if it has at least a "SNP" column or a "chromosome" column + required_pos = set(SNP_COLS_VAR_POS).issubset(self.header) + if not required_pos: + # check if everything but snp: + required_pos = set(CHR_COLS_VAR_POS).issubset(self.header) + if not required_pos: + self.logger.error("One of the following required header is missing: '{}' and/or '{}' are not in the file header: {}".format(SNP_COLS_VAR_POS, CHR_COLS_VAR_POS, self.header)) + required_is_subset = required_pos + + if not required_is_subset: + self.logger.info("Invalid headers...exiting before any further checks") + self.set_file_is_invalid() + + return required_is_subset + + +################################################################## + +def init_validator(file, logfile, score_dir=None) -> ValidatorPos: + validator = ValidatorPos(file=file, score_dir=score_dir, logfile=logfile) + return validator \ No newline at end of file diff --git a/pgscatalog_utils/validate/helpers.py b/pgscatalog_utils/validate/helpers.py new file mode 100644 index 0000000..7d786e5 --- /dev/null +++ b/pgscatalog_utils/validate/helpers.py @@ -0,0 +1,29 @@ +import math +import pandas as pd +from pandas_schema.validation import _SeriesValidation + + +class InInclusiveRangeValidation(_SeriesValidation): + """ + Checks that each element in the series is within a given inclusive numerical range. + Doesn't care if the values are not numeric - it will try anyway. + """ + def __init__(self, min: float = -math.inf, max: float = math.inf, **kwargs): + """ + :param min: The minimum (inclusive) value to accept + :param max: The maximum (inclusive) value to accept + """ + self.min = min + self.max = max + super().__init__(**kwargs) + + @property + def default_message(self): + return 'was not in the range [{}, {})'.format(self.min, self.max) + + def validate(self, series: pd.Series) -> pd.Series: + series = pd.to_numeric(series, errors='coerce') + return (series >= self.min) & (series <= self.max) + + + diff --git a/pgscatalog_utils/validate/schemas.py b/pgscatalog_utils/validate/schemas.py new file mode 100644 index 0000000..43e8e27 --- /dev/null +++ b/pgscatalog_utils/validate/schemas.py @@ -0,0 +1,157 @@ +import numpy as np +from pandas_schema import Column +from pandas_schema.validation import MatchesPatternValidation, InListValidation, CanConvertValidation, LeadingWhitespaceValidation, TrailingWhitespaceValidation, CustomElementValidation +from pgscatalog_utils.validate.helpers import InInclusiveRangeValidation +from pgscatalog_utils.validate.common_constants import * + + +#### Validation types #### + +VALID_TYPE_FORMATTED = 'formatted' +VALID_TYPE_POS = 'hm_pos' + + +#### Columns #### + +# Formatted scoring files +STD_COLS_VAR_FORMATTED = (EFFECT_DSET, CHR_DSET, BP_DSET, SNP_DSET) #OR_DSET, RANGE_L_DSET, RANGE_U_DSET, BETA_DSET, SE_DSET, FREQ_DSET , EFFECT_DSET, OTH_DSET) + +SNP_COLS_VAR_FORMATTED = (EFFECT_DSET, CHR_DSET, BP_DSET) +CHR_COLS_VAR_FORMATTED = (EFFECT_DSET, SNP_DSET) + +STD_COLS_EFFECT_FORMATTED = (EFFECT_WEIGHT_DSET,OR_DSET,HR_DSET) + +VALID_COLS_FORMATTED = (EFFECT_WEIGHT_DSET, OR_DSET, HR_DSET, BETA_DSET, FREQ_DSET, LOCUS_DSET, EFFECT_DSET, OTH_DSET, CHR_DSET, BP_DSET, SNP_DSET) + +# Harmonized scoring files - POS +STD_COLS_VAR_POS = (HM_SOURCE_DSET, HM_CHR_DSET, HM_BP_DSET) + +SNP_COLS_VAR_POS = (SNP_DSET, HM_SNP_DSET) +CHR_COLS_VAR_POS = (CHR_DSET,) + +VALID_COLS_POS = (HM_SOURCE_DSET, HM_SNP_DSET, HM_CHR_DSET, HM_BP_DSET, HM_OTH_DSET, HM_MATCH_CHR_DSET, HM_MATCH_BP_DSET) + +# Harmonized scoring files - Final +STD_COLS_VAR_FINAL = (EFFECT_DSET, EFFECT_WEIGHT_DSET, HM_CODE_DSET, HM_INFO_DSET) + +SNP_COLS_VAR_FINAL = (VARIANT_DSET,) +CHR_COLS_VAR_FINAL = (CHR_DSET, HM_CHR_DSET) + +VALID_COLS_FINAL = (SNP_DSET, CHR_DSET, BP_DSET, EFFECT_DSET, OTH_DSET, EFFECT_WEIGHT_DSET, LOCUS_DSET, HM_CODE_DSET, HM_SNP_DSET, HM_CHR_DSET, HM_BP_DSET, HM_OTH_DSET, HM_MATCH_CHR_DSET, HM_MATCH_BP_DSET) + + +#### Global variables #### + +VALID_CHROMOSOMES = ['1', '2', '3', '4', '5', '6', '7', '8', + '9', '10', '11', '12', '13', '14', '15', '16', + '17', '18', '19', '20', '21', '22', + 'X', 'x', 'Y', 'y', 'XY', 'xy', 'MT', 'Mt', 'mt'] + +VALID_FILE_EXTENSIONS = [".txt", ".txt.gz"] + +# For the harmonized files +VALID_SOURCES = ['ENSEMBL','Author-reported'] +# VALID_CODES = ['5','4','3','1','0','-1','-4','-5'] +BUILD_LIST = ['GRCh37','GRCh38'] + + +error_msg = 'this column cannot be null/empty' +null_validation = CustomElementValidation(lambda d: d is not np.nan and d != '', error_msg) + + +#### Validators #### + +# Generic/shared validators +GENERIC_VALIDATORS = { + CHR_DSET: Column(CHR_DSET, [InListValidation(VALID_CHROMOSOMES)], allow_empty=True), + BP_DSET: Column(BP_DSET, [CanConvertValidation(DSET_TYPES[BP_DSET]), InInclusiveRangeValidation(1, 999999999)], allow_empty=True), + EFFECT_WEIGHT_DSET: Column(EFFECT_WEIGHT_DSET, [CanConvertValidation(DSET_TYPES[EFFECT_WEIGHT_DSET]), null_validation], allow_empty=False), + EFFECT_DSET: Column(EFFECT_DSET, [MatchesPatternValidation(r'^[ACTGN\-]+$')], allow_empty=False), + OTH_DSET: Column(OTH_DSET, [MatchesPatternValidation(r'^[ACTGN\-]+$')], allow_empty=True), + LOCUS_DSET: Column(LOCUS_DSET, [CanConvertValidation(DSET_TYPES[LOCUS_DSET]), LeadingWhitespaceValidation(), TrailingWhitespaceValidation(), null_validation], allow_empty=True) +} + +# Formatted validators +FORMATTED_VALIDATORS = {k:v for k,v in GENERIC_VALIDATORS.items()} +FORMATTED_VALIDATORS[SNP_DSET] = Column(SNP_DSET, [CanConvertValidation(DSET_TYPES[SNP_DSET]), MatchesPatternValidation(r'^(rs|HLA\-\w+\*)[0-9]+$')], allow_empty=True) +FORMATTED_VALIDATORS[OR_DSET] = Column(OR_DSET, [CanConvertValidation(DSET_TYPES[OR_DSET]), null_validation], allow_empty=True) +FORMATTED_VALIDATORS[HR_DSET] = Column(HR_DSET, [CanConvertValidation(DSET_TYPES[HR_DSET]), null_validation], allow_empty=True) +FORMATTED_VALIDATORS[BETA_DSET] = Column(BETA_DSET, [CanConvertValidation(DSET_TYPES[BETA_DSET]), null_validation], allow_empty=True) +FORMATTED_VALIDATORS[FREQ_DSET] = Column(FREQ_DSET, [CanConvertValidation(DSET_TYPES[FREQ_DSET]), null_validation], allow_empty=True) +FORMATTED_VALIDATORS[DOSAGE_0_WEIGHT] = Column(DOSAGE_0_WEIGHT, [CanConvertValidation(DSET_TYPES[DOSAGE_0_WEIGHT]), null_validation], allow_empty=True) +FORMATTED_VALIDATORS[DOSAGE_1_WEIGHT] = Column(DOSAGE_1_WEIGHT, [CanConvertValidation(DSET_TYPES[DOSAGE_1_WEIGHT]), null_validation], allow_empty=True) +FORMATTED_VALIDATORS[DOSAGE_2_WEIGHT] = Column(DOSAGE_2_WEIGHT, [CanConvertValidation(DSET_TYPES[DOSAGE_2_WEIGHT]), null_validation], allow_empty=True) + +FORMATTED_VALIDATORS_SNP = {k:v for k,v in FORMATTED_VALIDATORS.items()} +FORMATTED_VALIDATORS_SNP[SNP_DSET] = Column(SNP_DSET, [CanConvertValidation(DSET_TYPES[SNP_DSET]), MatchesPatternValidation(r'^(rs|HLA\-\w+\*)[0-9]+$')], allow_empty=False) + +FORMATTED_VALIDATORS_SNP_EMPTY = {k:v for k,v in FORMATTED_VALIDATORS.items()} +FORMATTED_VALIDATORS_SNP_EMPTY[SNP_DSET] = Column(SNP_DSET, [CanConvertValidation(DSET_TYPES[SNP_DSET]), MatchesPatternValidation(r'^(rs[0-9]+|HLA\-\w+\*[0-9]+|nan)$')], allow_empty=False) +FORMATTED_VALIDATORS_SNP_EMPTY[CHR_DSET] = Column(CHR_DSET, [InListValidation(VALID_CHROMOSOMES)], allow_empty=False) +FORMATTED_VALIDATORS_SNP_EMPTY[BP_DSET] = Column(BP_DSET, [CanConvertValidation(DSET_TYPES[BP_DSET]), InInclusiveRangeValidation(1, 999999999)], allow_empty=False) + +FORMATTED_VALIDATORS_POS = {k:v for k,v in FORMATTED_VALIDATORS.items()} +FORMATTED_VALIDATORS_POS[CHR_DSET] = Column(CHR_DSET, [InListValidation(VALID_CHROMOSOMES)], allow_empty=False) +FORMATTED_VALIDATORS_POS[BP_DSET] = Column(BP_DSET, [CanConvertValidation(DSET_TYPES[BP_DSET]), InInclusiveRangeValidation(1, 999999999)], allow_empty=False) + +FORMATTED_VALIDATORS_OR = {k:v for k,v in FORMATTED_VALIDATORS.items()} +FORMATTED_VALIDATORS_OR[OR_DSET] = Column(OR_DSET, [CanConvertValidation(DSET_TYPES[OR_DSET])], allow_empty=False) + +FORMATTED_VALIDATORS_HR = {k:v for k,v in FORMATTED_VALIDATORS.items()} +FORMATTED_VALIDATORS_HR[HR_DSET] = Column(HR_DSET, [CanConvertValidation(DSET_TYPES[HR_DSET])], allow_empty=False) + +# Position validators +POS_VALIDATORS = {} +POS_VALIDATORS[HR_DSET] = Column(HR_DSET, [CanConvertValidation(DSET_TYPES[HR_DSET]), null_validation], allow_empty=True) +POS_VALIDATORS[HM_SOURCE_DSET] = Column(HM_SOURCE_DSET, [CanConvertValidation(DSET_TYPES[HM_SOURCE_DSET]), InListValidation(VALID_SOURCES), LeadingWhitespaceValidation(), TrailingWhitespaceValidation(), null_validation], allow_empty=False) +POS_VALIDATORS[HM_SNP_DSET] = Column(HM_SNP_DSET, [CanConvertValidation(DSET_TYPES[HM_SNP_DSET]), MatchesPatternValidation(r'^(rs|HLA\-\w+\*)[0-9]+$')], allow_empty=True) +POS_VALIDATORS[HM_CHR_DSET] = Column(HM_CHR_DSET, [InListValidation(VALID_CHROMOSOMES)], allow_empty=True) +POS_VALIDATORS[HM_BP_DSET] = Column(HM_BP_DSET, [CanConvertValidation(DSET_TYPES[HM_BP_DSET]), InInclusiveRangeValidation(1, 999999999)], allow_empty=True) +POS_VALIDATORS[HM_OTH_DSET] = Column(HM_OTH_DSET, [MatchesPatternValidation(r'^[ACTGN\-\/]+$')], allow_empty=True) +POS_VALIDATORS[HM_MATCH_CHR_DSET] = Column(HM_MATCH_CHR_DSET, [InListValidation(['True', 'False'])], allow_empty=True) +POS_VALIDATORS[HM_MATCH_BP_DSET] = Column(HM_MATCH_BP_DSET, [InListValidation(['True', 'False'])], allow_empty=True) + +# Final validator +# FINAL_VALIDATORS = {k:v for k,v in GENERIC_VALIDATORS.items()} +# FINAL_VALIDATORS[EFFECT_DSET] = Column(EFFECT_DSET, [MatchesPatternValidation(r'^[ACTGN\-]+$')], allow_empty=True) +# FINAL_VALIDATORS[OTH_DSET] = Column(OTH_DSET, [MatchesPatternValidation(r'^[ACTGN\-\.]+$')], allow_empty=True) +# FINAL_VALIDATORS[VARIANT_DSET] = Column(VARIANT_DSET, [CanConvertValidation(DSET_TYPES[VARIANT_DSET]), MatchesPatternValidation(r'^((rs|HLA\-\w+\*)[0-9]+|\.)$')], allow_empty=True) +# FINAL_VALIDATORS[HM_CODE_DSET] = Column(HM_CODE_DSET, [InListValidation(VALID_CODES), null_validation], allow_empty=True) +# FINAL_VALIDATORS[HM_INFO_DSET] = Column(HM_INFO_DSET, [CanConvertValidation(DSET_TYPES[HM_INFO_DSET]), null_validation], allow_empty=True) + + +#### Metadata entries #### + +FORMATTED_META_GENERIC = [ + '###PGS CATALOG SCORING FILE', + '#format_version', + '##POLYGENIC SCORE', + '#pgs_id', + '#pgs_name', + '#trait_reported', + '#trait_mapped', + '#trait_efo', + '#genome_build', + '#variants_number', + '#weight_type', + '##SOURCE INFORMATION', + '#pgp_id', + '#citation' +] + +HM_META_GENERIC = [ x for x in FORMATTED_META_GENERIC ] +HM_META_GENERIC.append('##HARMONIZATION DETAILS') + +HM_META_POS = [ x for x in HM_META_GENERIC ] +HM_META_POS.append('#HmPOS_build') +HM_META_POS.append('#HmPOS_date') +HM_META_POS.append('#HmPOS_match_chr') +HM_META_POS.append('#HmPOS_match_pos') + +# HM_META_FINAL = [ x for x in HM_META_GENERIC ] +# HM_META_FINAL.append('#Hm_file_version') +# HM_META_FINAL.append('#Hm_genome_build') +# HM_META_FINAL.append('#Hm_reference_source') +# HM_META_FINAL.append('#Hm_creation_date') +# HM_META_FINAL.append('#Hm_variants_number_matched') +# HM_META_FINAL.append('#Hm_variants_number_unmapped') \ No newline at end of file diff --git a/pgscatalog_utils/validate/validate_scorefile.py b/pgscatalog_utils/validate/validate_scorefile.py new file mode 100644 index 0000000..80294c3 --- /dev/null +++ b/pgscatalog_utils/validate/validate_scorefile.py @@ -0,0 +1,171 @@ +import os, glob, re +import argparse +import logging +import textwrap + +data_sum = {'valid': [], 'invalid': [], 'other': []} + +val_types = ('formatted', 'hm_pos') + +logging.basicConfig(level=logging.INFO, format='(%(levelname)s): %(message)s') + + +def validate_scorefile() -> None: + global data_sum, score_dir + args = _parse_args() + _check_args(args) + + # Check PGS Catalog file name nomenclature + check_filename = False + if args.check_filename: + check_filename = True + else: + print("WARNING: the parameter '--check_filename' is not present in the submitted command line, therefore the validation of the scoring file name(s) won't be performed.") + + validator_type = args.t + files_dir = args.dir + log_dir = args.log_dir + + ## Select validator class ## + if validator_type == 'formatted': + import pgscatalog_utils.validate.formatted.validator as validator_package + elif validator_type == 'hm_pos': + import pgscatalog_utils.validate.harmonized_position.validator as validator_package + + ## Run validator ## + # One file + if args.f: + _run_validator(args.f,log_dir,score_dir,validator_package,check_filename,validator_type) + # Content of the directory + elif files_dir: + count_files = 0 + # Browse directory: for each file run validator + for filepath in sorted(glob.glob(files_dir+"/*.*")): + _run_validator(filepath,log_dir,score_dir,validator_package,check_filename,validator_type) + count_files += 1 + + # Print summary + results + print("\nSummary:") + if data_sum['valid']: + print(f"- Valid: {len(data_sum['valid'])}/{count_files}") + if data_sum['invalid']: + print(f"- Invalid: {len(data_sum['invalid'])}/{count_files}") + if data_sum['other']: + print(f"- Other issues: {len(data_sum['other'])}/{count_files}") + + if data_sum['invalid']: + print("Invalid files:") + print("\n".join(data_sum['invalid'])) + + +def _read_last_line(file: str) -> str: + ''' + Return the last line of the file + ''' + fileHandle = open ( file,"r" ) + lineList = fileHandle.readlines() + fileHandle.close() + return lineList[-1] + + +def _file_validation_state(filename: str, log_file: str) -> None: + global data_sum + if os.path.exists(log_file): + log_result = _read_last_line(log_file) + if re.search("File is valid", log_result): + print("> valid\n") + data_sum['valid'].append(filename) + elif re.search("File is invalid", log_result): + print("#### invalid! ####\n") + data_sum['invalid'].append(filename) + else:# + print("!! validation process had an issue. Please look at the logs.\n") + data_sum['other'].append(filename) + else: + print("!! validation process had an issue: the log file can't be found") + data_sum['other'].append(filename) + + +def _check_args(args: argparse.Namespace) -> None: + global score_dir + + ## Check parameters ## + # Type of validator + if args.t not in val_types: + print(f"Error: Validator type (option -t) '{args.t}' is not in the list of recognized types: {val_types}.") + exit(1) + # Logs dir + if not os.path.isdir(args.log_dir): + print(f"Error: Log dir '{args.log_dir}' can't be found!") + exit(1) + # File and directory parameters (only one of the '-f' and '--dir' can be used) + if args.f and args.dir: + print("Error: you can't use both options [-f] - single scoring file and [--dir] - directory of scoring files. Please use only 1 of these 2 options!") + exit(1) + elif not args.f and not args.dir: + print("Error: you need to provide a scoring file [-f] or a directory of scoring files [--dir]!") + exit(1) + elif args.f and not os.path.isfile(args.f): + print(f"Error: Scoring file '{args.f}' can't be found!") + exit(1) + elif args.dir and not os.path.isdir(args.dir): + print(f"Error: the scoring file directory '{args.dir}' can't be found!") + exit(1) + # Scoring files directory (only to compare with the harmonized files) + score_dir = None + if args.score_dir: + score_dir = args.score_dir + if not os.path.isdir(score_dir): + print(f"Error: Scoring file directory '{score_dir}' can't be found!") + exit(1) + elif args.t != 'formatted': + print("WARNING: the parameter '--score_dir' is not present in the submitted command line, therefore the comparison of the number of data rows between the formatted scoring file(s) and the harmonized scoring file(s) won't be performed.") + + +def _run_validator(filepath: str, log_dir: str, score_dir: str, validator_package: object, check_filename: bool, validator_type: str) -> None: + ''' Run the file validator ''' + file = os.path.basename(filepath) + filename = file.split('.')[0] + print(f"# Filename: {file}") + log_file = f'{log_dir}/{filename}_log.txt' + + # Run validator + validator = validator_package.init_validator(filepath,log_file,score_dir) + if check_filename: + validator.run_validator() + else: + validator.run_validator_skip_check_filename() + + # Check log + _file_validation_state(file,log_file) + + +def _description_text() -> str: + return textwrap.dedent('''\ + Validate a set of scoring files to match the PGS Catalog scoring file formats. + It can validate: + - The formatted scoring file format (https://www.pgscatalog.org/downloads/#dl_ftp_scoring) + - The harmonized (Position) scoring file format (https://www.pgscatalog.org/downloads/#dl_ftp_scoring_hm_pos) + ''') + + +def _epilog_text() -> str: + return textwrap.dedent(f'''\ + You need to specify the type of file format to validate, using the paramter '-t' ({' or '.join(val_types)}). + ''') + + +def _parse_args(args=None) -> argparse.Namespace: + parser = argparse.ArgumentParser(description=_description_text(), epilog=_epilog_text(), + formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("-t", help=f"Type of validator: {' or '.join(val_types)}", metavar='VALIDATOR_TYPE') + parser.add_argument("-f", help='The path to the polygenic scoring file to be validated (no need to use the [--dir] option)', metavar='SCORING_FILE_NAME') + parser.add_argument('--dir', help='The name of the directory containing the files that need to processed (no need to use the [-f] option') + parser.add_argument('--score_dir', help=' The name of the directory containing the formatted scoring files to compare with harmonized scoring files') + parser.add_argument('--log_dir', help='The name of the log directory where the log file(s) will be stored', required=True) + parser.add_argument('--check_filename', help=' Check that the file name match the PGS Catalog nomenclature', required=False, action='store_true') + return parser.parse_args(args) + + +if __name__ == '__main__': + validate_scorefile() diff --git a/pgscatalog_utils/validate/validator_base.py b/pgscatalog_utils/validate/validator_base.py new file mode 100644 index 0000000..ddfbc59 --- /dev/null +++ b/pgscatalog_utils/validate/validator_base.py @@ -0,0 +1,429 @@ +import os, sys, gc +import gzip +import csv +import pathlib +import logging +import re +from typing import List +import pandas as pd +import pandas_schema +import warnings +from pgscatalog_utils.validate.schemas import * + +''' +PGS Catalog file validator +- using pandas_schema https://github.com/TMiguelT/PandasSchema +''' + + +csv.field_size_limit(sys.maxsize) + +class ValidatorBase: + + valid_extensions = VALID_FILE_EXTENSIONS + schema_validators = GENERIC_VALIDATORS + valid_cols = [] + valid_type = '' + sep = '\t' + + def __init__(self, file, score_dir=None, logfile="VALIDATE.log", error_limit=0): + self.file = file + self.score_dir = score_dir + self.schema = None + self.header = [] + self.genomebuild = None + self.comment_lines_count = 1 # Counting the header line + self.cols_to_validate = [] + self.cols_to_read = [] + self.bad_rows = [] + self.row_errors = [] + self.errors_seen = {} + self.logfile = logfile + self.error_limit = int(error_limit) + self.is_valid = True + + # Logging variables + self.logger = logging.getLogger(__name__) + self.handler = logging.FileHandler(self.logfile, 'w+') + self.handler.setLevel(logging.INFO) + self.logger.addHandler(self.handler) + self.logger.propagate = False + + self.global_errors = 0 + self.variants_number = 0 + + + def validate_schema(self, schema: dict, dataframe_to_validate: pd.core.frame.DataFrame): + ''' + Run the pandas_schema validation using the provided Schema and DataFrame + ''' + self.schema = pandas_schema.Schema([schema[h] for h in self.cols_to_validate]) + with warnings.catch_warnings(): + # Ignore python warningd raised in the pandas_schema code + warnings.simplefilter('ignore', UserWarning) + errors = self.schema.validate(dataframe_to_validate) + self.store_errors(errors) + + + def setup_field_validation(self): + ''' + Fetch the header and build the list of column to check/validate + ''' + self.header = self.get_header() + self.cols_to_validate = [h for h in self.header if h in self.valid_cols] + self.cols_to_read = [h for h in self.header if h in self.valid_cols] + + + def get_header(self): + ''' + Fetch the header (i.e. column names) information from the harmonized scoring file and store the list in a variable + ''' + first_row = pd.read_csv(self.file, sep=self.sep, comment='#', nrows=1, index_col=False) + # Check if the column headers have leading and/or trailing spaces + # The leading/trailing spaces should raise an error during the header validation + has_trailing_spaces = self.check_leading_trailing_spaces(first_row.columns.values) + if has_trailing_spaces: + self.global_errors += 1 + return first_row.columns.values + + + def get_genomebuild(self): + ''' Retrieve the Genome Build from the comments ''' + if self.valid_type == 'hm_pos': + self.genomebuild = self.get_comments_info('#HmPOS_build') + else: + self.genomebuild = self.get_comments_info('#Hm_genome_build') + + + def get_pgs_id(self): + ''' Retrieve the PGS ID from the comments ''' + self.pgs_id = self.get_comments_info('#pgs_id') + + + def validate_content(self): + ''' Validate the file content and verify that the number of variant lines corresponds to the number of variants in the headers ''' + variant_lines_count = 0 + meta_lines_count = 0 + + with gzip.open( self.file, 'rb') as f: + line_number = 0 + file_meta = [] + for line in f: + line_number += 1 + line = line.decode('utf-8').rstrip() + # Check Metadata + if line.startswith('#'): + self.extract_specific_metadata(line) + # Check that we have all the meta information + for meta in self.meta_format: + if line.startswith(meta): + file_meta.append(meta) + meta_lines_count += 1 + break + + # Check data + else: + variant_lines_count += 1 + if re.search(r'\w+', line): # Line not empty + cols_content = line.split(self.sep) + has_trailing_spaces = self.check_leading_trailing_spaces(cols_content,line_number) + if has_trailing_spaces: + self.global_errors += 1 + + if line.startswith('rsID') or line.startswith('chr_name'): + continue + + self.validate_line_content(cols_content,variant_lines_count) + else: + self.logger.error(f'- Line {line_number} is empty') + self.global_errors += 1 + + # Compare the number of metadata lines: read vs expected + if meta_lines_count != len(self.meta_format): + self.logger.error(f'- The number of metadata lines [i.e. starting with the "#" character] in the file ({meta_lines_count}) and the expected number of metadata lines ({len(self.meta_format)}) are different') + diff_list = list(set(self.meta_format).difference(file_meta)) + self.logger.error(f" > Missing metadata line(s): {', '.join(diff_list)}") + self.global_errors += 1 + + + def validate_data(self) -> bool: + ''' Validate the file: data format and data content ''' + self.logger.info("Validating data...") + if not self.open_file_and_check_for_squareness(): + self.logger.error("Please fix the table. Some rows have different numbers of columns to the header") + self.logger.info("Rows with different numbers of columns to the header are not validated") + + # Validate data content and check the consitence between the declared variants number and the actual number of variants in the file + self.validate_content() + for chunk in self.df_iterator(self.file): + dataframe_to_validate = chunk[self.cols_to_read] + dataframe_to_validate.columns = self.cols_to_validate # sets the headers to standard format if neeeded + + # Schema validation + self.validate_schema(self.schema_validators,dataframe_to_validate) + + self.process_errors() + if len(self.bad_rows) >= self.error_limit: + break + + if not self.bad_rows and not self.global_errors and self.is_valid: + self.logger.info("File is valid") + else: + self.logger.info("File is invalid - {} bad rows, limit set to {}".format(len(self.bad_rows), self.error_limit)) + self.set_file_is_invalid() + return self.is_valid + + + def is_file_valid(self) -> bool: + ''' Method returning the boolean value: True if the file is valid, False if the file is invalid. ''' + return self.is_valid + + def set_file_is_invalid(self): + ''' Set the flag "is_valid" to False. ''' + self.is_valid = False + + + def process_errors(self): + ''' Populate the logger error and the list of bad rows with the errors found. ''' + for error in self.row_errors: + if len(self.bad_rows) < self.error_limit or self.error_limit < 1: + self.logger.error(error) + if error.row not in self.bad_rows: + self.bad_rows.append(error.row) + self.row_errors = [] + + + def store_errors(self, errors: List[pandas_schema.validation_warning.ValidationWarning]): + ''' Capture the errors found into a temporary structure before being processed. ''' + for error in errors: + seen = 0 + row_number = error.row + file_line_number = row_number + self.comment_lines_count + 1 # rows are 0 indexes + error.row = str(row_number) + " (line "+str(file_line_number)+")" + col = error.column + # Avoid duplication as the errors can be detected several times + if row_number in self.errors_seen.keys(): + if col in self.errors_seen[row_number].keys(): + seen = 1 + else: + self.errors_seen[row_number][col] = 1 + else: + self.errors_seen[row_number] = { col : 1 } + if seen == 0: + self.row_errors.append(error) + + + def validate_file_extension(self): + ''' Check/validate the file name extension. ''' + self.logger.info("Validating file extension...") + check_exts = [self.check_ext(ext) for ext in self.valid_extensions] + if not any(check_exts): + self.valid_ext = False + self.set_file_is_invalid() + self.logger.info("Invalid file extension: {}".format(self.file)) + self.logger.error("File extension should be in {}".format(self.valid_extensions)) + else: + self.valid_ext = True + return self.valid_ext + + + def compare_number_of_rows(self): + ''' Compare the number of data rows between the harmonized and the formatted scoring files. ''' + # Harmonization file - length + hm_rows_count = 0 + for chunk in self.df_iterator(self.file): + hm_rows_count += len(chunk.index) + gc.collect() + + # Formatted scoring file - length + scoring_rows_count = 0 + scoring_file = f'{self.score_dir}/{self.pgs_id}.txt.gz' + if os.path.isfile(scoring_file): + for score_chunk in self.df_iterator(scoring_file): + scoring_rows_count += len(score_chunk.index) + gc.collect() + + comparison_status = True + if scoring_rows_count == 0: + self.logger.error(f"Can't find the Scoring file '{scoring_file}' to compare the number of rows with the harmonization file!") + comparison_status = False + elif hm_rows_count != scoring_rows_count: + self.logger.error(f'The number of data rows between the Scoring file ({scoring_rows_count}) and the Harmonization POS file ({hm_rows_count}) are different') + comparison_status = False + return comparison_status + + + def compare_with_filename(self): + ''' Check that the filename matches the information present in the file metadata (PGS ID, genome build). ''' + self.logger.info("Comparing filename with metadata...") + comparison_status = True + if hasattr(self,'file_genomebuild') and hasattr(self,'file_pgs_id'): + # Extract some metadata + self.get_genomebuild() + self.get_pgs_id() + # Compare metadata with filename information + if self.file_genomebuild != self.genomebuild: + self.logger.error("Build: the genome build in the HmPOS_build header ({}) is different from the one on the filename ({})".format(self.genomebuild,self.file_genomebuild)) + comparison_status = False + if self.file_pgs_id != self.pgs_id: + self.logger.error("ID: the PGS ID of the header ({}) is different from the one on the filename ({})".format(self.pgs_id,self.file_pgs_id)) + comparison_status = False + # Compare number of rows with Scoring file + if self.score_dir: + row_comparison_status = self.compare_number_of_rows() + if row_comparison_status == False: + comparison_status = row_comparison_status + else: + self.logger.info("Comparison of the number of rows between Harmonized and Scoring file skipped!") + if not comparison_status: + self.logger.info("Discrepancies between filename information and metadata: {}".format(self.file)) + self.set_file_is_invalid() + return comparison_status + + + def df_iterator(self, data_file: str): + ''' Setup a pandas dataframe iterator. ''' + df = pd.read_csv(data_file, + sep=self.sep, + dtype=str, + comment='#', + chunksize=1000000) + return df + + + def check_file_is_square(self, csv_file: str): + ''' Check that each row has the name number of columns. ''' + square = True + csv_file.seek(0) + reader = csv.reader(csv_file, delimiter=self.sep) + count = 1 + for row in reader: + if len(row) != 0: + if row[0].startswith('#'): + self.comment_lines_count += 1 + continue + if (len(row) != len(self.header)): + self.logger.error("Length of row {c} is: {l} instead of {h}".format(c=count, l=str(len(row)), h=str(len(self.header)))) + self.logger.error("ROW: "+str(row)) + square = False + count += 1 + del csv_file + return square + + + def open_file_and_check_for_squareness(self): + ''' Method to read the file in order to check that each row has the name number of columns. ''' + if pathlib.Path(self.file).suffix in [".gz", ".gzip"]: + with gzip.open(self.file, 'rt') as f: + return self.check_file_is_square(f) + else: + with open(self.file) as f: + return self.check_file_is_square(f) + + + def check_leading_trailing_spaces(self, cols:str, line_number:str = None): + ''' + Check if the columns have leading and/or trailing spaces. + The leading/trailing spaces should raise an error during the validation. + ''' + leading_trailing_spaces = [] + found_trailing_spaces = False + for idx, col in enumerate(cols): + if col.startswith(' ') or col.endswith(' '): + leading_trailing_spaces.append(self.header[idx]+' => |'+str(col)+'|') + if len(leading_trailing_spaces): + if line_number: + line_name = f'line {line_number} has' + else: + line_name = 'following headers have' + self.logger.error("The "+line_name+" leading and/or trailing spaces: "+' ; '.join(leading_trailing_spaces)) + found_trailing_spaces = True + return found_trailing_spaces + + + def check_ext(self, ext:str) -> bool: + if self.file.endswith(ext): + return True + return False + + + def check_build_is_legit(self, build:str) -> bool: + if build in BUILD_LIST: + return True + return False + + + def get_comments_info(self, type:str) -> str: + ''' Retrieve information from the comments ''' + with gzip.open(self.file, 'rb') as f_in: + for f_line in f_in: + line = f_line.decode() + # Update header + if line.startswith(type): + info = (line.split('='))[1] + return info.strip() + + def run_generic_validator(self,check_filename): + self.logger.propagate = False + + # Check files exist + if not self.file or not self.logfile: + self.logger.info("Missing file and/or logfile") + self.set_file_is_invalid() + elif self.file and not os.path.exists(self.file): + self.logger.info("Error: the file '"+self.file+"' can't be found") + self.set_file_is_invalid() + + # Validate file extension + self.validate_file_extension() + + # Validate file name nomenclature + if self.is_file_valid() and check_filename: + self.validate_filename() + + # Only for harmonized files + if self.is_file_valid() and type(self).__name__ != 'ValidatorFormatted': + self.compare_with_filename() + + # Validate column headers + if self.is_file_valid(): + self.validate_headers() + + # Validate data content + if self.is_file_valid(): + self.validate_data() + + # Close log handler + self.logger.removeHandler(self.handler) + self.handler.close() + + def run_validator(self): + self.run_generic_validator(True) + + def run_validator_skip_check_filename(self): + self.run_generic_validator(False) + + + def validate_filename(self): + ''' Validate the file name structure. ''' + print("To be implemented in inherited classes") + pass + + + def validate_headers(self): + ''' Validate the list of column names. ''' + print("To be implemented in inherited classes") + pass + + + def validate_line_content(self, cols_content:str, var_line_number:int): + ''' Validate each data row. ''' + print("To be implemented in inherited classes") + pass + + + def extract_specific_metadata(self, line:str): + ''' Extra method to extract and validate specific data. ''' + print("To be implemented in inherited classes") + pass + diff --git a/poetry.lock b/poetry.lock index d776774..2ae26df 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,11 +1,3 @@ -[[package]] -name = "atomicwrites" -version = "1.4.1" -description = "Atomic file writes." -category = "dev" -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" - [[package]] name = "attrs" version = "22.1.0" @@ -22,15 +14,26 @@ tests_no_zope = ["coverage[toml] (>=5.0.2)", "hypothesis", "pympler", "pytest (> [[package]] name = "certifi" -version = "2022.6.15" +version = "2022.9.24" description = "Python package for providing Mozilla's CA Bundle." category = "main" optional = false python-versions = ">=3.6" +[[package]] +name = "cffi" +version = "1.15.1" +description = "Foreign Function Interface for Python calling C code." +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +pycparser = "*" + [[package]] name = "charset-normalizer" -version = "2.1.0" +version = "2.1.1" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." category = "main" optional = false @@ -47,9 +50,27 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "contourpy" +version = "1.0.5" +description = "Python library for calculating contours of 2D quadrilateral grids" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +numpy = ">=1.16" + +[package.extras] +test-no-codebase = ["pillow", "matplotlib", "pytest"] +test-minimal = ["pytest"] +test = ["isort", "flake8", "pillow", "matplotlib", "pytest"] +docs = ["sphinx-rtd-theme", "sphinx", "docutils (<0.18)"] +bokeh = ["selenium", "bokeh"] + [[package]] name = "coverage" -version = "6.4.4" +version = "6.5.0" description = "Code coverage measurement for Python" category = "dev" optional = false @@ -61,9 +82,39 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "cycler" +version = "0.11.0" +description = "Composable style cycles" +category = "dev" +optional = false +python-versions = ">=3.6" + +[[package]] +name = "fonttools" +version = "4.37.4" +description = "Tools to manipulate font files" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +all = ["fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "zopfli (>=0.1.4)", "lz4 (>=1.7.4.2)", "matplotlib", "sympy", "skia-pathops (>=0.5.0)", "uharfbuzz (>=0.23.0)", "brotlicffi (>=0.8.0)", "scipy", "brotli (>=1.0.1)", "munkres", "unicodedata2 (>=14.0.0)", "xattr"] +graphite = ["lz4 (>=1.7.4.2)"] +interpolatable = ["scipy", "munkres"] +lxml = ["lxml (>=4.0,<5)"] +pathops = ["skia-pathops (>=0.5.0)"] +plot = ["matplotlib"] +repacker = ["uharfbuzz (>=0.23.0)"] +symfont = ["sympy"] +type1 = ["xattr"] +ufo = ["fs (>=2.2.0,<3)"] +unicode = ["unicodedata2 (>=14.0.0)"] +woff = ["zopfli (>=0.1.4)", "brotlicffi (>=0.8.0)", "brotli (>=1.0.1)"] + [[package]] name = "idna" -version = "3.3" +version = "3.4" description = "Internationalized Domain Names in Applications (IDNA)" category = "main" optional = false @@ -79,15 +130,54 @@ python-versions = "*" [[package]] name = "jq" -version = "1.2.2" +version = "1.3.0" description = "jq is a lightweight and flexible JSON processor." category = "main" optional = false python-versions = ">=3.5" +[[package]] +name = "kiwisolver" +version = "1.4.4" +description = "A fast implementation of the Cassowary constraint solver" +category = "dev" +optional = false +python-versions = ">=3.7" + +[[package]] +name = "matplotlib" +version = "3.6.0" +description = "Python plotting package" +category = "dev" +optional = false +python-versions = ">=3.8" + +[package.dependencies] +contourpy = ">=1.0.1" +cycler = ">=0.10" +fonttools = ">=4.22.0" +kiwisolver = ">=1.0.1" +numpy = ">=1.19" +packaging = ">=20.0" +pillow = ">=6.2.0" +pyparsing = ">=2.2.1" +python-dateutil = ">=2.7" +setuptools_scm = ">=7" + +[[package]] +name = "memory-profiler" +version = "0.60.0" +description = "A module for monitoring memory usage of a python program" +category = "dev" +optional = false +python-versions = ">=3.4" + +[package.dependencies] +psutil = "*" + [[package]] name = "numpy" -version = "1.23.1" +version = "1.23.3" description = "NumPy is the fundamental package for array computing with Python." category = "main" optional = false @@ -97,7 +187,7 @@ python-versions = ">=3.8" name = "packaging" version = "21.3" description = "Core utilities for Python packages" -category = "dev" +category = "main" optional = false python-versions = ">=3.6" @@ -106,7 +196,7 @@ pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" [[package]] name = "pandas" -version = "1.4.3" +version = "1.5.0" description = "Powerful data structures for data analysis, time series, and statistics" category = "main" optional = false @@ -118,7 +208,32 @@ python-dateutil = ">=2.8.1" pytz = ">=2020.1" [package.extras] -test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] +test = ["pytest-xdist (>=1.31)", "pytest (>=6.0)", "hypothesis (>=5.5.3)"] + +[[package]] +name = "pandas-schema" +version = "0.3.6" +description = "A validation library for Pandas data frames using user-friendly schemas" +category = "main" +optional = false +python-versions = "*" + +[package.dependencies] +numpy = "*" +packaging = "*" +pandas = ">=0.19" + +[[package]] +name = "pillow" +version = "9.2.0" +description = "Python Imaging Library (Fork)" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-issues (>=3.0.1)", "sphinx-removed-in", "sphinxext-opengraph"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] [[package]] name = "pluggy" @@ -134,20 +249,33 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "polars" -version = "0.14.9" +version = "0.14.17" description = "Blazingly fast DataFrame library" category = "main" optional = false python-versions = ">=3.7" [package.extras] -pandas = ["pyarrow (>=4.0)", "pandas"] +pandas = ["pyarrow (>=4.0.0)", "pandas"] connectorx = ["connectorx"] -numpy = ["numpy (>=1.16.0)"] -fsspec = ["fsspec"] xlsx2csv = ["xlsx2csv (>=0.8.0)"] -pytz = ["pytz"] -pyarrow = ["pyarrow (>=4.0)"] +timezone = ["backports.zoneinfo", "tzdata"] +matplotlib = ["matplotlib"] +fsspec = ["fsspec"] +numpy = ["numpy (>=1.16.0)"] +all = ["polars"] +pyarrow = ["pyarrow (>=4.0.0)"] + +[[package]] +name = "psutil" +version = "5.9.2" +description = "Cross-platform lib for process and system monitoring in Python." +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + +[package.extras] +test = ["ipaddress", "mock", "enum34", "pywin32", "wmi"] [[package]] name = "py" @@ -157,6 +285,14 @@ category = "dev" optional = false python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +[[package]] +name = "pycparser" +version = "2.21" +description = "C parser in Python" +category = "main" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" + [[package]] name = "pyliftover" version = "0.4" @@ -169,7 +305,7 @@ python-versions = "*" name = "pyparsing" version = "3.0.9" description = "pyparsing module - Classes and methods to define and execute parsing grammars" -category = "dev" +category = "main" optional = false python-versions = ">=3.6.8" @@ -186,14 +322,13 @@ python-versions = ">=3" [[package]] name = "pytest" -version = "7.1.2" +version = "7.1.3" description = "pytest: simple powerful testing with Python" category = "dev" optional = false python-versions = ">=3.7" [package.dependencies] -atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} attrs = ">=19.2.0" colorama = {version = "*", markers = "sys_platform == \"win32\""} iniconfig = "*" @@ -233,7 +368,7 @@ six = ">=1.5" [[package]] name = "pytz" -version = "2022.1" +version = "2022.4" description = "World timezone definitions, modern and historical" category = "main" optional = false @@ -257,6 +392,23 @@ urllib3 = ">=1.21.1,<1.27" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "setuptools-scm" +version = "7.0.5" +description = "the blessed package to manage your versions by scm tags" +category = "dev" +optional = false +python-versions = ">=3.7" + +[package.dependencies] +packaging = ">=20.0" +tomli = ">=1.0.0" +typing-extensions = "*" + +[package.extras] +test = ["pytest (>=6.2)", "virtualenv (>20)"] +toml = ["setuptools (>=42)"] + [[package]] name = "six" version = "1.16.0" @@ -273,9 +425,17 @@ category = "dev" optional = false python-versions = ">=3.7" +[[package]] +name = "typing-extensions" +version = "4.3.0" +description = "Backported and Experimental Type Hints for Python 3.7+" +category = "dev" +optional = false +python-versions = ">=3.7" + [[package]] name = "urllib3" -version = "1.26.11" +version = "1.26.12" description = "HTTP library with thread-safe connection pooling, file post, and more." category = "main" optional = false @@ -283,64 +443,63 @@ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, [package.extras] brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"] -secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "ipaddress"] +secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "urllib3-secure-extra", "ipaddress"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "zstandard" +version = "0.18.0" +description = "Zstandard bindings for Python" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\""} + +[package.extras] +cffi = ["cffi (>=1.11)"] + [metadata] lock-version = "1.1" python-versions = "^3.10" -content-hash = "607d2d543f52a4ecc116c0b912c499a83cd1c740244323c81fdfe89ba27a55eb" +content-hash = "84b4520b176bb1b892c870fe894814cd05e217a86d7b4fadfa638b91a919bae5" [metadata.files] -atomicwrites = [] attrs = [] certifi = [] +cffi = [] charset-normalizer = [] colorama = [] +contourpy = [] coverage = [] -idna = [ - {file = "idna-3.3-py3-none-any.whl", hash = "sha256:84d9dd047ffa80596e0f246e2eab0b391788b0503584e8945f2368256d2735ff"}, - {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"}, -] +cycler = [] +fonttools = [] +idna = [] iniconfig = [ {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, {file = "iniconfig-1.1.1.tar.gz", hash = "sha256:bc3af051d7d14b2ee5ef9969666def0cd1a000e121eaea580d4a313df4b37f32"}, ] jq = [] +kiwisolver = [] +matplotlib = [] +memory-profiler = [] numpy = [] packaging = [] -pandas = [ - {file = "pandas-1.4.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d51674ed8e2551ef7773820ef5dab9322be0828629f2cbf8d1fc31a0c4fed640"}, - {file = "pandas-1.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:16ad23db55efcc93fa878f7837267973b61ea85d244fc5ff0ccbcfa5638706c5"}, - {file = "pandas-1.4.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:958a0588149190c22cdebbc0797e01972950c927a11a900fe6c2296f207b1d6f"}, - {file = "pandas-1.4.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e48fbb64165cda451c06a0f9e4c7a16b534fcabd32546d531b3c240ce2844112"}, - {file = "pandas-1.4.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f803320c9da732cc79210d7e8cc5c8019aad512589c910c66529eb1b1818230"}, - {file = "pandas-1.4.3-cp310-cp310-win_amd64.whl", hash = "sha256:2893e923472a5e090c2d5e8db83e8f907364ec048572084c7d10ef93546be6d1"}, - {file = "pandas-1.4.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:24ea75f47bbd5574675dae21d51779a4948715416413b30614c1e8b480909f81"}, - {file = "pandas-1.4.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d5ebc990bd34f4ac3c73a2724c2dcc9ee7bf1ce6cf08e87bb25c6ad33507e318"}, - {file = "pandas-1.4.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d6c0106415ff1a10c326c49bc5dd9ea8b9897a6ca0c8688eb9c30ddec49535ef"}, - {file = "pandas-1.4.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:78b00429161ccb0da252229bcda8010b445c4bf924e721265bec5a6e96a92e92"}, - {file = "pandas-1.4.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6dfbf16b1ea4f4d0ee11084d9c026340514d1d30270eaa82a9f1297b6c8ecbf0"}, - {file = "pandas-1.4.3-cp38-cp38-win32.whl", hash = "sha256:48350592665ea3cbcd07efc8c12ff12d89be09cd47231c7925e3b8afada9d50d"}, - {file = "pandas-1.4.3-cp38-cp38-win_amd64.whl", hash = "sha256:605d572126eb4ab2eadf5c59d5d69f0608df2bf7bcad5c5880a47a20a0699e3e"}, - {file = "pandas-1.4.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:a3924692160e3d847e18702bb048dc38e0e13411d2b503fecb1adf0fcf950ba4"}, - {file = "pandas-1.4.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:07238a58d7cbc8a004855ade7b75bbd22c0db4b0ffccc721556bab8a095515f6"}, - {file = "pandas-1.4.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:755679c49460bd0d2f837ab99f0a26948e68fa0718b7e42afbabd074d945bf84"}, - {file = "pandas-1.4.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41fc406e374590a3d492325b889a2686b31e7a7780bec83db2512988550dadbf"}, - {file = "pandas-1.4.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1d9382f72a4f0e93909feece6fef5500e838ce1c355a581b3d8f259839f2ea76"}, - {file = "pandas-1.4.3-cp39-cp39-win32.whl", hash = "sha256:0daf876dba6c622154b2e6741f29e87161f844e64f84801554f879d27ba63c0d"}, - {file = "pandas-1.4.3-cp39-cp39-win_amd64.whl", hash = "sha256:721a3dd2f06ef942f83a819c0f3f6a648b2830b191a72bbe9451bcd49c3bd42e"}, - {file = "pandas-1.4.3.tar.gz", hash = "sha256:2ff7788468e75917574f080cd4681b27e1a7bf36461fe968b49a87b5a54d007c"}, -] +pandas = [] +pandas-schema = [] +pillow = [] pluggy = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, ] polars = [] +psutil = [] py = [ {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, ] +pycparser = [] pyliftover = [ {file = "pyliftover-0.4.tar.gz", hash = "sha256:72bcfb7de907569b0eb75e86c817840365297d63ba43a961da394187e399da41"}, ] @@ -352,11 +511,9 @@ python-dateutil = [ {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, ] -pytz = [ - {file = "pytz-2022.1-py2.py3-none-any.whl", hash = "sha256:e68985985296d9a66a881eb3193b0906246245294a881e7c8afe623866ac6a5c"}, - {file = "pytz-2022.1.tar.gz", hash = "sha256:1e760e2fe6a8163bc0b3d9a19c4f84342afa0a2affebfaa84b01b978a02ecaa7"}, -] +pytz = [] requests = [] +setuptools-scm = [] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, @@ -365,4 +522,6 @@ tomli = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] +typing-extensions = [] urllib3 = [] +zstandard = [] diff --git a/pyproject.toml b/pyproject.toml index b8262b2..18de317 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [tool.poetry] name = "pgscatalog_utils" -version = "0.1.2" +version = "0.2.0" description = "Utilities for working with PGS Catalog API and scoring files" homepage = "https://github.com/PGScatalog/pgscatalog_utils" -authors = ["Benjamin Wingfield ", "Samuel Lambert "] +authors = ["Benjamin Wingfield ", "Samuel Lambert ", "Laurent Gil "] license = "Apache-2.0" readme = "README.md" @@ -11,19 +11,26 @@ readme = "README.md" combine_scorefiles = "pgscatalog_utils.scorefile.combine_scorefiles:combine_scorefiles" download_scorefiles = "pgscatalog_utils.download.download_scorefile:download_scorefile" match_variants = "pgscatalog_utils.match.match_variants:match_variants" +aggregate_scores = "pgscatalog_utils.aggregate.aggregate_scores:aggregate_scores" +validate_scorefiles = "pgscatalog_utils.validate.validate_scorefile:validate_scorefile" [tool.poetry.dependencies] python = "^3.10" +numpy = "^1.23.3" pandas = "^1.4.3" +pandas-schema = "^0.3.6" pyliftover = "^0.4" requests = "^2.28.1" jq = "^1.2.2" -polars = "0.14.9" +polars = "^0.14.9" +zstandard = "^0.18.0" [tool.poetry.dev-dependencies] pytest = "^7.1.2" pytest-cov = "^3.0.0" pysqlar = "^0.1.2" +memory-profiler = "^0.60.0" +matplotlib = "^3.6.0" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/data/test_scoring_file_1.txt.gz b/tests/data/test_scoring_file_1.txt.gz new file mode 100644 index 0000000..cd46417 Binary files /dev/null and b/tests/data/test_scoring_file_1.txt.gz differ diff --git a/tests/data/test_scoring_file_2.txt.gz b/tests/data/test_scoring_file_2.txt.gz new file mode 100644 index 0000000..fc1c10c Binary files /dev/null and b/tests/data/test_scoring_file_2.txt.gz differ diff --git a/tests/data/test_scoring_file_3.txt.gz b/tests/data/test_scoring_file_3.txt.gz new file mode 100644 index 0000000..6a2fef3 Binary files /dev/null and b/tests/data/test_scoring_file_3.txt.gz differ diff --git a/tests/data/test_scoring_file_4.txt.gz b/tests/data/test_scoring_file_4.txt.gz new file mode 100644 index 0000000..7e57cfe Binary files /dev/null and b/tests/data/test_scoring_file_4.txt.gz differ diff --git a/tests/data/test_scoring_file_hmpos_37_1.txt.gz b/tests/data/test_scoring_file_hmpos_37_1.txt.gz new file mode 100644 index 0000000..64ca88d Binary files /dev/null and b/tests/data/test_scoring_file_hmpos_37_1.txt.gz differ diff --git a/tests/data/test_scoring_file_hmpos_37_2.txt.gz b/tests/data/test_scoring_file_hmpos_37_2.txt.gz new file mode 100644 index 0000000..8acaa73 Binary files /dev/null and b/tests/data/test_scoring_file_hmpos_37_2.txt.gz differ diff --git a/tests/data/test_scoring_file_hmpos_37_3.txt.gz b/tests/data/test_scoring_file_hmpos_37_3.txt.gz new file mode 100644 index 0000000..601865a Binary files /dev/null and b/tests/data/test_scoring_file_hmpos_37_3.txt.gz differ diff --git a/tests/data/test_scoring_file_hmpos_38_1.txt.gz b/tests/data/test_scoring_file_hmpos_38_1.txt.gz new file mode 100644 index 0000000..2c6ce5d Binary files /dev/null and b/tests/data/test_scoring_file_hmpos_38_1.txt.gz differ diff --git a/tests/data/test_scoring_file_hmpos_38_2.txt.gz b/tests/data/test_scoring_file_hmpos_38_2.txt.gz new file mode 100644 index 0000000..8c1ec18 Binary files /dev/null and b/tests/data/test_scoring_file_hmpos_38_2.txt.gz differ diff --git a/tests/data/test_scoring_file_hmpos_38_3.txt.gz b/tests/data/test_scoring_file_hmpos_38_3.txt.gz new file mode 100644 index 0000000..343e1b6 Binary files /dev/null and b/tests/data/test_scoring_file_hmpos_38_3.txt.gz differ diff --git a/tests/match/test_label.py b/tests/match/test_label.py index 8198335..ebe0c43 100644 --- a/tests/match/test_label.py +++ b/tests/match/test_label.py @@ -4,6 +4,7 @@ import pytest import polars as pl +from pgscatalog_utils.match.label import label_matches from pgscatalog_utils.match.match import get_all_matches from tests.match.test_match import _cast_cat @@ -29,37 +30,54 @@ def test_label(small_scorefile, small_target): scorefile, target = _cast_cat(small_scorefile, small_target) # get_all_matches calls label_matches - labelled = get_all_matches(scorefile, target, skip_flip=True, remove_ambiguous=True, keep_first_match=False) + params = {'skip_flip': True, 'remove_ambiguous': True, 'remove_multiallelic': False, 'keep_first_match': False} + labelled: pl.DataFrame = (get_all_matches(scorefile=scorefile, target=target) + .pipe(label_matches, params=params) + .collect()) logger.debug(labelled.select(['ID', 'match_type', 'best_match', 'ambiguous', 'match_status', 'exclude'])) - assert labelled['best_match'].to_list() == [True, True, True] - assert labelled['ambiguous'].to_list() == [False, True, False] - assert labelled['exclude'].to_list() == [False, True, False] - assert labelled['match_status'].to_list() == ["matched", "excluded", "matched"] + assert labelled['best_match'].to_list() == [True, True, True, False] + assert labelled['ambiguous'].to_list() == [False, True, False, True] + assert labelled['exclude'].to_list() == [False, True, False, True] + assert labelled['match_status'].to_list() == ["matched", "excluded", "matched", "not_best"] def test_ambiguous_label(small_flipped_scorefile, small_target): """ Test ambiguous variant labels change when they're kept for match candidates with one match per position """ scorefile, target = _cast_cat(small_flipped_scorefile, small_target) - - no_ambiguous = get_all_matches(scorefile, target, skip_flip=True, remove_ambiguous=True, keep_first_match=False) - - assert no_ambiguous['best_match'].to_list() == [True] - assert no_ambiguous['ambiguous'].to_list() == [True] - assert no_ambiguous['exclude'].to_list() == [True] - assert no_ambiguous['match_status'].to_list() == ["excluded"] + no_flip = {'skip_flip': True, 'remove_ambiguous': True, 'remove_multiallelic': False, 'keep_first_match': False} + no_ambiguous: pl.DataFrame = (get_all_matches(scorefile=scorefile, target=target) + .pipe(label_matches, params=no_flip) + .collect()) + + # 2:2:T:A -> refalt -> ambiguous -> excluded (best match but ambiguous) + # 1:1:A:C -> refalt_flip -> not ambiguous -> excluded (best match but skip_flip) + # 2:2:T:A -> refalt_flip -> ambiguous -> not_best (refalt priority so not best and excluded) + # 3:3:T:G -> refalt_flip -> not ambiguous -> excluded (best match but skip_flip) + assert no_ambiguous['best_match'].to_list() == [True, True, False, True] + assert no_ambiguous['ambiguous'].to_list() == [True, False, True, False] + assert no_ambiguous['exclude'].to_list() == [True, True, True, True] + assert no_ambiguous['match_status'].to_list() == ["excluded", "excluded", "not_best", "excluded"] # otherwise, ambiguous variants are kept - labelled = get_all_matches(scorefile, target, skip_flip=True, remove_ambiguous=False, keep_first_match=False) - - assert labelled['best_match'].to_list() == [True] - assert labelled['ambiguous'].to_list() == [True] - assert labelled['exclude'].to_list() == [False] - assert labelled['match_status'].to_list() == ["matched"] - - -def test_duplicate_best_match(duplicated_matches, request): + flip_params = {'skip_flip': True, 'remove_ambiguous': False, 'remove_multiallelic': False, + 'keep_first_match': False} + labelled = (get_all_matches(scorefile=scorefile, target=target) + .pipe(label_matches, params=flip_params) + .collect()) + + # 2:2:T:A -> refalt -> ambiguous -> matched + # 1:1:A:C -> refalt_flip -> not ambiguous -> excluded (best match but skip_flip) + # 2:2:T:A -> refalt_flip -> ambiguous -> not_best (refalt priority so not best and excluded) + # 3:3:T:G -> refalt_flip -> not ambiguous -> excluded (best match but skip_flip) + assert labelled['best_match'].to_list() == [True, True, False, True] + assert labelled['ambiguous'].to_list() == [True, False, True, False] + assert labelled['exclude'].to_list() == [False, True, True, True] + assert labelled['match_status'].to_list() == ["matched", "excluded", "not_best", "excluded"] + + +def test_duplicate_ID(duplicated_matches, request): # these matches come from different lines in the original scoring file assert duplicated_matches["row_nr"].to_list() == [1, 4] # but they have the same ID! @@ -94,7 +112,7 @@ def test_duplicate_best_match(duplicate_best_match): @pytest.fixture(params=[True, False], ids=["keep_first_match", "delete_both"]) -def duplicated_matches(small_scorefile, small_target, request): +def duplicated_matches(small_scorefile, small_target, request) -> pl.DataFrame: # pgs catalog scorefiles can contain the same variant remapped to multiple rows # this happens after liftover to a different genome build # row_nrs will be different, but other information may be the same @@ -105,21 +123,33 @@ def duplicated_matches(small_scorefile, small_target, request): scorefile, target = _cast_cat(dups, small_target) - return get_all_matches(scorefile, target, skip_flip=False, remove_ambiguous=False, keep_first_match=request.param) + params = {'skip_flip': False, 'remove_ambiguous': False, 'remove_multiallelic': False, + 'keep_first_match': request.param} + return (get_all_matches(scorefile=scorefile, target=target) + .pipe(label_matches, params=params) + .collect()) @pytest.fixture -def multiple_match_types(small_target, small_scorefile): +def multiple_match_types(small_target, small_scorefile) -> pl.DataFrame: # skip flip will return two candidate matches for one target position: refalt + refalt_flip scorefile, target = _cast_cat(small_scorefile, small_target) - return (get_all_matches(scorefile, target, skip_flip=False, remove_ambiguous=False, keep_first_match=False) - .filter(pl.col('chr_name') == 2)) + + params = {'skip_flip': False, 'remove_ambiguous': False, 'remove_multiallelic': False, 'keep_first_match': False} + return (get_all_matches(scorefile=scorefile, target=target) + .pipe(label_matches, params=params) + .filter(pl.col('chr_name') == '2') + .collect()) @pytest.fixture -def duplicate_best_match(small_target, small_scorefile_no_oa): +def duplicate_best_match(small_target, small_scorefile_no_oa) -> pl.DataFrame: # this type of target genome can sometimes occur when the REF is different at the same position odd_target = {'#CHROM': [1, 1], 'POS': [1, 1], 'REF': ['T', 'C'], 'ALT': ['A', 'A'], 'ID': ['1:1:T:C', '1:1:A:A'], 'is_multiallelic': [False, False]} scorefile, target = _cast_cat(small_scorefile_no_oa, pl.DataFrame(odd_target)) - return get_all_matches(scorefile, target, skip_flip=False, remove_ambiguous=False, keep_first_match=False) + + params = {'skip_flip': False, 'remove_ambiguous': False, 'remove_multiallelic': False, 'keep_first_match': False} + return (get_all_matches(scorefile=scorefile, target=target) + .pipe(label_matches, params=params) + .collect()) diff --git a/tests/match/test_match.py b/tests/match/test_match.py index 2c1c8f4..ca509d6 100644 --- a/tests/match/test_match.py +++ b/tests/match/test_match.py @@ -5,7 +5,8 @@ import polars as pl import pytest -from pgscatalog_utils.match.match import get_all_matches, _cast_categorical +from pgscatalog_utils.match.label import label_matches +from pgscatalog_utils.match.match import get_all_matches from pgscatalog_utils.match.match_variants import match_variants @@ -38,23 +39,43 @@ def test_match_pass(mini_scorefile, target_path, tmp_path): match_variants() -def _cast_cat(scorefile, target): +def _cast_cat(scorefile, target) -> tuple[pl.LazyFrame, pl.LazyFrame]: with pl.StringCache(): - return _cast_categorical(scorefile, target) + scorefile = scorefile.with_columns([ + pl.col("chr_name").cast(pl.Utf8).cast(pl.Categorical), + pl.col("effect_allele").cast(pl.Categorical), + pl.col("other_allele").cast(pl.Categorical), + pl.col("effect_type").cast(pl.Categorical), + pl.col("effect_allele_FLIP").cast(pl.Categorical), + pl.col("other_allele_FLIP").cast(pl.Categorical), + pl.col("accession").cast(pl.Categorical) + ]) + target = target.with_columns([ + pl.col("#CHROM").cast(pl.Utf8).cast(pl.Categorical), + pl.col("REF").cast(pl.Categorical), + pl.col("ALT").cast(pl.Categorical) + ]) + return scorefile.lazy(), target.lazy() def test_match_strategies(small_scorefile, small_target): scorefile, target = _cast_cat(small_scorefile, small_target) + params = {'skip_flip': True, 'remove_ambiguous': False, 'keep_first_match': False, 'remove_multiallelic': False} # check unambiguous matches - df = (get_all_matches(scorefile, target, skip_flip=True, remove_ambiguous=False, keep_first_match=False) - .filter(pl.col('ambiguous') == False)) + df: pl.DataFrame = (get_all_matches(scorefile, target) + .pipe(label_matches, params=params) + .filter(pl.col('ambiguous') == False) + .collect()) assert set(df['ID'].to_list()).issubset({'3:3:T:G', '1:1:A:C'}) assert set(df['match_type'].to_list()).issubset(['altref', 'refalt']) # when keeping ambiguous and flipping alleles - flip = (get_all_matches(scorefile, target, skip_flip=False, remove_ambiguous=False, keep_first_match=False) - .filter(pl.col('ambiguous') == True)) + flip_params = {'skip_flip': False, 'remove_ambiguous': False, 'keep_first_match': False, 'remove_multiallelic': False} + flip: pl.DataFrame = (get_all_matches(scorefile, target) + .pipe(label_matches, params=flip_params) + .filter(pl.col('ambiguous') == True) + .collect()) assert set(flip['ID'].to_list()).issubset({'2:2:T:A'}) assert set(flip['match_type'].to_list()).issubset({'altref', 'refalt_flip'}) @@ -63,28 +84,42 @@ def test_match_strategies(small_scorefile, small_target): def test_no_oa_match(small_scorefile_no_oa, small_target): scorefile, target = _cast_cat(small_scorefile_no_oa, small_target) - df = (get_all_matches(scorefile, target, skip_flip=True, remove_ambiguous=False, keep_first_match=False) - .filter(pl.col('ambiguous') == False)) + no_ambig = {'skip_flip': True, 'remove_ambiguous': False, 'keep_first_match': False, 'remove_multiallelic': False} + df: pl.DataFrame = (get_all_matches(scorefile, target) + .pipe(label_matches, params=no_ambig) + .filter(pl.col('ambiguous') == False) + .collect()) assert set(df['ID'].to_list()).issubset(['3:3:T:G', '1:1:A:C']) assert set(df['match_type'].to_list()).issubset(['no_oa_alt', 'no_oa_ref']) # check ambiguous matches - flip = (get_all_matches(scorefile, target, skip_flip=False, remove_ambiguous=False, keep_first_match=False) - .filter(pl.col('ambiguous') == True)) + ambig = {'skip_flip': False, 'remove_ambiguous': False, 'keep_first_match': False, 'remove_multiallelic': False} + flip: pl.DataFrame = (get_all_matches(scorefile, target) + .pipe(label_matches, ambig) + .filter(pl.col('ambiguous') == True) + .collect()) assert set(flip['ID'].to_list()).issubset({'2:2:T:A'}) assert set(flip['match_type'].to_list()).issubset({'no_oa_alt', 'no_oa_ref_flip'}) def test_flip_match(small_flipped_scorefile, small_target): scorefile, target = _cast_cat(small_flipped_scorefile, small_target) - - df = get_all_matches(scorefile, target, skip_flip=True, remove_ambiguous=False, keep_first_match=False) - assert set(df['ambiguous']) == {True} - assert set(df['match_type']) == {'refalt'} - - flip = (get_all_matches(scorefile, target, skip_flip=False, remove_ambiguous=False, keep_first_match=False) - .filter(pl.col('ambiguous') == False)) + params = {'skip_flip': True, 'remove_ambiguous': False, 'keep_first_match': False, 'remove_multiallelic': False} + df: pl.DataFrame = (get_all_matches(scorefile, target) + .pipe(label_matches, params=params) + .collect()) + + assert df['ambiguous'].to_list() == [True, False, True, False] + assert df['match_type'].to_list() == ['refalt', 'refalt_flip', 'altref_flip', 'altref_flip'] + assert df['match_status'].to_list() == ['matched', 'excluded', 'not_best', 'excluded'] # flipped -> excluded + + no_flip_params = {'skip_flip': False, 'remove_ambiguous': False, 'keep_first_match': False, + 'remove_multiallelic': False} + flip: pl.DataFrame = (get_all_matches(scorefile, target) + .pipe(label_matches, params=no_flip_params) + .filter(pl.col('ambiguous') == False) + .collect()) assert flip['match_type'].str.contains('flip').all() assert set(flip['ID'].to_list()).issubset(['3:3:T:G', '1:1:A:C']) diff --git a/tests/test_validate.py b/tests/test_validate.py new file mode 100644 index 0000000..7459f05 --- /dev/null +++ b/tests/test_validate.py @@ -0,0 +1,149 @@ +import pytest +import numpy as np + +from pgscatalog_utils.validate.formatted.validator import init_validator as formatted_init_validator +from pgscatalog_utils.validate.harmonized_position.validator import init_validator as hmpos_init_validator + + +log_file = 'VALIDATE.log' +test_data_dir = './tests/data' + + +###### Formatted scoring files ###### +def _get_formatted_validator(test_file): + validator = formatted_init_validator(test_file,log_file,None) + return validator + +def _valid_file(test_file): + validator = _get_formatted_validator(test_file) + assert validator.validate_file_extension() + assert validator.validate_headers() + assert validator.validate_data() + assert validator.is_file_valid() + +def _failed_file(test_file): + validator = _get_formatted_validator(test_file) + assert validator.validate_file_extension() + assert validator.validate_headers() + assert not validator.validate_data() + assert not validator.is_file_valid() + +def _failed_header_file(test_file): + validator = _get_formatted_validator(test_file) + assert validator.validate_file_extension() + validator.header = np.delete(validator.header,np.s_[0,1,2]) + assert not validator.validate_headers() + + +# Valid file with rsID, chr_name and chr_position +def test_valid_formatted_file_rsID_and_pos(test_file_1): + _valid_file(test_file_1) + +# Valid file with rsID only +def test_valid_formatted_file_rsID_only(test_file_2): + _valid_file(test_file_2) + +# Valid file with chr_name and chr_position +def test_valid_formatted_file_pos_only(test_file_3): + _valid_file(test_file_3) + +# File made invalid file by removing some mandatory column headers +def test_failed_formatted_file_missing_header(test_file_1): + _failed_header_file(test_file_1) + +# Invalid file with several data content issues +def test_failed_formatted_file_data_issues(test_file_4): + _failed_file(test_file_4) + + + +###### Harmonized (Position) scoring files ###### +def _get_hmpos_validator(test_file): + validator = hmpos_init_validator(test_file,log_file,None) + return validator + +def _valid_hmpos_file(test_file): + validator = _get_hmpos_validator(test_file) + assert validator.validate_file_extension() + assert validator.validate_headers() + assert validator.validate_data() + assert validator.is_file_valid() + +def _failed_file(test_file): + validator = _get_formatted_validator(test_file) + assert validator.validate_file_extension() + assert validator.validate_headers() + assert not validator.validate_data() + assert not validator.is_file_valid() + +def _failed_header_file(test_file): + validator = _get_formatted_validator(test_file) + assert validator.validate_file_extension() + validator.header = np.delete(validator.header,np.s_[0,1,2]) + assert not validator.validate_headers() + + +## GRCh37 ## +# Valid file with rsID, chr_name and chr_position +def test_valid_hmpos_file_rsID_and_pos_37(test_hmpos_file_GRCh37_1): + _valid_hmpos_file(test_hmpos_file_GRCh37_1) +# Valid file with rsID only +def test_valid_formatted_file_rsID_only_37(test_hmpos_file_GRCh37_2): + _valid_hmpos_file(test_hmpos_file_GRCh37_2) +# Valid file with chr_name and chr_position +def test_valid_formatted_file_pos_only_37(test_hmpos_file_GRCh37_3): + _valid_file(test_hmpos_file_GRCh37_3) + +## GRCh38 ## +# Valid file with rsID, chr_name and chr_position +def test_valid_hmpos_file_rsID_and_pos_38(test_hmpos_file_GRCh38_1): + _valid_hmpos_file(test_hmpos_file_GRCh38_1) +# Valid file with rsID only +def test_valid_formatted_file_rsID_only_38(test_hmpos_file_GRCh38_2): + _valid_hmpos_file(test_hmpos_file_GRCh38_2) +# Valid file with chr_name and chr_position +def test_valid_formatted_file_pos_only_38(test_hmpos_file_GRCh38_3): + _valid_file(test_hmpos_file_GRCh38_3) + + +###################################################### + +@pytest.fixture +def test_file_1(): + return f'{test_data_dir}/test_scoring_file_1.txt.gz' + +@pytest.fixture +def test_file_2(): + return f'{test_data_dir}/test_scoring_file_2.txt.gz' + +@pytest.fixture +def test_file_3(): + return f'{test_data_dir}/test_scoring_file_3.txt.gz' + +@pytest.fixture +def test_file_4(): + return f'{test_data_dir}/test_scoring_file_4.txt.gz' + +@pytest.fixture +def test_hmpos_file_GRCh37_1(): + return f'{test_data_dir}/test_scoring_file_hmpos_37_1.txt.gz' + +@pytest.fixture +def test_hmpos_file_GRCh38_1(): + return f'{test_data_dir}/test_scoring_file_hmpos_38_1.txt.gz' + +@pytest.fixture +def test_hmpos_file_GRCh37_2(): + return f'{test_data_dir}/test_scoring_file_hmpos_37_2.txt.gz' + +@pytest.fixture +def test_hmpos_file_GRCh38_2(): + return f'{test_data_dir}/test_scoring_file_hmpos_38_2.txt.gz' + +@pytest.fixture +def test_hmpos_file_GRCh37_3(): + return f'{test_data_dir}/test_scoring_file_hmpos_37_3.txt.gz' + +@pytest.fixture +def test_hmpos_file_GRCh38_3(): + return f'{test_data_dir}/test_scoring_file_hmpos_38_3.txt.gz' \ No newline at end of file