-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcore_mama.py
executable file
·206 lines (153 loc) · 8.11 KB
/
core_mama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#!/usr/bin/env python3
"""
Python functions that implement the core MAMA processing
"""
import gc
from typing import Tuple
import numpy as np
# Functions ##################################
#################################
def create_omega_matrix(ldscores: np.ndarray, reg_ldscore_coefs: np.ndarray) -> np.ndarray:
"""
Creates the omega matrix for each SNP. Assumes the PxP submatrices in the ldscores and the
PxP matrix of LD regression coefficients have the same ordering of corresponding ancestries.
:param ldscores: (Mx)PxP symmetric matrices containing LD scores (PxP per SNP)
:param reg_ldscore_coefs: PxP symmetric matrix containing LD score regression coefficients
:return: The Omega matrices as indicated in the MAMA paper (PxP per SNP) = (Mx)PxP
"""
# Multiply PxP slices of LD scores with the regression coefficients component-wise
return reg_ldscore_coefs * ldscores
#################################
def tweak_omega(omega_slice: np.ndarray) -> np.ndarray:
"""
Tweaks the off-diagonal elements of a non positive semi-definite omega matrix to make it
positive semi-definite. This assumes that necessary checks are done ahead of time to ensure
this method will converge (e.g. all diagonal elements must be positive)
:param omega_slice: PxP symmetric Omega matrix
:return np.ndarray: A modified omega that is now positive semi-definite
"""
# First get the component-wise square root of the diagonal
omega_diag = np.diag(omega_slice).copy()
omega_sqrt_diag = np.sqrt(omega_diag)
# Clamp off diagonal elements to values based on product of the corresponding diagonal entries
omega_slice = np.minimum(np.outer(omega_sqrt_diag, omega_sqrt_diag), omega_slice)
# Then, scale down off-diagonal elements until positive semi-definite
d_indices = np.diag_indices_from(omega_slice)
while np.any(np.linalg.eigvalsh(omega_slice) < 0.0):
omega_slice *= 0.99
omega_slice[d_indices] = omega_diag
return omega_slice
#################################
def qc_omega(omega: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""
Runs checks over the omega matrices for positive-semi-definiteness. Tweaks omega where possible
to correct for non-positive-semi-definiteness and returns an array of length M
(where M = number of SNPs) along the SNP axis (the first dimension of the MxPxP omega)
where True indicates positive semi-definiteness and False indicates
non-positive semi-definiteness
:param omega: MxPxP matrix for Omega values
:return: Tuple containing:
1) Array of length M where True indicates positive semi-definiteness and False
indicates non-positive semi-definiteness
2) Array of length M where True indicates the omega was tweaked to make it
positive semi-definite (False otherwise)
"""
# Create result vectors of length M, all values defaulting to False
M = omega.shape[0]
pos_semi_def_indices = np.full(M, False)
tweaked_omega_indices = np.full(M, False)
# Iterate over the M PxP matrices of sigma
for i in range(M):
omega_slice = omega[i, :, :]
# Check for positive semi-definiteness (if PSD, set to True and move on)
if np.all(np.linalg.eigvalsh(omega_slice) >= 0.0):
pos_semi_def_indices[i] = True
continue
# If diagonal entries aren't positive, move on
if np.any(np.diag(omega_slice) <= 0.0):
continue
# We can try to tweak ths slice of omega to become positive semi-definite
omega[i, :, :] = tweak_omega(omega_slice)
pos_semi_def_indices[i] = True
tweaked_omega_indices[i] = True
return pos_semi_def_indices, tweaked_omega_indices
#################################
def create_sigma_matrix(sumstat_ses, reg_se2_coefs, reg_const_coefs):
"""
Creates the sigma matrix for each SNP. Assumes the PxP submatrices in the ldscores and the
PxP matrix of LD regression coefficients have the same ordering of corresponding ancestries.
:param sumstat_se: Standard errors for the SNPs for each population (M x P matrix)
:param reg_se2_coefs: PxP symmetric matrix containing SE^2 regression coefficients
:param reg_const_coefs: PxP symmetric matrix containing Constant term regression coefficients
:return: The Sigma matrices as indicated in the MAMA paper (PxP per SNP) = (Mx)PxP
"""
# Create an initial MxPxP matrix with the se components
result_matrix = sumstat_ses[:, :, np.newaxis] * sumstat_ses[:, np.newaxis, :]
# Incorporate the regression coefficients
result_matrix *= reg_se2_coefs
result_matrix += reg_const_coefs
return result_matrix
#################################
def qc_sigma(sigma: np.ndarray) -> np.ndarray:
"""
Runs checks over the sigma matrices for positive-definiteness. Returns an array of length M
(where M = number of SNPs) along the SNP axis (the first dimension of the MxPxP sigma)
where True indicates positive definiteness and False indicates non-positive definiteness
:param sigma: MxPxP matrix for Sigma values
:return np.ndarray: Array of length M where True indicates positive definiteness and False
indicates non-positive definiteness
"""
# Create result vector of length M, all values defaulting to False
M = sigma.shape[0]
result = np.full(M, False)
# Iterate over the M PxP matrices of sigma
for i in range(M):
sigma_slice = sigma[i, :, :]
try:
np.linalg.cholesky(sigma_slice)
result[i] = True
except np.linalg.LinAlgError:
# If not positive definite, then the Cholesky decomposition raises a LinAlgError
pass
return result
#################################
def run_mama_method(betas, omega, sigma):
"""
Runs the core MAMA method to combine results and generate final, combined summary statistics
:param betas: MxP matrix of beta values (M = # of SNPs, P = # of ancestries)
:param omega: MxPxP matrix of omega values (M = # of SNPs, P = # of ancestries)
:param sigma: MxPxP matrix of sigma values (M = # of SNPs, P = # of ancestries)
:return: Tuple containing:
1) Result ndarray of betas (MxP) where M = SNPs and P = populations
2) Result ndarray of beta standard errors (MxP) where M = SNPs and P = populations
"""
# Get values for M and P (used to keep track of slices / indices / broadcasting)
M, P, *extra_dimensions = omega.shape # pylint: disable=unused-variable
# Create a 3D matrix, M rows of Px1 column vectors with shape (M, P, 1)
d_indices = np.arange(P)
omega_diag = omega[:, d_indices, d_indices][:, :, np.newaxis]
omega_pp_scaled = np.divide(omega, omega_diag) # Slice rows are Omega'_pjj / omega_pp,j
# Produce center matrix in steps (product of omega terms, add omega and sigma, then invert)
center_matrix_inv = -omega_pp_scaled[:, :, :, np.newaxis] * omega[:, :, np.newaxis, :]
center_matrix_inv += omega[:, np.newaxis, :, :] + sigma[:, np.newaxis, :, :] # Broadcast add
center_matrix = np.linalg.inv(center_matrix_inv) # Inverts each slice separately
del center_matrix_inv # Clean up the inverse matrix to free space
gc.collect()
# Calculate (Omega'_p,j/omega_pp,j) * center_matrix
left_product = np.matmul(omega_pp_scaled[:, :, np.newaxis, :], center_matrix)
del center_matrix # Clean up the center matrix to free space
gc.collect()
# Calculate denominator (M x P x 1 x 1)
denom = np.matmul(left_product,
np.transpose(omega_pp_scaled[:, :, np.newaxis, :], (0, 1, 3, 2)))
denom_recip = np.reciprocal(denom)
denom_recip_view = denom_recip.view()
denom_recip_view.shape = (M, P)
# Calculate numerator (M x P x 1 x 1))
left_product_view = left_product.view().reshape(M, P, P)
numer = np.matmul(left_product_view, betas[:, :, np.newaxis])
numer_view = numer.view().reshape(M, P)
# Calculate result betas and standard errors
new_betas = denom_recip_view * numer_view
new_beta_ses = np.sqrt(denom_recip_view)
return new_betas, new_beta_ses