Skip to content

Commit

Permalink
[query] add hl.pgenchisq (hail-is#12605)
Browse files Browse the repository at this point in the history
* [query] add hl.pgenchisq

CHANGELOG: Add `hl.pgenchisq` the cumulative distribution function of the generalized chi-squared distribution.

The [Generalized Chi-Squared
Distribution](https://en.wikipedia.org/wiki/Generalized_chi-squared_distribution)
arises from weighted sums of sums of squares of independent normally distributed
variables and is used by `hl.skat` to generate p-values. The simplest
formulation I know for it is this:

    w     : R^n
    k     : Z^n
    lam   : R^n
    mu    : R
    sigma : R

    x   ~ N(mu, sigma^2)
    y_i ~ NonCentralChiSquared(k_i, lam_i)

    Z = x + w y^T
      = x + sum_i{ w_i y_i }
    Z ~ GeneralizedNonCentralChiSquared(w, k, lam, mu, sigma)

The non-central chi-squared distribution arises from a sum of independent
normally distributed variables with non-zero mean and unit variance. The
non-centrality parameter, lambda, is defined as the sum of the squares of the
means of each component normal random variable.

Although the non-central chi-squared distribution has a closed form
implementation (indeed, Hail implements this CDF: `hl.pchisqtail`), the
generalized chi-squared distribution does not have a closed form. There are at
least four distinct algorithms for evaluating the CDF. To my knowledge, the
oldest one is by Robert Davies:

    Davies, Robert. "The distribution of a linear combination of chi-squared
    random variables." Applied Statistics 29 323-333. 1980.

The [original publication](http://www.robertnz.net/pdf/lc_chisq.pdf) includes a
Fortran implementation in the publication. Davies'
[website](http://www.robertnz.net/QF.htm) also includes a C version.

Hail includes a copy of the C version as `davies.cpp`. I suspect this code
contains undefined behavior. Moreover, it is not supported on Apple M1 machines
because we don't ship binaries for that platform.

It seemed to me that the simplest solution is to port this algorithm to
Scala. This PR is that port. I tested against the 39 test cases provided Davies
with the source code. I also added some doctests based on the CDF plots from
Wikipedia. The same 39 test cases are tested in Scala and in Python.

I am open to suggestions for the name. `pgenchisq` seems to strike a balance
between clarity and brevity.

I believe this is the first CDF which can fail to converge. I included some
relevant debugging information. I think we should standardize on a schema, but I
need more examples before I am certain of the right standard.

I am open to critique of `GeneralizedChiSquaredDistribution.scala` but I will
strongly argue against significant refactoring. I worry that we will subtly
break this algorithm.

I directly reached out to Robert Davies to clarify the licensing of this
algorithm. It appears to have been released at least under both GPL2 and MIT by
unaffiliated third parties (who, really, have no right to apply a license to
it). Do not remove WIP until I resolve this.

With this PR in place, `hl.skat` can be implemented entirely in Python.

* clarify license
  • Loading branch information
danking authored Jan 23, 2023
1 parent c022a81 commit 20fc42f
Show file tree
Hide file tree
Showing 11 changed files with 1,544 additions and 2 deletions.
2 changes: 2 additions & 0 deletions hail/python/hail/docs/functions/stats.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Statistical functions
hardy_weinberg_test
binom_test
pchisqtail
pgenchisq
pnorm
pT
pF
Expand All @@ -32,6 +33,7 @@ Statistical functions
.. autofunction:: hardy_weinberg_test
.. autofunction:: binom_test
.. autofunction:: pchisqtail
.. autofunction:: pgenchisq
.. autofunction:: pnorm
.. autofunction:: pT
.. autofunction:: pF
Expand Down
3 changes: 2 additions & 1 deletion hail/python/hail/expr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
hardy_weinberg_test, parse_locus, parse_variant, variant_str, locus, locus_from_global_position,
interval, locus_interval, parse_locus_interval, call, is_defined, is_missing, is_nan, is_finite,
is_infinite, json, parse_json, log, log10, null, missing, or_else, coalesce, or_missing,
binom_test, pchisqtail, pl_dosage, pl_to_gp, pnorm, pT, pF, ppois, qchisqtail, qnorm, qpois,
binom_test, pchisqtail, pgenchisq, pl_dosage, pl_to_gp, pnorm, pT, pF, ppois, qchisqtail, qnorm, qpois,
range, _stream_range, zeros, rand_bool, rand_norm, rand_norm2d, rand_pois, rand_unif, rand_int32, rand_int64,
rand_beta, rand_gamma, rand_cat, rand_dirichlet, sqrt, corr, str, is_snp, is_mnp, is_transition,
is_transversion, is_insertion, is_deletion, is_indel, is_star, is_complex, is_strand_ambiguous,
Expand Down Expand Up @@ -124,6 +124,7 @@
'or_missing',
'binom_test',
'pchisqtail',
'pgenchisq',
'pl_dosage',
'pl_to_gp',
'pnorm',
Expand Down
157 changes: 157 additions & 0 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2046,6 +2046,7 @@ def pchisqtail(x, df, ncp=None, lower_tail=False, log_p=False) -> Float64Express
Parameters
----------
x : float or :class:`.Expression` of type :py:data:`.tfloat64`
The value at which to evaluate the CDF.
df : float or :class:`.Expression` of type :py:data:`.tfloat64`
Degrees of freedom.
ncp: float or :class:`.Expression` of type :py:data:`.tfloat64`
Expand All @@ -2066,6 +2067,162 @@ def pchisqtail(x, df, ncp=None, lower_tail=False, log_p=False) -> Float64Express
return _func("pnchisqtail", tfloat64, x, df, ncp, lower_tail, log_p)


PGENCHISQ_RETURN_TYPE = tstruct(value=tfloat64, n_iterations=tint32, converged=tbool, fault=tint32)


@typecheck(x=expr_float64,
w=expr_array(expr_float64),
k=expr_array(expr_int32),
lam=expr_array(expr_float64),
mu=expr_float64,
sigma=expr_float64,
max_iterations=nullable(expr_int32),
min_accuracy=nullable(expr_float64))
def pgenchisq(x, w, k, lam, mu, sigma, *, max_iterations=None, min_accuracy=None) -> Float64Expression:
r"""The cumulative probability function of a `generalized chi-squared distribution
<https://en.wikipedia.org/wiki/Generalized_chi-squared_distribution>`__.
The generalized chi-squared distribution has many interpretations. We share here four
interpretations of the values of this distribution:
1. A linear combination of normal variables and squares of normal variables.
2. A weighted sum of sums of squares of normally distributed values plus a normally distributed
value.
3. A weighted sum of chi-squared distributed values plus a normally distributed value.
4. A `"quadratic form" <https://en.wikipedia.org/wiki/Quadratic_form_(statistics)>`__ in a vector
of uncorrelated `standard normal
<https://en.wikipedia.org/wiki/Normal_distribution#Standard_normal_distribution>`__ values.
The parameters of this function correspond to the parameters of the third interpretation.
.. math::
\begin{aligned}
w &: R^n \quad k : Z^n \quad lam : R^n \quad mu : R \quad sigma : R \\
\\
x &\sim N(mu, sigma^2) \\
y_i &\sim \mathrm{NonCentralChiSquared}(k_i, lam_i) \\
\\
Z &= x + w y^T \\
&= x + \sum_i w_i y_i \\
Z &\sim \mathrm{GeneralizedNonCentralChiSquared}(w, k, lam, mu, sigma)
\end{aligned}
The generalized chi-squared distribution often arises when working on linear models with standard
normal noise because the sum of the squares of the residuals should follow a generalized
chi-squared distribution.
Examples
--------
The following plot shows three examples of the generalized chi-squared cumulative distribution
function.
.. image:: https://upload.wikimedia.org/wikipedia/commons/thumb/c/cd/Generalized_chi-square_cumulative_distribution_function.svg/1280px-Generalized_chi-square_cumulative_distribution_function.svg.png
:alt: Plots of examples of the generalized chi-square cumulative distribution function. Created by Dvidby0.
:target: https://commons.wikimedia.org/wiki/File:Generalized_chi-square_cumulative_distribution_function.svg
:width: 640px
The following examples are chosen from the three instances shown above. The curves appear in the
same order as the legend of the plot: blue, red, yellow.
>>> hl.eval(hl.pgenchisq(-80, w=[1, 2], k=[1, 4], lam=[1, 1], mu=0, sigma=0).value)
0.0
>>> hl.eval(hl.pgenchisq(-20, w=[1, 2], k=[1, 4], lam=[1, 1], mu=0, sigma=0).value)
0.0
>>> hl.eval(hl.pgenchisq(10 , w=[1, 2], k=[1, 4], lam=[1, 1], mu=0, sigma=0).value)
0.4670012373599629
>>> hl.eval(hl.pgenchisq(40 , w=[1, 2], k=[1, 4], lam=[1, 1], mu=0, sigma=0).value)
0.9958803111156718
>>> hl.eval(hl.pgenchisq(-80, w=[-2, -1], k=[5, 2], lam=[3, 1], mu=-3, sigma=0).value)
9.227056966837344e-05
>>> hl.eval(hl.pgenchisq(-20, w=[-2, -1], k=[5, 2], lam=[3, 1], mu=-3, sigma=0).value)
0.516439358616939
>>> hl.eval(hl.pgenchisq(10 , w=[-2, -1], k=[5, 2], lam=[3, 1], mu=-3, sigma=0).value)
1.0
>>> hl.eval(hl.pgenchisq(40 , w=[-2, -1], k=[5, 2], lam=[3, 1], mu=-3, sigma=0).value)
1.0
>>> hl.eval(hl.pgenchisq(-80, w=[1, -10, 2], k=[1, 2, 3], lam=[2, 3, 7], mu=-10, sigma=0).value)
0.14284718767288906
>>> hl.eval(hl.pgenchisq(-20, w=[1, -10, 2], k=[1, 2, 3], lam=[2, 3, 7], mu=-10, sigma=0).value)
0.5950150356303258
>>> hl.eval(hl.pgenchisq(10 , w=[1, -10, 2], k=[1, 2, 3], lam=[2, 3, 7], mu=-10, sigma=0).value)
0.923219534175858
>>> hl.eval(hl.pgenchisq(40 , w=[1, -10, 2], k=[1, 2, 3], lam=[2, 3, 7], mu=-10, sigma=0).value)
0.9971746768781656
Notes
-----
We follow Wikipedia's notational conventions. Some texts refer to the weight vector (our `w`) as
:math:`\lambda` or `lb` and the non-centrality vector (our `lam`) as `nc`.
We use the Davies' algorithm which was published as: `Davies, Robert. "The distribution of a
linear combination of chi-squared random variables." Applied Statistics 29
323-333. 1980. <http://www.robertnz.net/pdf/lc_chisq.pdf>`__ Davies included Fortran source code
in the original publication. Davies also released a `C language port
<http://www.robertnz.net/QF.htm>`__. Hail's implementation is a fairly direct port of the C
implementation to Scala. Davies provides 39 test cases with the source code. The Hail tests
include all 39 test cases as well as a few additional tests.
Davies' website cautions:
The method works well in most situations if you want only modest accuracy, say 0.0001. But
problems may arise if the sum is dominated by one or two terms with a total of only one or
two degrees of freedom and x is small.
Parameters
----------
x : :obj:`float` or :class:`.Expression` of type :py:data:`.tfloat64`
The value at which to evaluate the cumulative distribution function (CDF).
w : :obj:`list` of :obj:`float` or :class:`.Expression` of type :py:class:`.tarray` of :py:data:`.tfloat64`
A weight for each non-central chi-square term.
k : :obj:`list` of :obj:`int` or :class:`.Expression` of type :py:class:`.tarray` of :py:data:`.tint32`
A degrees of freedom parameter for each non-central chi-square term.
lam : :obj:`list` of :obj:`float` or :class:`.Expression` of type :py:class:`.tarray` of :py:data:`.tfloat64`
A non-centrality parameter for each non-central chi-square term. We use `lam` instead
of `lambda` because the latter is a reserved word in Python.
mu : :obj:`float` or :class:`.Expression` of type :py:data:`.tfloat64`
The standard deviation of the normal term.
sigma : :obj:`float` or :class:`.Expression` of type :py:data:`.tfloat64`
The standard deviation of the normal term.
max_iterations : :obj:`int` or :class:`.Expression` of type :py:data:`.tint32`
The maximum number of iterations of the numerical integration before raising an error.
min_accuracy : :obj:`int` or :class:`.Expression` of type :py:data:`.tint32`
The minimum accuracy of the returned value. If the minimum accuracy is not achieved, this
function will raise an error.
Returns
-------
:class:`.StructExpression`
This method returns a structure with the value as well as information about the numerical
integration.
- value : :class:`.Float64Expression`. If converged is true, the value of the CDF evaluated
at `x`. Otherwise, this is the last value the integration evaluated before aborting.
- n_iterations : :class:`.Int32Expression`. The number of iterations before stopping.
- converged : :class:`.BooleanExpression`. True if the `min_accuracy` was achieved and round
off error is not likely significant.
- fault : :class:`.Int32Expression`. If converged is true, fault is zero. If converged is
false, fault is either one or two. One indicates that the requried accuracy was not
achieved. Two indicates the round-off error is possibly significant.
"""
if max_iterations is None:
max_iterations = hl.literal(10_000)
if min_accuracy is None:
min_accuracy = hl.literal(0.00001)
return _func("pgenchisq", PGENCHISQ_RETURN_TYPE, x - mu, w, k, lam, sigma, max_iterations, min_accuracy)


@typecheck(x=expr_float64, mu=expr_float64, sigma=expr_float64, lower_tail=expr_bool, log_p=expr_bool)
def pnorm(x, mu=0, sigma=1, lower_tail=True, log_p=False) -> Float64Expression:
"""The cumulative probability function of a normal distribution with mean
Expand Down
29 changes: 29 additions & 0 deletions hail/python/test/hail/expr/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import hail as hl
import scipy.stats as spst
import pytest
from ..helpers import resource


def test_deprecated_binom_test():
Expand Down Expand Up @@ -35,3 +36,31 @@ def right_tail_from_scipy(x, df, ncp):

def test_shuffle():
assert set(hl.eval(hl.shuffle(hl.range(5)))) == set(range(5))


def test_pgenchisq():
ht = hl.import_table(
resource('davies-genchisq-tests.tsv'),
types={
'c': hl.tfloat64,
'weights': hl.tarray(hl.tfloat64),
'k': hl.tarray(hl.tint32),
'lam': hl.tarray(hl.tfloat64),
'sigma': hl.tfloat64,
'lim': hl.tint32,
'acc': hl.tfloat64,
'expected': hl.tfloat64,
'expected_n_iterations': hl.tint32
}
)
ht = ht.add_index('line_number')
ht = ht.annotate(line_number = ht.line_number + 1)
ht = ht.annotate(genchisq_result = hl.pgenchisq(
ht.c, ht.weights, ht.k, ht.lam, 0.0, ht.sigma, max_iterations=ht.lim, min_accuracy=ht.acc
))
tests = ht.collect()
for test in tests:
assert abs(test.genchisq_result.value - test.expected) < 0.0000005, str(test)
assert test.genchisq_result.fault == 0, str(test)
assert test.genchisq_result.converged == True, str(test)
assert test.genchisq_result.n_iterations == test.expected_n_iterations, str(test)
12 changes: 12 additions & 0 deletions hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,18 @@ abstract class RegistryFunctions {
case (r, cb, _, rt, Array(a1, a2, a3, a4, a5), errorID) => impl(r, cb, rt, a1, a2, a3, a4, a5, errorID)
}

def registerSCode6(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, mt5: Type, mt6: Type, rt: Type, pt: (Type, SType, SType, SType, SType, SType, SType) => SType)
(impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, SValue, SValue, SValue, SValue, Value[Int]) => SValue): Unit =
registerSCode(name, Array(mt1, mt2, mt3, mt4, mt5, mt6), rt, unwrappedApply(pt)) {
case (r, cb, _, rt, Array(a1, a2, a3, a4, a5, a6), errorID) => impl(r, cb, rt, a1, a2, a3, a4, a5, a6, errorID)
}

def registerSCode7(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, mt5: Type, mt6: Type, mt7: Type, rt: Type, pt: (Type, SType, SType, SType, SType, SType, SType, SType) => SType)
(impl: (EmitRegion, EmitCodeBuilder, SType, SValue, SValue, SValue, SValue, SValue, SValue, SValue, Value[Int]) => SValue): Unit =
registerSCode(name, Array(mt1, mt2, mt3, mt4, mt5, mt6, mt7), rt, unwrappedApply(pt)) {
case (r, cb, _, rt, Array(a1, a2, a3, a4, a5, a6, a7), errorID) => impl(r, cb, rt, a1, a2, a3, a4, a5, a6, a7, errorID)
}

def registerCode1(name: String, mt1: Type, rt: Type, pt: (Type, SType) => SType)(impl: (EmitCodeBuilder, EmitRegion, SType, SValue) => Value[_]): Unit =
registerCode(name, Array(mt1), rt, unwrappedApply(pt)) {
case (r, cb, rt, _, Array(a1)) => impl(cb, r, rt, a1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@ import is.hail.asm4s.Code
import is.hail.expr.ir._
import is.hail.stats._
import is.hail.types.physical.stypes._
import is.hail.types.physical.stypes.concrete._
import is.hail.types.physical.stypes.interfaces.primitive
import is.hail.types.physical.stypes.primitives._
import is.hail.types.physical.{PBoolean, PFloat32, PFloat64, PInt32, PInt64, PType}
import is.hail.types.physical.{PCanonicalArray, PBoolean, PFloat32, PFloat64, PInt32, PInt64, PType}
import is.hail.types.virtual._
import is.hail.utils._
import org.apache.commons.math3.special.Gamma
Expand Down Expand Up @@ -186,6 +187,53 @@ object MathFunctions extends RegistryFunctions {
registerScalaFunction("qnchisqtail", Array(TFloat64, TFloat64, TFloat64), TFloat64, null)(statsPackageClass, "qnchisqtail")
registerScalaFunction("qnchisqtail", Array(TFloat64, TFloat64, TFloat64, TBoolean, TBoolean), TFloat64, null)(statsPackageClass, "qnchisqtail")

registerSCode7(
"pgenchisq",
TFloat64,
TArray(TFloat64),
TArray(TInt32),
TArray(TFloat64),
TFloat64,
TInt32,
TFloat64,
DaviesAlgorithm.pType.virtualType,
(_, _, _, _, _, _, _, _) => DaviesAlgorithm.pType.sType
) {
case (r, cb, rt,
x: SFloat64Value,
_w: SIndexablePointerValue,
_k: SIndexablePointerValue,
_lam: SIndexablePointerValue,
sigma: SFloat64Value,
maxIterations: SInt32Value,
minAccuracy: SFloat64Value,
_) =>

val w = _w.castToArray(cb)
val k = _k.castToArray(cb)
val lam = _lam.castToArray(cb)

val res = cb.newLocal[DaviesResultForPython]("pgenchisq_result",
Code.invokeScalaObject7[
Double, IndexedSeq[Double], IndexedSeq[Int], IndexedSeq[Double], Double, Int, Double, DaviesResultForPython
](statsPackageClass, "pgenchisq",
x.value,
Code.checkcast[IndexedSeq[Double]](svalueToJavaValue(cb, r.region, w)),
Code.checkcast[IndexedSeq[Int]](svalueToJavaValue(cb, r.region, k)),
Code.checkcast[IndexedSeq[Double]](svalueToJavaValue(cb, r.region, lam)),
sigma.value,
maxIterations.value,
minAccuracy.value)
)

DaviesAlgorithm.pType.constructFromFields(cb, r.region, FastIndexedSeq(
EmitValue.present(primitive(cb.memoize(res.invoke[Double]("value")))),
EmitValue.present(primitive(cb.memoize(res.invoke[Int]("nIterations")))),
EmitValue.present(primitive(cb.memoize(res.invoke[Boolean]("converged")))),
EmitValue.present(primitive(cb.memoize(res.invoke[Int]("fault"))))
), deepCopy = false)
}

registerScalaFunction("floor", Array(TFloat32), TFloat32, null)(thisClass, "floor")
registerScalaFunction("floor", Array(TFloat64), TFloat64, null)(thisClass, "floor")

Expand Down
Loading

0 comments on commit 20fc42f

Please sign in to comment.