diff --git a/hail/python/hail/docs/functions/index.rst b/hail/python/hail/docs/functions/index.rst index c316db86969..4d3f7a1ab7b 100644 --- a/hail/python/hail/docs/functions/index.rst +++ b/hail/python/hail/docs/functions/index.rst @@ -164,6 +164,7 @@ These functions are exposed at the top level of the module, e.g. ``hl.case``. chi_squared_test fisher_exact_test contingency_table_test + cochran_mantel_haenszel_test dbeta dpois hardy_weinberg_test diff --git a/hail/python/hail/docs/functions/stats.rst b/hail/python/hail/docs/functions/stats.rst index 44f8e059b52..4ab7728af39 100644 --- a/hail/python/hail/docs/functions/stats.rst +++ b/hail/python/hail/docs/functions/stats.rst @@ -7,6 +7,7 @@ Statistical functions chi_squared_test fisher_exact_test contingency_table_test + cochran_mantel_haenszel_test dbeta dchisq dnorm @@ -26,6 +27,7 @@ Statistical functions .. autofunction:: chi_squared_test .. autofunction:: fisher_exact_test .. autofunction:: contingency_table_test +.. autofunction:: cochran_mantel_haenszel_test .. autofunction:: dbeta .. autofunction:: dchisq .. autofunction:: dnorm diff --git a/hail/python/hail/expr/__init__.py b/hail/python/hail/expr/__init__.py index 23739d77270..c3e9c31e11b 100644 --- a/hail/python/hail/expr/__init__.py +++ b/hail/python/hail/expr/__init__.py @@ -110,6 +110,7 @@ bind, rbind, contingency_table_test, + cochran_mantel_haenszel_test, dbeta, dict, dpois, @@ -334,6 +335,7 @@ 'bind', 'rbind', 'contingency_table_test', + 'cochran_mantel_haenszel_test', 'dbeta', 'dict', 'dpois', diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index 198e2534fa2..042e25de00d 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -824,6 +824,84 @@ def contingency_table_test(c1, c2, c3, c4, min_cell_count) -> StructExpression: return _func("contingency_table_test", ret_type, c1, c2, c3, c4, min_cell_count) +# We use 64-bit integers. +# It is relatively easy to encounter an integer overflow bug with 32-bit integers. +@typecheck(a=expr_array(expr_int64), b=expr_array(expr_int64), c=expr_array(expr_int64), d=expr_array(expr_int64)) +def cochran_mantel_haenszel_test( + a: Union[tarray, list], b: Union[tarray, list], c: Union[tarray, list], d: Union[tarray, list] +) -> StructExpression: + """Perform the Cochran-Mantel-Haenszel test for association. + + Examples + -------- + >>> a = [56, 61, 73, 71] + >>> b = [69, 257, 65, 48] + >>> c = [40, 57, 71, 55] + >>> d = [77, 301, 79, 48] + >>> hl.eval(hl.cochran_mantel_haenszel_test(a, b, c, d)) + Struct(test_statistic=5.0496881823306765, p_value=0.024630370456863417) + + >>> mt = ds.filter_rows(mt.locus == hl.Locus(20, 10633237)) + >>> mt.count_rows() + 1 + >>> a, b, c, d = mt.aggregate_entries( + ... hl.tuple([ + ... hl.array([hl.agg.count_where(mt.GT.is_non_ref() & mt.pheno.is_case & mt.pheno.is_female), hl.agg.count_where(mt.GT.is_non_ref() & mt.pheno.is_case & ~mt.pheno.is_female)]), + ... hl.array([hl.agg.count_where(mt.GT.is_non_ref() & ~mt.pheno.is_case & mt.pheno.is_female), hl.agg.count_where(mt.GT.is_non_ref() & ~mt.pheno.is_case & ~mt.pheno.is_female)]), + ... hl.array([hl.agg.count_where(~mt.GT.is_non_ref() & mt.pheno.is_case & mt.pheno.is_female), hl.agg.count_where(~mt.GT.is_non_ref() & mt.pheno.is_case & ~mt.pheno.is_female)]), + ... hl.array([hl.agg.count_where(~mt.GT.is_non_ref() & ~mt.pheno.is_case & mt.pheno.is_female), hl.agg.count_where(~mt.GT.is_non_ref() & ~mt.pheno.is_case & ~mt.pheno.is_female)]) + ... ]) + ... ) + >>> hl.eval(hl.cochran_mantel_haenszel_test(a, b, c, d)) + Struct(test_statistic=0.2188830334629822, p_value=0.6398923118508772) + + Notes + ----- + See the `Wikipedia article `_ + for more details. + + Parameters + ---------- + a : :class:`.ArrayExpression` of type :py:data:`.tint64` + Values for the upper-left cell in the contingency tables. + b : :class:`.ArrayExpression` of type :py:data:`.tint64` + Values for the upper-right cell in the contingency tables. + c : :class:`.ArrayExpression` of type :py:data:`.tint64` + Values for the lower-left cell in the contingency tables. + d : :class:`.ArrayExpression` of type :py:data:`.tint64` + Values for the lower-right cell in the contingency tables. + + Returns + ------- + :class:`.StructExpression` + A :class:`.tstruct` expression with two fields, `test_statistic` + (:py:data:`.tfloat64`) and `p_value` (:py:data:`.tfloat64`). + """ + # The variable names below correspond to the notation used in the Wikipedia article. + # https://en.m.wikipedia.org/wiki/Cochran%E2%80%93Mantel%E2%80%93Haenszel_statistics + n1 = hl.zip(a, b).map(lambda ab: ab[0] + ab[1]) + n2 = hl.zip(c, d).map(lambda cd: cd[0] + cd[1]) + m1 = hl.zip(a, c).map(lambda ac: ac[0] + ac[1]) + m2 = hl.zip(b, d).map(lambda bd: bd[0] + bd[1]) + t = hl.zip(n1, n2).map(lambda nn: nn[0] + nn[1]) + + def numerator_term(a, n1, m1, t): + return a - n1 * m1 / t + + # The numerator comes from the link below, not from the Wikipedia article. + # https://www.biostathandbook.com/cmh.html + numerator = (hl.abs(hl.sum(hl.zip(a, n1, m1, t).map(lambda tup: numerator_term(*tup)))) - 0.5) ** 2 + + def denominator_term(n1, n2, m1, m2, t): + return n1 * n2 * m1 * m2 / (t**3 - t**2) + + denominator = hl.sum(hl.zip(n1, n2, m1, m2, t).map(lambda tup: denominator_term(*tup))) + + test_statistic = numerator / denominator + p_value = pchisqtail(test_statistic, 1) + return struct(test_statistic=test_statistic, p_value=p_value) + + @typecheck( collection=expr_oneof( expr_dict(), expr_set(expr_tuple([expr_any, expr_any])), expr_array(expr_tuple([expr_any, expr_any])) diff --git a/hail/python/test/hail/expr/test_expr.py b/hail/python/test/hail/expr/test_expr.py index 10f20ecbf1f..1977f2d8834 100644 --- a/hail/python/test/hail/expr/test_expr.py +++ b/hail/python/test/hail/expr/test_expr.py @@ -3600,6 +3600,48 @@ def test_contingency_table_test(self): self.assertAlmostEqual(res['p_value'] / 2.1565e-7, 1.0, places=4) self.assertAlmostEqual(res['odds_ratio'], 4.91805817) + def test_cochran_mantel_haenszel_test(self): + # https://cran.r-project.org/web/packages/samplesizeCMH/vignettes/samplesizeCMH-introduction.html + a = [118, 154, 422, 670] + b = [62, 25, 88, 192] + c = [4, 13, 106, 3] + d = [141, 93, 90, 20] + + result = hl.eval(hl.cochran_mantel_haenszel_test(a, b, c, d)) + self.assertEqual(360.3311519725744, result['test_statistic']) + self.assertEqual(2.384935629406975e-80, result['p_value']) + + # https://www.biostathandbook.com/cmh.html + a = [708, 136, 106, 109, 801, 159, 151, 950] + b = [50, 24, 32, 22, 102, 27, 51, 173] + c = [169, 73, 17, 16, 180, 18, 28, 218] + d = [13, 14, 4, 26, 25, 13, 15, 33] + + expr = hl.cochran_mantel_haenszel_test(a, b, c, d) + result = hl.eval(expr) + self.assertEqual(6.07023412667767, result['test_statistic']) + self.assertEqual(0.013747873638119005, result['p_value']) + + a = [56, 61, 73, 71] + b = [69, 257, 65, 48] + c = [40, 57, 71, 55] + d = [77, 301, 79, 48] + + expr = hl.cochran_mantel_haenszel_test(a, b, c, d) + result = hl.eval(expr) + self.assertEqual(5.0496881823306765, result['test_statistic']) + self.assertEqual(0.024630370456863417, result['p_value']) + + a = hl.array([2, 4, 1, 1, 2]) + b = hl.array([46, 67, 86, 37, 92]) + c = hl.array([11, 12, 4, 6, 1]) + d = hl.array([41, 60, 76, 32, 93]) + + expr = hl.cochran_mantel_haenszel_test(a, b, c, d) + result = hl.eval(expr) + self.assertEqual(12.74572269532737, result['test_statistic']) + self.assertEqual(0.0003568242404514306, result['p_value']) + def test_hardy_weinberg_test(self): two_sided_res = hl.eval(hl.hardy_weinberg_test(1, 2, 1, one_sided=False)) self.assertAlmostEqual(two_sided_res['p_value'], 0.65714285) @@ -3835,6 +3877,9 @@ def test_array_fold_and_scan(self): hl.array_scan(lambda x, y: x + y, 0, [1.0, 2.0, 3.0]), [0.0, 1.0, 3.0, 6.0], tarray(tfloat64) ) + def test_sum(self): + self.assertValueEqual(hl.sum([1, 2, 3, 4]), 10, tint32) + def test_cumulative_sum(self): self.assertValueEqual(hl.cumulative_sum([1, 2, 3, 4]), [1, 3, 6, 10], tarray(tint32)) self.assertValueEqual(hl.cumulative_sum([1.0, 2.0, 3.0, 4.0]), [1.0, 3.0, 6.0, 10.0], tarray(tfloat64))