Skip to content

Commit

Permalink
[query] lower firth regression (#12816)
Browse files Browse the repository at this point in the history
CHANGELOG: In Query-on-Batch, `hl.logistic_regression('firth', ...)` is
now supported.

Forgive me: I cleaned up and unified the look of the (now three) `fit`
methods.

A few of the sweeping cleanups:
1. num_iter, max_iter, and cur_iter are now n_iterations,
max_iterations, and iteration.
2. Pervasive use of broadcasting functions rather than map.
3. `log_lkhd` only evaluated on the last iteration (in particular, its
not bound before the `case`)
4. `select` as the last step rather than `drop` (we didn't drop all the
unnecessary fields previously).
5. `select_globals` to make sure we only keep the `null_fit`.

Major changes in this PR:
1. Add no_crash to triangular_solve
2. Split the epacts tests by type (now easy to run all firth tests with
`-k firth`).
3. Firth fitting and test. A straight copy from Scala. I honestly don't
understand what this is doing.
  • Loading branch information
danking authored Apr 7, 2023
1 parent 91793d9 commit a980f1c
Show file tree
Hide file tree
Showing 7 changed files with 449 additions and 196 deletions.
4 changes: 4 additions & 0 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,10 @@ def logit(x) -> Float64Expression:
def expit(x) -> Float64Expression:
"""The logistic sigmoid function.
.. math::
\textrm{expit}(x) = \frac{1}{1 + e^{-x}}
Examples
--------
>>> hl.eval(hl.expit(.01))
Expand Down
300 changes: 222 additions & 78 deletions hail/python/hail/methods/statgen.py

Large diffs are not rendered by default.

32 changes: 20 additions & 12 deletions hail/python/hail/nd/nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,33 +269,41 @@ def solve(a, b, no_crash=False):
return result


@typecheck(nd_coef=expr_ndarray(), nd_dep=expr_ndarray(), lower=expr_bool)
def solve_triangular(nd_coef, nd_dep, lower=False):
"""Solve a triangular linear system.
@typecheck(A=expr_ndarray(), b=expr_ndarray(), lower=expr_bool, no_crash=bool)
def solve_triangular(A, b, lower=False, no_crash=False):
"""Solve a triangular linear system Ax = b for x.
Parameters
----------
nd_coef : :class:`.NDArrayNumericExpression`, (N, N)
A : :class:`.NDArrayNumericExpression`, (N, N)
Triangular coefficient matrix.
nd_dep : :class:`.NDArrayNumericExpression`, (N,) or (N, K)
b : :class:`.NDArrayNumericExpression`, (N,) or (N, K)
Dependent variables.
lower : `bool`:
If true, nd_coef is interpreted as a lower triangular matrix
If false, nd_coef is interpreted as a upper triangular matrix
If true, A is interpreted as a lower triangular matrix
If false, A is interpreted as a upper triangular matrix
Returns
-------
:class:`.NDArrayNumericExpression`, (N,) or (N, K)
Solution to the triangular system Ax = B. Shape is same as shape of B.
"""
nd_dep_ndim_orig = nd_dep.ndim
nd_coef, nd_dep = solve_helper(nd_coef, nd_dep, nd_dep_ndim_orig)
return_type = hl.tndarray(hl.tfloat64, 2)
nd_dep_ndim_orig = b.ndim
A, b = solve_helper(A, b, nd_dep_ndim_orig)

indices, aggregations = unify_all(A, b)

indices, aggregations = unify_all(nd_coef, nd_dep)
if no_crash:
return_type = hl.tstruct(solution=hl.tndarray(hl.tfloat64, 2), failed=hl.tbool)
ir = Apply("linear_triangular_solve_no_crash", return_type, A._ir, b._ir, lower._ir)
result = construct_expr(ir, return_type, indices, aggregations)
if nd_dep_ndim_orig == 1:
result = result.annotate(solution=result.solution.reshape((-1)))
return result

ir = Apply("linear_triangular_solve", return_type, nd_coef._ir, nd_dep._ir, lower._ir)
return_type = hl.tndarray(hl.tfloat64, 2)
ir = Apply("linear_triangular_solve", return_type, A._ir, b._ir, lower._ir)
result = construct_expr(ir, return_type, indices, aggregations)
if nd_dep_ndim_orig == 1:
result = result.reshape((-1))
Expand Down
4 changes: 2 additions & 2 deletions hail/python/test/hail/methods/test_skat.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def test_skat_max_iteration_fails_explodes_in_37_steps():
except FatalError as err:
assert 'Failed to fit logistic regression null model (MLE with covariates only): exploded at Newton iteration 37' in err.args[0]
except HailUserError as err:
assert 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, num_iter: 37, log_lkhd: -0.6931471805599453, converged: false, exploded: true}' in err.args[0]
assert 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, n_iterations: 37, log_lkhd: -0.6931471805599453, converged: false, exploded: true}' in err.args[0]
else:
assert False

Expand Down Expand Up @@ -650,6 +650,6 @@ def test_skat_max_iterations_fails_to_converge_in_fewer_than_36_steps():
except FatalError as err:
assert 'Failed to fit logistic regression null model (MLE with covariates only): Newton iteration failed to converge' in err.args[0]
except HailUserError as err:
assert 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, num_iter: 36, log_lkhd: -0.6931471805599457, converged: false, exploded: false}' in err.args[0]
assert 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, n_iterations: 36, log_lkhd: -0.6931471805599457, converged: false, exploded: false}' in err.args[0]
else:
assert False
Loading

0 comments on commit a980f1c

Please sign in to comment.