From 393ad46aba44269f973374e2540ab34e4ea921f7 Mon Sep 17 00:00:00 2001 From: Yakov Date: Fri, 28 Jun 2024 16:23:10 +0100 Subject: [PATCH] fix(SusieFineMapperStep): adding filtering of NANs in LD (#654) --- src/gentropy/susie_finemapper.py | 139 +++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/src/gentropy/susie_finemapper.py b/src/gentropy/susie_finemapper.py index 16c2ef1cc..b0821e2f0 100644 --- a/src/gentropy/susie_finemapper.py +++ b/src/gentropy/susie_finemapper.py @@ -1108,6 +1108,35 @@ def susie_finemapper_one_studylocus_row_v3_dev_ss_gathered( gnomad_ld = GnomADLDMatrix.get_numpy_matrix( gwas_index, gnomad_ancestry=major_population ) + # Module to remove NANs from the LD matrix + if sum(sum(np.isnan(gnomad_ld))) > 0: + gwas_index = gwas_index.toPandas() + + # First round of filtering out the variants with NANs + nan_count = 1 - (sum(np.isnan(gnomad_ld)) / len(gnomad_ld)) + indices = np.where(nan_count >= 0.98) + indices = indices[0] + gnomad_ld = gnomad_ld[indices][:, indices] + + gwas_index = gwas_index.iloc[indices, :] + + if len(gwas_index) == 0: + logging.warning("No overlapping variants in the LD Index") + return None + + # Second round of filtering out the variants with NANs + nan_count = sum(np.isnan(gnomad_ld)) + indices = np.where(nan_count == 0) + indices = indices[0] + + gnomad_ld = gnomad_ld[indices][:, indices] + gwas_index = gwas_index.iloc[indices, :] + + if len(gwas_index) == 0: + logging.warning("No overlapping variants in the LD Index") + return None + + gwas_index = session.spark.createDataFrame(gwas_index) else: gwas_index = gwas_df.join( ld_index.select("variantId", "alleles", "idx"), on="variantId" @@ -1119,6 +1148,45 @@ def susie_finemapper_one_studylocus_row_v3_dev_ss_gathered( gnomad_ld = GnomADLDMatrix.get_numpy_matrix( gwas_index, gnomad_ancestry=major_population ) + # Module to remove NANs from the LD matrix + if sum(sum(np.isnan(gnomad_ld))) > 0: + gwas_index = gwas_index.toPandas() + + # First round of filtering out the variants with NANs + nan_count = 1 - (sum(np.isnan(gnomad_ld)) / len(gnomad_ld)) + indices = np.where(nan_count >= 0.98) + indices = indices[0] + gnomad_ld = gnomad_ld[indices][:, indices] + + gwas_index = gwas_index.iloc[indices, :] + + if len(gwas_index) == 0: + logging.warning("No overlapping variants in the LD Index") + return None + + # Second round of filtering out the variants with NANs + nan_count = sum(np.isnan(gnomad_ld)) + indices = np.where(nan_count == 0) + indices = indices[0] + + gnomad_ld = gnomad_ld[indices][:, indices] + gwas_index = gwas_index.iloc[indices, :] + + if len(gwas_index) == 0: + logging.warning("No overlapping variants in the LD Index") + return None + + gwas_index = session.spark.createDataFrame(gwas_index) + + # sanity filters on LD matrix + np.fill_diagonal(gnomad_ld, 1) + gnomad_ld[gnomad_ld > 1] = 1 + gnomad_ld[gnomad_ld < -1] = -1 + upper_triangle = np.triu(gnomad_ld) + gnomad_ld = ( + upper_triangle + upper_triangle.T - np.diag(upper_triangle.diagonal()) + ) + np.fill_diagonal(gnomad_ld, 1) out = SusieFineMapperStep.susie_finemapper_from_prepared_dataframes( GWAS_df=gwas_df, @@ -1319,6 +1387,37 @@ def susie_finemapper_one_sl_row_v4_ss_gathered_boundaries( gnomad_ld = GnomADLDMatrix.get_numpy_matrix( gwas_index, gnomad_ancestry=major_population ) + + # Module to remove NANs from the LD matrix + if sum(sum(np.isnan(gnomad_ld))) > 0: + gwas_index = gwas_index.toPandas() + + # First round of filtering out the variants with NANs + nan_count = 1 - (sum(np.isnan(gnomad_ld)) / len(gnomad_ld)) + indices = np.where(nan_count >= 0.98) + indices = indices[0] + gnomad_ld = gnomad_ld[indices][:, indices] + + gwas_index = gwas_index.iloc[indices, :] + + if len(gwas_index) == 0: + logging.warning("No overlapping variants in the LD Index") + return None + + # Second round of filtering out the variants with NANs + nan_count = sum(np.isnan(gnomad_ld)) + indices = np.where(nan_count == 0) + indices = indices[0] + + gnomad_ld = gnomad_ld[indices][:, indices] + gwas_index = gwas_index.iloc[indices, :] + + if len(gwas_index) == 0: + logging.warning("No overlapping variants in the LD Index") + return None + + gwas_index = session.spark.createDataFrame(gwas_index) + else: gwas_index = gwas_df.join( ld_index.select("variantId", "alleles", "idx"), on="variantId" @@ -1331,6 +1430,46 @@ def susie_finemapper_one_sl_row_v4_ss_gathered_boundaries( gwas_index, gnomad_ancestry=major_population ) + # Module to remove NANs from the LD matrix + if sum(sum(np.isnan(gnomad_ld))) > 0: + gwas_index = gwas_index.toPandas() + + # First round of filtering out the variants with NANs + nan_count = 1 - (sum(np.isnan(gnomad_ld)) / len(gnomad_ld)) + indices = np.where(nan_count >= 0.98) + indices = indices[0] + gnomad_ld = gnomad_ld[indices][:, indices] + + gwas_index = gwas_index.iloc[indices, :] + + if len(gwas_index) == 0: + logging.warning("No overlapping variants in the LD Index") + return None + + # Second round of filtering out the variants with NANs + nan_count = sum(np.isnan(gnomad_ld)) + indices = np.where(nan_count == 0) + indices = indices[0] + + gnomad_ld = gnomad_ld[indices][:, indices] + gwas_index = gwas_index.iloc[indices, :] + + if len(gwas_index) == 0: + logging.warning("No overlapping variants in the LD Index") + return None + + gwas_index = session.spark.createDataFrame(gwas_index) + + # sanity filters on LD matrix + np.fill_diagonal(gnomad_ld, 1) + gnomad_ld[gnomad_ld > 1] = 1 + gnomad_ld[gnomad_ld < -1] = -1 + upper_triangle = np.triu(gnomad_ld) + gnomad_ld = ( + upper_triangle + upper_triangle.T - np.diag(upper_triangle.diagonal()) + ) + np.fill_diagonal(gnomad_ld, 1) + out = SusieFineMapperStep.susie_finemapper_from_prepared_dataframes( GWAS_df=gwas_df, ld_index=gwas_index,