Skip to content

Commit

Permalink
General usability updates to speed up VIPRS/VIPRSGrid and provide mor…
Browse files Browse the repository at this point in the history
…e control to the user.
  • Loading branch information
shz9 committed Jun 3, 2024
1 parent 3722488 commit 4eee62f
Show file tree
Hide file tree
Showing 13 changed files with 318 additions and 170 deletions.
22 changes: 22 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [0.1.2] - 2024-06-03

### Changed

- Fixed bug in implementation of `.fit` method of VIPRS models. Specifically,
there was an issue with the `continued=True` flag not working because the `OptimizeResult`
object wasn't refreshed.
- Replaced `print` statements with `logging` where appropriate (still needs some more work).
- Updated way we measure peak memory in `viprs_fit`
- Updated `dict_concat` to just return the element if there's a single entry.
- Refactored pars of `VIPRS` to cache some recurring computations.
- Updated `VIPRSBMA` & `VIPRSGridSearch` to only consider models that
successfully converged.

### Added

- Added SNP position to output table from VIPRS objects.
- Added measure of time taken to prepare data in `viprs_fit`.
- Added option to keep long-range LD regions in `viprs_fit`.
- Added convergence check based on parameter values.
- Added separate method for initializing optimization-related objects.

## [0.1.1] - 2024-04-24

### Changed
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_e_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ def exec_func():
dfs[d] = di

output_fname = (f"{args.file_prefix}timing_results_imp{args.implementation}_model{args.model}_"
f"lm{args.low_memory}_pr{args.float_precision}_threads{t}.csv")
f"lm{args.low_memory}_dq{args.dequantize_on_the_fly}_pr{args.float_precision}_threads{t}.csv")

# Calculate time per-iteration:
dfs['TimePerIteration'] = dfs['Time'] / dfs['Repeats']
Expand Down
46 changes: 33 additions & 13 deletions bin/viprs_fit
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def init_data(args, verbose=True):
from magenpy.parsers.sumstats_parsers import SumstatsParser

if verbose:
print('\n{:-^62}\n'.format(' Reading input data '))
print('\n{:-^62}\n'.format(' Reading & harmonizing input data '))

# Prepare the summary statistics parsers:
if args.sumstats_format == 'custom':
Expand All @@ -159,6 +159,13 @@ def init_data(args, verbose=True):
temp_dir=args.temp_dir,
verbose=verbose,
threads=args.threads)

# Unless the user explicitly decides to keep them, filter long-range LD regions:
if not args.keep_lrld:
print("> Filtering long-range LD regions...")
for ld in gdl.ld.values():
ld.filter_long_range_ld_regions()

# Read the summary statistics file(s):
gdl.read_summary_statistics(args.sumstats_path,
sumstats_format=ss_format,
Expand Down Expand Up @@ -277,12 +284,14 @@ def prepare_model(args, verbose=True):
p_model = partial(VIPRS,
float_precision=args.float_precision,
low_memory=not args.use_symmetric_ld,
dequantize_on_the_fly=args.dequantize_on_the_fly,
threads=args.threads)
elif args.model == 'VIPRSMix':
p_model = partial(VIPRSMix,
K=args.n_components,
float_precision=args.float_precision,
low_memory=not args.use_symmetric_ld,
dequantize_on_the_fly=args.dequantize_on_the_fly,
threads=args.threads)

elif args.hyp_search in ('BMA', 'GS'):
Expand All @@ -302,13 +311,15 @@ def prepare_model(args, verbose=True):
grid=grid,
float_precision=args.float_precision,
low_memory=not args.use_symmetric_ld,
dequantize_on_the_fly=args.dequantize_on_the_fly,
threads=args.threads)

else:

base_model = partial(VIPRS,
float_precision=args.float_precision,
low_memory=not args.use_symmetric_ld,
dequantize_on_the_fly=args.dequantize_on_the_fly,
threads=args.threads)

p_model = partial(BayesOpt,
Expand All @@ -332,7 +343,6 @@ def fit_model(model, data_dict, args):

import time
import numpy as np
from magenpy.utils.system_utils import get_memory_usage

# Set the random seed:
np.random.seed(args.seed)
Expand All @@ -351,9 +361,6 @@ def fit_model(model, data_dict, args):

# ----------------------------------------------------------

# Get the memory usage before loading the model:
start_mem = get_memory_usage()

# Initialize the model:
load_start_time = time.time()
m = model(data_dict['train'])
Expand All @@ -365,17 +372,13 @@ def fit_model(model, data_dict, args):

# Fit the model to data:
fit_start_time = time.time()
m = m.fit()
m = m.fit(max_iter=args.max_iter)
fit_end_time = time.time()

# ----------------------------------------------------------

# Get memory usage after fitting the model:
end_mem = get_memory_usage()

# Record the profiler metrics:
result_dict['ProfilerMetrics']['Fit_time'] = round(fit_end_time - fit_start_time, 2)
result_dict['ProfilerMetrics']['Memory_usage_MB'] = round(end_mem - start_mem, 2)
result_dict['ProfilerMetrics']['Total_Iterations'] = m.optim_result.iterations

# ----------------------------------------------------------
Expand Down Expand Up @@ -517,9 +520,14 @@ def main():
parser.add_argument('--use-symmetric-ld', dest='use_symmetric_ld', action='store_true',
default=False,
help='Use the symmetric form of the LD matrix when fitting the model.')
parser.add_argument('--dequantize-on-the-fly', dest='dequantize_on_the_fly', action='store_true',
default=False,
help='Dequantize the entries of the LD matrix on-the-fly when fitting the model.')
parser.add_argument('--n-components', dest='n_components', type=int, default=3,
help='The number of non-null Gaussian mixture components to use with the VIPRSMix model '
'(i.e. excluding the spike component).')
parser.add_argument('--max-iter', dest='max_iter', type=int, default=1000,
help='The maximum number of iterations to run the coordinate ascent algorithm.')

# Arguments for Hyperparameter tuning / model initialization:
parser.add_argument('--h2-est', dest='h2_est', type=float,
Expand Down Expand Up @@ -556,6 +564,8 @@ def main():

parser.add_argument('--genomewide', dest='genomewide', action='store_true', default=False,
help='Fit all chromosomes jointly')
parser.add_argument('--keep-lrld', dest='keep_lrld', action='store_true', default=False,
help='Keep the Long Range LD regions during inference (these regions are filtered by default).')
parser.add_argument('--backend', dest='backend', type=str.lower, default='xarray',
choices={'xarray', 'plink'},
help='The backend software used for computations on the genotype matrix.')
Expand All @@ -578,11 +588,14 @@ def main():
import time
from datetime import timedelta
import pandas as pd
import os.path as osp
from magenpy.utils.system_utils import makedir
import numpy as np
from magenpy.utils.system_utils import makedir, get_peak_memory_usage
from joblib import Parallel, delayed
from joblib.externals.loky import get_reusable_executor

# ----------------------------------------------------------
# Print the parsed arguments:

print('{:-^62}\n'.format(' Parsed arguments '))

for key, val in vars(args).items():
Expand All @@ -599,6 +612,8 @@ def main():

# (2) Read the data:
data_loaders = init_data(args)
# Record time for data preparation:
data_prep_time = time.time()

# (3) Prepare the model:
model = prepare_model(args)
Expand All @@ -611,6 +626,9 @@ def main():
for idx, dl in enumerate(data_loaders)
)

# Shut down the parallel executor:
get_reusable_executor().shutdown(wait=True)

# Record end time:
total_end_time = time.time()

Expand All @@ -634,6 +652,8 @@ def main():
profm_table = pd.concat([pd.DataFrame(r['ProfilerMetrics'], index=[0]).assign(Chromosome=r['Chromosome'])
for r in fit_results])
profm_table['Total_WallClockTime'] = round(total_end_time - total_start_time, 2)
profm_table['DataPrep_Time'] = round(data_prep_time - total_start_time, 2)
profm_table['Peak_Memory_MB'] = round(get_peak_memory_usage(include_children=True) or np.nan, 2)

output_prefix = osp.join(args.output_dir, args.output_prefix + args.model + '_' + args.hyp_search)

Expand All @@ -647,7 +667,7 @@ def main():
valid_tables.to_csv(output_prefix + '.validation', sep="\t", index=False)

if args.output_profiler_metrics:
profm_table.to_csv(output_prefix + '.time', sep="\t", index=False)
profm_table.to_csv(output_prefix + '.prof', sep="\t", index=False)

print('>>> Total Runtime:\n', timedelta(seconds=total_end_time - total_start_time))

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
numpy
scipy
pandas
pandas<=2.2.1 # Seen installation issues with newer versions
tqdm
magenpy>=0.1
statsmodels
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def no_cythonize(cy_extensions, **_ignore):

setup(
name="viprs",
version="0.1.1",
version="0.1.2",
author="Shadi Zabad",
author_email="shadi.zabad@mail.mcgill.ca",
description="Variational Inference of Polygenic Risk Scores (VIPRS)",
Expand Down
4 changes: 2 additions & 2 deletions viprs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
from .model.gridsearch.HyperparameterGrid import HyperparameterGrid
from .utils.data_utils import *

__version__ = '0.1.1'
__release_date__ = 'April 2024'
__version__ = '0.1.2'
__release_date__ = 'May 2024'
7 changes: 6 additions & 1 deletion viprs/model/BayesPRSModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def harmonize_data(self, gdl=None, parameter_table=None):

return pip, post_mean_beta, post_var_beta

def to_table(self, col_subset=('CHR', 'SNP', 'A1', 'A2'), per_chromosome=False):
def to_table(self, col_subset=('CHR', 'SNP', 'POS', 'A1', 'A2'), per_chromosome=False):
"""
Output the posterior estimates for the effect sizes to a pandas dataframe.
:param col_subset: The subset of columns to include in the tables (in addition to the effect sizes).
Expand Down Expand Up @@ -329,6 +329,11 @@ def read_inferred_parameters(self, f_names, sep=r"\s+"):
def write_inferred_parameters(self, f_name, per_chromosome=False, sep="\t"):
"""
A convenience method to write the inferred posterior for the effect sizes to file.
TODO:
* Support outputting scoring files compatible with PGS catalog format:
https://www.pgscatalog.org/downloads/#dl_scoring_files
:param f_name: The filename (or directory) where to write the effect sizes
:param per_chromosome: If True, write a file for each chromosome separately.
:param sep: The delimiter for the file (tab by default).
Expand Down
Loading

0 comments on commit 4eee62f

Please sign in to comment.