From 1664852436ea2d58b5a663d780fc9f32f1283348 Mon Sep 17 00:00:00 2001 From: Benjamin Wingfield Date: Wed, 12 Jun 2024 16:00:46 +0100 Subject: [PATCH] Improve aggregation (#23) * export key functions for sorting chromosomes / effect types * use new key functions for sorting * reduce memory usage during aggregation * fix doctest output * make aggregation steps clearer --- pgscatalog.calc/poetry.lock | 41 ++++++---- pgscatalog.calc/pyproject.toml | 2 +- .../src/pgscatalog/calc/cli/__init__.py | 4 +- .../src/pgscatalog/calc/cli/aggregate_cli.py | 50 +++++++++++-- .../src/pgscatalog/calc/lib/polygenicscore.py | 74 ++++++++++--------- pgscatalog.core/poetry.lock | 17 ++++- pgscatalog.core/pyproject.toml | 1 + .../src/pgscatalog/core/__init__.py | 4 + .../src/pgscatalog/core/lib/__init__.py | 3 + .../src/pgscatalog/core/lib/_sortpaths.py | 29 ++++++++ pgscatalog.match/poetry.lock | 30 ++++++-- pgscatalog.match/pyproject.toml | 2 +- .../src/pgscatalog/match/__init__.py | 8 +- .../src/pgscatalog/match/lib/matchresult.py | 1 + .../pgscatalog/match/lib/plinkscorefiles.py | 33 +++++---- pgscatalog.match/tests/test_merge_cli.py | 3 + 16 files changed, 216 insertions(+), 86 deletions(-) create mode 100644 pgscatalog.core/src/pgscatalog/core/lib/_sortpaths.py diff --git a/pgscatalog.calc/poetry.lock b/pgscatalog.calc/poetry.lock index 4404971..9da452f 100644 --- a/pgscatalog.calc/poetry.lock +++ b/pgscatalog.calc/poetry.lock @@ -311,6 +311,21 @@ files = [ {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, ] +[[package]] +name = "natsort" +version = "8.4.0" +description = "Simple yet flexible natural sorting in Python." +optional = false +python-versions = ">=3.7" +files = [ + {file = "natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c"}, + {file = "natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581"}, +] + +[package.extras] +fast = ["fastnumbers (>=2.0.0)"] +icu = ["PyICU (>=1.0.0)"] + [[package]] name = "numpy" version = "1.26.4" @@ -444,21 +459,21 @@ name = "pgscatalog-core" version = "0.1.2" description = "Core tools for working with polygenic scores (PGS) and the PGS Catalog" optional = false -python-versions = "<4.0,>=3.11" -files = [ - {file = "pgscatalog_core-0.1.2-py3-none-any.whl", hash = "sha256:45bd0f2d807ae47efc4fbe59e49e43cd30be999ccdf893135cb6fbae2aad3228"}, - {file = "pgscatalog_core-0.1.2.tar.gz", hash = "sha256:3a7428a4a78642f87f1a9da140a9e7e47e1b9c98100834c7c1033481f535e3bf"}, -] +python-versions = "^3.11" +files = [] +develop = true [package.dependencies] -httpx = ">=0.26.0,<0.27.0" -pyliftover = ">=0.4,<0.5" -tenacity = ">=8.2.3,<9.0.0" -tqdm = ">=4.66.1,<5.0.0" -xopen = {version = ">=1.8.0,<2.0.0", extras = ["zstd"]} +httpx = "^0.26.0" +natsort = "^8.4.0" +pyliftover = "^0.4" +tenacity = "^8.2.3" +tqdm = "^4.66.1" +xopen = {version = "^1.8.0", extras = ["zstd"]} -[package.extras] -pyarrow = ["pyarrow (>=15.0.0,<16.0.0)"] +[package.source] +type = "directory" +url = "../pgscatalog.core" [[package]] name = "pluggy" @@ -910,4 +925,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "5a862d272f92f569f8c623d773e32bc032a5ce16f3e65de341806fbb0b5767c9" +content-hash = "129e98096e082467bf990f3d02ffa8d9c1a34328d3992672c1d235db1075539c" diff --git a/pgscatalog.calc/pyproject.toml b/pgscatalog.calc/pyproject.toml index c78e8b7..c2cb506 100644 --- a/pgscatalog.calc/pyproject.toml +++ b/pgscatalog.calc/pyproject.toml @@ -10,7 +10,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.11" -"pgscatalog.core" = "^0.1.0" +"pgscatalog.core" = {path = "../pgscatalog.core", develop = true} numpy = "^1.26.4" pandas = "^2.2.0" pyarrow = "^15.0.0" diff --git a/pgscatalog.calc/src/pgscatalog/calc/cli/__init__.py b/pgscatalog.calc/src/pgscatalog/calc/cli/__init__.py index 02a90c1..a9a2c5b 100644 --- a/pgscatalog.calc/src/pgscatalog/calc/cli/__init__.py +++ b/pgscatalog.calc/src/pgscatalog/calc/cli/__init__.py @@ -1,3 +1 @@ -from .aggregate_cli import run_aggregate - -__all__ = ["run_aggregate"] +__all__ = [] diff --git a/pgscatalog.calc/src/pgscatalog/calc/cli/aggregate_cli.py b/pgscatalog.calc/src/pgscatalog/calc/cli/aggregate_cli.py index 7db041b..f1fcb54 100644 --- a/pgscatalog.calc/src/pgscatalog/calc/cli/aggregate_cli.py +++ b/pgscatalog.calc/src/pgscatalog/calc/cli/aggregate_cli.py @@ -2,10 +2,11 @@ import logging import pathlib import textwrap -import operator -import functools +from collections import deque +from typing import Optional -from ..lib.polygenicscore import PolygenicScore +from ..lib import PolygenicScore +from pgscatalog.core import chrom_keyfunc logger = logging.getLogger(__name__) @@ -21,15 +22,50 @@ def run_aggregate(): if args.verbose: logger.setLevel(logging.INFO) + logging.getLogger("pgscatalog.core").setLevel(logging.INFO) + logging.getLogger("pgscatalog.calc").setLevel(logging.INFO) if not (outdir := pathlib.Path(args.outdir)).exists(): raise FileNotFoundError(f"--outdir {outdir.name} doesn't exist") - score_paths = [pathlib.Path(x) for x in args.scores] - pgs = [PolygenicScore(path=x) for x in score_paths] - # call __add__ a lot - aggregated = functools.reduce(operator.add, pgs) + score_paths = sorted([pathlib.Path(x) for x in args.scores], key=chrom_keyfunc()) + # dfs are only read into memory after accessing them explicitly e.g. pgs[0].df + pgs = deque(PolygenicScore(path=x) for x in score_paths) + + observed_columns = set() + aggregated: Optional[PolygenicScore] = None + + # first, use PolygenicScore's __add__ method, which implements df.add(fill_value=0) + while pgs: + # popleft ensures that dfs are removed from memory after each aggregation + score: PolygenicScore = pgs.popleft() + if aggregated is None: + logger.info(f"Initialising aggregation with {score}") + aggregated: PolygenicScore = score + else: + logger.info(f"Adding {score}") + aggregated += score + observed_columns.update(set(score.df.columns)) + + # check to make sure that every column we saw in the dataframes is in the output + if (dfcols := set(aggregated.df.columns)) != observed_columns: + raise ValueError( + f"Missing columns in aggregated file!. " + f"Observed: {observed_columns}. " + f"In aggregated: {dfcols}" + ) + else: + logger.info("Aggregated columns match observed columns") + + # next, melt the plink2 scoring files from wide (many columns) format to long format + aggregated.melt() + + # recalculate PGS average using aggregated SUM and DENOM + aggregated.average() + + logger.info("Aggregation finished! Writing to a file") aggregated.write(outdir=args.outdir, split=args.split) + logger.info("all done. bye :)") def _description_text() -> str: diff --git a/pgscatalog.calc/src/pgscatalog/calc/lib/polygenicscore.py b/pgscatalog.calc/src/pgscatalog/calc/lib/polygenicscore.py index 62bc6af..1fe22f8 100644 --- a/pgscatalog.calc/src/pgscatalog/calc/lib/polygenicscore.py +++ b/pgscatalog.calc/src/pgscatalog/calc/lib/polygenicscore.py @@ -274,12 +274,17 @@ class PolygenicScore: >>> aggregated_score = pgs1 + pgs2 >>> aggregated_score # doctest: +ELLIPSIS - PolygenicScore(sampleset='test', path=None) + PolygenicScore(sampleset='test', path='(in-memory)') Once a score has been fully aggregated it can be helpful to recalculate an average: - >>> reprlib.repr(aggregated_score.average().to_dict()) # doctest: +ELLIPSIS - "{'DENOM': {('test', 'HG00096'): 3128, ('test', 'HG00097'): 3128, ('test', 'HG00099'): 3128, ('test', 'HG00100'): 3128, ...}, 'PGS001229_22_AVG': {('test', 'HG00096'): 0.0003484782608695652, ('test', 'HG00097'): 0.00043120268542199493, ('test', 'HG00099'): 0.0004074616368286445, ('test', 'HG00100'): 0.0005523938618925831, ...}}" + >>> aggregated_score.average() + >>> aggregated_score.df # doctest: +ELLIPSIS,+NORMALIZE_WHITESPACE + PGS SUM DENOM AVG + sampleset IID + test HG00096 PGS001229_22 1.090040 3128 0.000348 + HG00097 PGS001229_22 1.348802 3128 0.000431 + ... Scores can be written to a TSV file: @@ -321,14 +326,19 @@ def __init__(self, *, path=None, df=None, sampleset=None): if self.sampleset is None: raise TypeError("Missing sampleset") - self._chunksize = 50000 self._df = df + self._melted = False def __repr__(self): - return f"{type(self).__name__}(sampleset={repr(self.sampleset)}, path={repr(self.path)})" + if self.path is None: + path = repr("(in-memory)") + else: + path = repr(self.path) + return f"{type(self).__name__}(sampleset={repr(self.sampleset)}, path={path})" def __add__(self, other): if isinstance(other, PolygenicScore): + logger.info(f"Doing element-wise addition: {self} + {other}") sumdf = self.df.add(other.df, fill_value=0) return PolygenicScore(sampleset=self.sampleset, df=sumdf) else: @@ -361,32 +371,38 @@ def read(self): return df def average(self): - """Recalculate average.""" + """Update the dataframe with a recalculated average.""" + logger.info("Recalculating average") + if not self._melted: + self.melt() + df = self.df - avgs = df.filter(regex="SUM$") - avgs = avgs.divide(df.DENOM, axis=0) - avgs.insert(0, "DENOM", df.DENOM) - avgs.columns = avgs.columns.str.replace("_SUM", "_AVG") - return avgs + df["AVG"] = df.SUM / df.DENOM + self._df = df def melt(self): - """Melt dataframe from wide format to long format""" - sum_df = _melt(self.df, value_name="SUM") - avg_df = _melt(self.average(), value_name="AVG") - df = pd.concat([sum_df, avg_df.AVG], axis=1) + """Update the dataframe with a melted version (wide format to long format)""" + logger.info("Melting dataframe from wide to long format") + df = self.df.melt( + id_vars=["DENOM"], + value_name="SUM", + var_name="PGS", + ignore_index=False, + ) + # e.g. PGS000822_SUM -> PGS000822 + df["PGS"] = df["PGS"].str.replace("_SUM", "") # melted chunks need a consistent column order - return df[["PGS", "SUM", "DENOM", "AVG"]] + self._df = df[["PGS", "SUM", "DENOM"]] + self._melted = True - def write(self, outdir, split=False, melt=True): + def write(self, outdir, split=False): """Write PGS to a compressed TSV""" outdir = pathlib.Path(outdir) - if melt: - logger.info("Melting before write to TSV") - df = self.melt() - else: - logger.info("Writing wide format to TSV") - df = self.df + if not self._melted: + self.melt() + + df = self.df if split: logger.info("Writing results split by sampleset") @@ -408,15 +424,3 @@ def _select_agg_cols(cols): for x in cols if (x.endswith("_SUM") and (x != "NAMED_ALLELE_DOSAGE_SUM")) or (x in keep_cols) ] - - -def _melt(df, value_name): - df = df.melt( - id_vars=["DENOM"], - value_name=value_name, - var_name="PGS", - ignore_index=False, - ) - # e.g. PGS000822_SUM -> PGS000822 - df["PGS"] = df["PGS"].str.replace(f"_{value_name}", "") - return df diff --git a/pgscatalog.core/poetry.lock b/pgscatalog.core/poetry.lock index a014aa4..68c277d 100644 --- a/pgscatalog.core/poetry.lock +++ b/pgscatalog.core/poetry.lock @@ -543,6 +543,21 @@ files = [ {file = "MarkupSafe-2.1.5.tar.gz", hash = "sha256:d283d37a890ba4c1ae73ffadf8046435c76e7bc2247bbb63c00bd1a709c6544b"}, ] +[[package]] +name = "natsort" +version = "8.4.0" +description = "Simple yet flexible natural sorting in Python." +optional = false +python-versions = ">=3.7" +files = [ + {file = "natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c"}, + {file = "natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581"}, +] + +[package.extras] +fast = ["fastnumbers (>=2.0.0)"] +icu = ["PyICU (>=1.0.0)"] + [[package]] name = "packaging" version = "24.0" @@ -1074,4 +1089,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "a27d4c1788784e8fa4eda68d877292c1129f816067a577bcf72f0f1aaed3e02c" +content-hash = "dc7d2236db352d5e2b633c3cd949ce93176a65292802d74c668d9f9b052fe7e2" diff --git a/pgscatalog.core/pyproject.toml b/pgscatalog.core/pyproject.toml index fa2ca1a..8db87d5 100644 --- a/pgscatalog.core/pyproject.toml +++ b/pgscatalog.core/pyproject.toml @@ -21,6 +21,7 @@ tenacity = "^8.2.3" pyliftover = "^0.4" xopen = {version = "^1.8.0", extras = ["zstd"]} tqdm = "^4.66.1" +natsort = "^8.4.0" [tool.poetry.group.dev.dependencies] pytest = "^7.4.4" diff --git a/pgscatalog.core/src/pgscatalog/core/__init__.py b/pgscatalog.core/src/pgscatalog/core/__init__.py index 9444fcc..4192465 100644 --- a/pgscatalog.core/src/pgscatalog/core/__init__.py +++ b/pgscatalog.core/src/pgscatalog/core/__init__.py @@ -34,6 +34,8 @@ RelabelArgs, relabel, relabel_write, + effect_type_keyfunc, + chrom_keyfunc, ) log_fmt = "%(name)s: %(asctime)s %(levelname)-8s %(message)s" @@ -74,6 +76,8 @@ "RelabelArgs", "relabel", "relabel_write", + "effect_type_keyfunc", + "chrom_keyfunc", ] __version__ = "0.1.2" diff --git a/pgscatalog.core/src/pgscatalog/core/lib/__init__.py b/pgscatalog.core/src/pgscatalog/core/lib/__init__.py index 665416d..aa3877e 100644 --- a/pgscatalog.core/src/pgscatalog/core/lib/__init__.py +++ b/pgscatalog.core/src/pgscatalog/core/lib/__init__.py @@ -5,6 +5,7 @@ from .genomebuild import GenomeBuild from .targetvariants import TargetVariants, TargetVariant, TargetType from ._relabel import RelabelArgs, relabel, relabel_write +from ._sortpaths import effect_type_keyfunc, chrom_keyfunc from .pgsexceptions import ( BasePGSException, MatchError, @@ -59,4 +60,6 @@ "RelabelArgs", "relabel", "relabel_write", + "effect_type_keyfunc", + "chrom_keyfunc", ] diff --git a/pgscatalog.core/src/pgscatalog/core/lib/_sortpaths.py b/pgscatalog.core/src/pgscatalog/core/lib/_sortpaths.py new file mode 100644 index 0000000..6757bb9 --- /dev/null +++ b/pgscatalog.core/src/pgscatalog/core/lib/_sortpaths.py @@ -0,0 +1,29 @@ +""" This module assumes you're working with paths that follow the format: + +{sampleset}_{chrom}_{effect_type}_{n} +""" +from natsort import natsort_keygen, ns + + +def effect_type_keyfunc(): + """Return a key that sorts by effect type and n. Chromosome order doesn't matter. + + This is useful for things like itertools.groupby which expect sorted input + + >>> import pathlib + >>> paths = [pathlib.Path("ukb_2_dominant_0.txt.gz"), pathlib.Path("ukb_X_additive_0.txt.gz"), pathlib.Path("ukb_X_additive_1.txt.gz"), pathlib.Path("ukb_1_recessive_0.txt.gz")] + >>> sorted(paths, key=effect_type_keyfunc()) + [PosixPath('ukb_X_additive_0.txt.gz'), PosixPath('ukb_X_additive_1.txt.gz'), PosixPath('ukb_2_dominant_0.txt.gz'), PosixPath('ukb_1_recessive_0.txt.gz')] + """ + return natsort_keygen(key=lambda x: x.stem.split("_")[2:], alg=ns.REAL) + + +def chrom_keyfunc(): + """Return a key that sorts by chromosome, including non-integer chromosomes + + >>> import pathlib + >>> paths = [pathlib.Path("ukb_2_additive_0.txt.gz"), pathlib.Path("ukb_X_additive_0.txt.gz"), pathlib.Path("ukb_1_additive_0.txt.gz")] + >>> sorted(paths, key=chrom_keyfunc()) + [PosixPath('ukb_1_additive_0.txt.gz'), PosixPath('ukb_2_additive_0.txt.gz'), PosixPath('ukb_X_additive_0.txt.gz')] + """ + return natsort_keygen(key=lambda x: x.stem.split("_")[1], alg=ns.REAL) diff --git a/pgscatalog.match/poetry.lock b/pgscatalog.match/poetry.lock index 41f365f..5d47435 100644 --- a/pgscatalog.match/poetry.lock +++ b/pgscatalog.match/poetry.lock @@ -22,13 +22,13 @@ trio = ["trio (>=0.23)"] [[package]] name = "certifi" -version = "2024.2.2" +version = "2024.6.2" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2024.2.2-py3-none-any.whl", hash = "sha256:dc383c07b76109f368f6106eee2b593b04a011ea4d55f652c6ca24a754d1cdd1"}, - {file = "certifi-2024.2.2.tar.gz", hash = "sha256:0569859f95fc761b18b45ef421b1290a0f65f147e92a1e5eb3e635f9a5e4e66f"}, + {file = "certifi-2024.6.2-py3-none-any.whl", hash = "sha256:ddc6c8ce995e6987e7faf5e3f1b02b302836a0e5d98ece18392cb1a36c72ad56"}, + {file = "certifi-2024.6.2.tar.gz", hash = "sha256:3cd43f1c6fa7dedc5899d69d3ad0398fd018ad1a17fba83ddaf78aa46c747516"}, ] [[package]] @@ -300,6 +300,21 @@ files = [ {file = "isal-1.6.1.tar.gz", hash = "sha256:7b64b75d260b544beea3f59cb25a6f520c04768818ef4ac316ee9a1f2ebf18f5"}, ] +[[package]] +name = "natsort" +version = "8.4.0" +description = "Simple yet flexible natural sorting in Python." +optional = false +python-versions = ">=3.7" +files = [ + {file = "natsort-8.4.0-py3-none-any.whl", hash = "sha256:4732914fb471f56b5cce04d7bae6f164a592c7712e1c85f9ef585e197299521c"}, + {file = "natsort-8.4.0.tar.gz", hash = "sha256:45312c4a0e5507593da193dedd04abb1469253b601ecaf63445ad80f0a1ea581"}, +] + +[package.extras] +fast = ["fastnumbers (>=2.0.0)"] +icu = ["PyICU (>=1.0.0)"] + [[package]] name = "numpy" version = "1.26.4" @@ -367,6 +382,7 @@ develop = true [package.dependencies] httpx = "^0.26.0" +natsort = "^8.4.0" pyliftover = "^0.4" tenacity = "^8.2.3" tqdm = "^4.66.1" @@ -502,13 +518,13 @@ files = [ [[package]] name = "pytest" -version = "8.2.1" +version = "8.2.2" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.8" files = [ - {file = "pytest-8.2.1-py3-none-any.whl", hash = "sha256:faccc5d332b8c3719f40283d0d44aa5cf101cec36f88cde9ed8f2bc0538612b1"}, - {file = "pytest-8.2.1.tar.gz", hash = "sha256:5046e5b46d8e4cac199c373041f26be56fdb81eb4e67dc11d4e10811fc3408fd"}, + {file = "pytest-8.2.2-py3-none-any.whl", hash = "sha256:c434598117762e2bd304e526244f67bf66bbd7b5d6cf22138be51ff661980343"}, + {file = "pytest-8.2.2.tar.gz", hash = "sha256:de4bb8104e201939ccdc688b27a89a7be2079b22e2bd2b07f806b6ba71117977"}, ] [package.dependencies] @@ -720,4 +736,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "a117ec9e876ddfc24356d7eb44749bdca6a3d2dda049f0157cb4f1f87a0c6de1" +content-hash = "3e56c3152459fb66ba8883f066806f48504c64a53fc9fbfd061ad51d1fdbedc7" diff --git a/pgscatalog.match/pyproject.toml b/pgscatalog.match/pyproject.toml index 5604ce6..171ede2 100644 --- a/pgscatalog.match/pyproject.toml +++ b/pgscatalog.match/pyproject.toml @@ -10,7 +10,7 @@ packages = [ [tool.poetry.dependencies] python = "^3.11" -polars = "^0.20.5" +polars = "0.20.30" pyarrow = "^15.0.0" "pgscatalog.core" = {path = "../pgscatalog.core", develop = true} diff --git a/pgscatalog.match/src/pgscatalog/match/__init__.py b/pgscatalog.match/src/pgscatalog/match/__init__.py index 7acda8d..f2458bf 100644 --- a/pgscatalog.match/src/pgscatalog/match/__init__.py +++ b/pgscatalog.match/src/pgscatalog/match/__init__.py @@ -1,9 +1,9 @@ import logging -from .lib.variantframe import VariantFrame -from .lib.scoringfileframe import ScoringFileFrame, match_variants -from .lib.matchresult import MatchResult, MatchResults -from .lib.plinkscorefiles import PlinkScoreFiles +from .lib import VariantFrame +from .lib import ScoringFileFrame, match_variants +from .lib import MatchResult, MatchResults +from .lib import PlinkScoreFiles log_fmt = "%(name)s: %(asctime)s %(levelname)-8s %(message)s" diff --git a/pgscatalog.match/src/pgscatalog/match/lib/matchresult.py b/pgscatalog.match/src/pgscatalog/match/lib/matchresult.py index 5504aae..7661d64 100644 --- a/pgscatalog.match/src/pgscatalog/match/lib/matchresult.py +++ b/pgscatalog.match/src/pgscatalog/match/lib/matchresult.py @@ -287,6 +287,7 @@ def write_scorefiles( # double check log count vs scoring file variant count self._log_OK = check_log_count(scorefile=score_df, summary_log=self.summary_log) + # will be empty if no scores pass match threshold, so nothing gets written plink = PlinkFrames.from_matchresult(self.df) outfs = [] for frame in plink: diff --git a/pgscatalog.match/src/pgscatalog/match/lib/plinkscorefiles.py b/pgscatalog.match/src/pgscatalog/match/lib/plinkscorefiles.py index a8baac4..db92540 100644 --- a/pgscatalog.match/src/pgscatalog/match/lib/plinkscorefiles.py +++ b/pgscatalog.match/src/pgscatalog/match/lib/plinkscorefiles.py @@ -10,6 +10,7 @@ import polars as pl +from pgscatalog.core import chrom_keyfunc, effect_type_keyfunc logger = logging.getLogger(__name__) @@ -68,28 +69,32 @@ def merge(self, directory): if dataset not in x.stem: raise ValueError(f"Invalid dataset: {dataset} and {x.stem}") - def effect_type_sort(path): - """Sort by effect type and n""" - return path.stem.split("_")[2:] + sorted_paths = sorted(self._elements, key=effect_type_keyfunc()) - def chrom_sort(path): - try: - return int(path.name.split("_")[1]) - except ValueError: - return path.name.split("_")[1] + for k, g in itertools.groupby(sorted_paths, key=effect_type_keyfunc()): + logger.info(f"Writing combined scoring file {k}") - sorted_paths = sorted(self._elements, key=effect_type_sort) + # tidy up keys to create the output file name + # keyfunc returns: + # (('additive',), ('', 0.0, 'scorefile')) + # need: [additive, 0] + keys = (x for x in (itertools.chain(*k)) if x != "") + keys = list(str(int(x)) if isinstance(x, float) else x for x in keys) + keys.pop() # drop scorefile - for k, g in itertools.groupby(sorted_paths, key=effect_type_sort): - logger.info(f"Writing combined scoring file {k}") # multi-chrom -> ALL - fout = "_".join([dataset, "ALL", *k]) + ".gz" - paths = sorted(list(g), key=chrom_sort) + fout = "_".join([dataset, "ALL", *keys]) + ".scorefile.gz" + paths = sorted(list(g), key=chrom_keyfunc()) + # infer_schema_length: read all columns as utf8 to simplify joining dfs = (pl.read_csv(x, separator="\t", infer_schema_length=0) for x in paths) # diagonal concat is important to handle different column sets across dfs df = pl.concat(dfs, how="diagonal").fill_null(value="0") + logger.info("Score files combined successfully") - with gzip.open(pathlib.Path(directory) / fout, "wb") as gcsv: + with gzip.open( + pathlib.Path(directory) / fout, "wb", compresslevel=6 + ) as gcsv: + logger.info(f"Writing out to {fout}") outf = io.TextIOWrapper(gcsv) df.write_csv(outf, separator="\t") diff --git a/pgscatalog.match/tests/test_merge_cli.py b/pgscatalog.match/tests/test_merge_cli.py index 0beadb2..1c7bd0c 100644 --- a/pgscatalog.match/tests/test_merge_cli.py +++ b/pgscatalog.match/tests/test_merge_cli.py @@ -143,3 +143,6 @@ def test_strict_merge(tmp_path_factory, good_scorefile, match_ipc): with pytest.raises(ZeroMatchesError): with patch("sys.argv", flargs): run_merge() + + # don't write any scoring files + assert glob(str(outdir / "*scorefile.gz")) == []