From a980f1c332f697d66b44ea506039a60990a675b3 Mon Sep 17 00:00:00 2001 From: Dan King Date: Fri, 7 Apr 2023 14:00:52 -0400 Subject: [PATCH] [query] lower firth regression (#12816) CHANGELOG: In Query-on-Batch, `hl.logistic_regression('firth', ...)` is now supported. Forgive me: I cleaned up and unified the look of the (now three) `fit` methods. A few of the sweeping cleanups: 1. num_iter, max_iter, and cur_iter are now n_iterations, max_iterations, and iteration. 2. Pervasive use of broadcasting functions rather than map. 3. `log_lkhd` only evaluated on the last iteration (in particular, its not bound before the `case`) 4. `select` as the last step rather than `drop` (we didn't drop all the unnecessary fields previously). 5. `select_globals` to make sure we only keep the `null_fit`. Major changes in this PR: 1. Add no_crash to triangular_solve 2. Split the epacts tests by type (now easy to run all firth tests with `-k firth`). 3. Firth fitting and test. A straight copy from Scala. I honestly don't understand what this is doing. --- hail/python/hail/expr/functions.py | 4 + hail/python/hail/methods/statgen.py | 300 +++++++++++++----- hail/python/hail/nd/nd.py | 32 +- hail/python/test/hail/methods/test_skat.py | 4 +- hail/python/test/hail/methods/test_statgen.py | 283 +++++++++++------ .../is/hail/expr/ir/functions/Functions.scala | 7 +- .../expr/ir/functions/NDArrayFunctions.scala | 15 + 7 files changed, 449 insertions(+), 196 deletions(-) diff --git a/hail/python/hail/expr/functions.py b/hail/python/hail/expr/functions.py index b2a14467d4c..4be0cf23130 100644 --- a/hail/python/hail/expr/functions.py +++ b/hail/python/hail/expr/functions.py @@ -1842,6 +1842,10 @@ def logit(x) -> Float64Expression: def expit(x) -> Float64Expression: """The logistic sigmoid function. + .. math:: + + \textrm{expit}(x) = \frac{1}{1 + e^{-x}} + Examples -------- >>> hl.eval(hl.expit(.01)) diff --git a/hail/python/hail/methods/statgen.py b/hail/python/hail/methods/statgen.py index 2453322839d..f687270e442 100644 --- a/hail/python/hail/methods/statgen.py +++ b/hail/python/hail/methods/statgen.py @@ -42,7 +42,7 @@ score=tvector64, fisher=tmatrix64, mu=tvector64, - num_iter=hl.tint32, + n_iterations=hl.tint32, log_lkhd=hl.tfloat64, converged=hl.tbool, exploded=hl.tbool) @@ -620,7 +620,7 @@ def process_partition(part): covariates=sequenceof(expr_float64), pass_through=sequenceof(oneof(str, Expression)), max_iterations=nullable(int), - tolerance=float) + tolerance=nullable(float)) def logistic_regression_rows(test, y, x, @@ -628,7 +628,7 @@ def logistic_regression_rows(test, pass_through=(), *, max_iterations: Optional[int] = None, - tolerance: float = 1e-6) -> hail.Table: + tolerance: Optional[float] = None) -> hail.Table: r"""For each row, test an input variable for association with a binary response variable using logistic regression. @@ -851,9 +851,9 @@ def logistic_regression_rows(test, Additional row fields to include in the resulting table. max_iterations : :obj:`int` The maximum number of iterations. - tolerance : :obj:`float` - Convergence is defined by a change in the beta vector of less than - `tolerance`. + tolerance : :obj:`float`, optional + The iterative fit of this model is considered "converged" if the change in the estimated + beta is smaller than tolerance. By default the tolerance is 1e-6. Returns ------- @@ -863,9 +863,13 @@ def logistic_regression_rows(test, if max_iterations is None: max_iterations = 25 if test != 'firth' else 100 - if not isinstance(Env.backend(), SparkBackend): + if hl.current_backend().requires_lowering: return _logistic_regression_rows_nd( - test, y, x, covariates, pass_through, max_iterations=max_iterations) + test, y, x, covariates, pass_through, max_iterations=max_iterations, tolerance=tolerance) + + if tolerance is None: + tolerance = 1e-6 + assert tolerance > 0.0 if len(covariates) == 0: raise ValueError('logistic regression requires at least one covariate expression') @@ -930,8 +934,17 @@ def nd_max(hl_nd): return hl.max(hl.array(hl_nd.reshape(-1))) -def logreg_fit(X, y, null_fit, max_iter: int, tol: float): - assert max_iter >= 0 +def logreg_fit(X: hl.NDArrayNumericExpression, # (K,) + y: hl.NDArrayNumericExpression, # (N, K) + null_fit: Optional[hl.StructExpression], + max_iterations: int, + tolerance: float + ) -> hl.StructExpression: + """Iteratively reweighted least squares to fit the model y ~ Bernoulli(logit(X \beta)) + + When fitting the null model, K=n_covariates, otherwise K=n_covariates + 1. + """ + assert max_iterations >= 0 assert X.ndim == 2 assert y.ndim == 1 # X is samples by covs. @@ -972,39 +985,37 @@ def logreg_fit(X, y, null_fit, max_iter: int, tol: float): dtype = numerical_regression_fit_dtype blank_struct = hl.struct(**{k: hl.missing(dtype[k]) for k in dtype}) - def search(recur, cur_iter, b, mu, score, fisher): - def cont(exploded, delta_b, max_delta_b, log_lkhd): - def compute_next_iter(cur_iter, b, mu, score, fisher): - cur_iter = cur_iter + 1 - b = b + delta_b - mu = sigmoid(X @ b) - score = X.T @ (y - mu) - fisher = X.T @ (X * (mu * (1 - mu)).reshape(-1, 1)) - return recur(cur_iter, b, mu, score, fisher) + def search(recur, iteration, b, mu, score, fisher): + def cont(exploded, delta_b, max_delta_b): + log_lkhd = hl.log((y * mu) + (1 - y) * (1 - mu)).sum() + + next_b = b + delta_b + next_mu = sigmoid(X @ next_b) + next_score = X.T @ (y - next_mu) + next_fisher = X.T @ (X * (next_mu * (1 - next_mu)).reshape(-1, 1)) return (hl.case() .when(exploded | hl.is_nan(delta_b[0]), - blank_struct.annotate(num_iter=cur_iter, log_lkhd=log_lkhd, converged=False, exploded=True)) - .when(max_delta_b < tol, - hl.struct(b=b, score=score, fisher=fisher, mu=mu, num_iter=cur_iter, log_lkhd=log_lkhd, converged=True, exploded=False)) - .when(cur_iter == max_iter, - blank_struct.annotate(num_iter=cur_iter, log_lkhd=log_lkhd, converged=False, exploded=False)) - .default(compute_next_iter(cur_iter, b, mu, score, fisher))) + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True)) + .when(max_delta_b < tolerance, + hl.struct(b=b, score=score, fisher=fisher, mu=mu, n_iterations=iteration, log_lkhd=log_lkhd, converged=True, exploded=False)) + .when(iteration == max_iterations, + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False)) + .default(recur(iteration + 1, next_b, next_mu, next_score, next_fisher))) delta_b_struct = hl.nd.solve(fisher, score, no_crash=True) exploded = delta_b_struct.failed delta_b = delta_b_struct.solution - max_delta_b = nd_max(delta_b.map(lambda e: hl.abs(e))) - log_lkhd = hl.log((y * mu) + (1 - y) * (1 - mu)).sum() - return hl.bind(cont, exploded, delta_b, max_delta_b, log_lkhd) + max_delta_b = nd_max(hl.abs(delta_b)) + return hl.bind(cont, exploded, delta_b, max_delta_b) - if max_iter == 0: - return blank_struct.annotate(num_iter=0, log_lkhd=0, converged=False, exploded=False) + if max_iterations == 0: + return blank_struct.annotate(n_iterations=0, log_lkhd=0, converged=False, exploded=False) return hl.experimental.loop(search, numerical_regression_fit_dtype, 1, b, mu, score, fisher) def wald_test(X, fit): - se = hl.nd.diagonal(hl.nd.inv(fit.fisher)).map(lambda e: hl.sqrt(e)) + se = hl.sqrt(hl.nd.diagonal(hl.nd.inv(fit.fisher))) z = fit.b / se p = z.map(lambda e: 2 * hl.pnorm(-hl.abs(e))) return hl.struct( @@ -1012,7 +1023,7 @@ def wald_test(X, fit): standard_error=se[X.shape[1] - 1], z_stat=z[X.shape[1] - 1], p_value=p[X.shape[1] - 1], - fit=hl.struct(n_iterations=fit.num_iter, converged=fit.converged, exploded=fit.exploded)) + fit=fit.select('n_iterations', 'converged', 'exploded')) def lrt_test(X, null_fit, fit): @@ -1023,7 +1034,7 @@ def lrt_test(X, null_fit, fit): beta=fit.b[X.shape[1] - 1], chi_sq_stat=chi_sq, p_value=p, - fit=hl.struct(n_iterations=fit.num_iter, converged=fit.converged, exploded=fit.exploded)) + fit=fit.select('n_iterations', 'converged', 'exploded')) def logistic_score_test(X, y, null_fit): @@ -1034,7 +1045,7 @@ def logistic_score_test(X, y, null_fit): X0 = X[:, 0:m0] X1 = X[:, m0:] - mu = (X @ b).map(lambda e: hl.expit(e)) + mu = hl.expit(X @ b) score_0 = null_fit.score score_1 = X1.T @ (y - mu) @@ -1052,22 +1063,121 @@ def logistic_score_test(X, y, null_fit): solve_attempt = hl.nd.solve(fisher, score, no_crash=True) - chi_sq = hl.if_else(solve_attempt.failed, - hl.missing(hl.tfloat64), - (score * solve_attempt.solution).sum()) + chi_sq = hl.or_missing( + ~solve_attempt.failed, + (score * solve_attempt.solution).sum() + ) p = hl.pchisqtail(chi_sq, m - m0) return hl.struct(chi_sq_stat=chi_sq, p_value=p) +def _firth_fit(b: hl.NDArrayNumericExpression, # (K,) + X: hl.NDArrayNumericExpression, # (N, K) + y: hl.NDArrayNumericExpression, # (N,) + max_iterations: int, + tolerance: float + ) -> hl.StructExpression: + """Iteratively reweighted least squares using Firth's regression to fit the model y ~ Bernoulli(logit(X \beta)) + + When fitting the null model, K=n_covariates, otherwise K=n_covariates + 1. + """ + assert max_iterations >= 0 + assert X.ndim == 2 + assert y.ndim == 1 + assert b.ndim == 1 + + dtype = numerical_regression_fit_dtype._drop_fields(['score', 'fisher']) + blank_struct = hl.struct(**{k: hl.missing(dtype[k]) for k in dtype}) + X_bslice = X[:, :b.shape[0]] + + def fit(recur, iteration, b): + def cont(exploded, delta_b, max_delta_b): + log_lkhd_left = hl.log(y * mu + (hl.literal(1.0) - y) * (1 - mu)).sum() + log_lkhd_right = hl.log(hl.abs(hl.nd.diagonal(r))).sum() + log_lkhd = log_lkhd_left + log_lkhd_right + + next_b = b + delta_b + + return (hl.case() + .when(exploded | hl.is_nan(delta_b[0]), + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True)) + .when(max_delta_b < tolerance, + hl.struct(b=b, mu=mu, n_iterations=iteration, log_lkhd=log_lkhd, converged=True, exploded=False)) + .when(iteration == max_iterations, + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False)) + .default(recur(iteration + 1, next_b))) + + m = b.shape[0] # n_covariates or n_covariates + 1, depending on improved null fit vs full fit + mu = sigmoid(X_bslice @ b) + sqrtW = hl.sqrt(mu * (1 - mu)) + q, r = hl.nd.qr(X * sqrtW.T.reshape(-1, 1)) + h = (q * q).sum(1) + coef = r[:m, :m] + residual = y - mu + dep = q[:, :m].T @ ((residual + (h * (0.5 - mu))) / sqrtW) + delta_b_struct = hl.nd.solve_triangular(coef, dep.reshape(-1, 1), no_crash=True) + exploded = delta_b_struct.failed + delta_b = delta_b_struct.solution.reshape(-1) + + max_delta_b = nd_max(hl.abs(delta_b)) + + return hl.bind(cont, exploded, delta_b, max_delta_b) + + if max_iterations == 0: + return blank_struct.annotate(n_iterations=0, log_lkhd=0, converged=False, exploded=False) + return hl.experimental.loop(fit, dtype, 1, b) + + +def _firth_test(null_fit, X, y, max_iterations, tolerance) -> hl.StructExpression: + firth_improved_null_fit = _firth_fit(null_fit.b, X, y, max_iterations=max_iterations, tolerance=tolerance) + dof = 1 # 1 variant + + def cont(firth_improved_null_fit): + initial_b_full_model = hl.nd.hstack([firth_improved_null_fit.b, hl.nd.array([0.0])]) + firth_fit = _firth_fit(initial_b_full_model, X, y, max_iterations=max_iterations, tolerance=tolerance) + + def cont2(firth_fit): + firth_chi_sq = 2 * (firth_fit.log_lkhd - firth_improved_null_fit.log_lkhd) + firth_p = hl.pchisqtail(firth_chi_sq, dof) + + blank_struct = hl.struct( + beta=hl.missing(hl.tfloat64), + chi_sq_stat=hl.missing(hl.tfloat64), + p_value=hl.missing(hl.tfloat64), + firth_null_fit=hl.missing(firth_improved_null_fit.dtype), + fit=hl.missing(firth_fit.dtype) + ) + return (hl.case() + .when(firth_improved_null_fit.converged, + hl.case() + .when(firth_fit.converged, + hl.struct( + beta=firth_fit.b[firth_fit.b.shape[0] - 1], + chi_sq_stat=firth_chi_sq, + p_value=firth_p, + firth_null_fit=firth_improved_null_fit, + fit=firth_fit + )) + .default(blank_struct.annotate( + firth_null_fit=firth_improved_null_fit, + fit=firth_fit + ))) + .default(blank_struct.annotate( + firth_null_fit=firth_improved_null_fit + ))) + return hl.bind(cont2, firth_fit) + return hl.bind(cont, firth_improved_null_fit) + + @typecheck(test=enumeration('wald', 'lrt', 'score', 'firth'), y=oneof(expr_float64, sequenceof(expr_float64)), x=expr_float64, covariates=sequenceof(expr_float64), pass_through=sequenceof(oneof(str, Expression)), max_iterations=nullable(int), - tolerance=float) + tolerance=nullable(float)) def _logistic_regression_rows_nd(test, y, x, @@ -1075,7 +1185,7 @@ def _logistic_regression_rows_nd(test, pass_through=(), *, max_iterations: Optional[int] = None, - tolerance: float = 1e-6) -> hail.Table: + tolerance: Optional[float] = None) -> hail.Table: r"""For each row, test an input variable for association with a binary response variable using logistic regression. @@ -1294,6 +1404,10 @@ def _logistic_regression_rows_nd(test, if max_iterations is None: max_iterations = 25 if test != 'firth' else 100 + if tolerance is None: + tolerance = 1e-8 + assert tolerance > 0.0 + if len(covariates) == 0: raise ValueError('logistic regression requires at least one covariate expression') @@ -1347,32 +1461,40 @@ def error_if_not_converged(null_fit): .or_error("Failed to fit logistic regression null model (standard MLE with covariates only): " "Newton iteration failed to converge"))) .or_error(hl.format("Failed to fit logistic regression null model (standard MLE with covariates only): " - "exploded at Newton iteration %d", null_fit.num_iter))) + "exploded at Newton iteration %d", null_fit.n_iterations))) - null_fit = logreg_fit(ht.covmat, yvec, None, max_iter=max_iterations, tol=tolerance) + null_fit = logreg_fit(ht.covmat, yvec, None, max_iterations=max_iterations, tolerance=tolerance) return hl.bind(error_if_not_converged, null_fit) ht = ht.annotate_globals(null_fits=ht.yvecs.map(fit_null)) ht = ht.transmute(x=hl.nd.array(mean_impute(ht.entries[x_field_name]))) - covs_and_x = hl.nd.hstack([ht.covmat, ht.x.reshape((-1, 1))]) + ht = ht.annotate(covs_and_x=hl.nd.hstack([ht.covmat, ht.x.reshape((-1, 1))])) def run_test(yvec, null_fit): if test == 'score': - return logistic_score_test(covs_and_x, yvec, null_fit) + return logistic_score_test(ht.covs_and_x, yvec, null_fit) + if test == 'firth': + return _firth_test(null_fit, ht.covs_and_x, yvec, max_iterations=max_iterations, tolerance=tolerance) - test_fit = logreg_fit(covs_and_x, yvec, null_fit, max_iter=max_iterations, tol=tolerance) + test_fit = logreg_fit(ht.covs_and_x, yvec, null_fit, max_iterations=max_iterations, tolerance=tolerance) if test == 'wald': - return wald_test(covs_and_x, test_fit) - elif test == 'lrt': - return lrt_test(covs_and_x, null_fit, test_fit) - else: - assert test == 'firth' - raise ValueError("firth not yet supported on lowered backends") - ht = ht.annotate(logistic_regression=hl.starmap(run_test, hl.zip(ht.yvecs, ht.null_fits))) + return wald_test(ht.covs_and_x, test_fit) + assert test == 'lrt', test + return lrt_test(ht.covs_and_x, null_fit, test_fit) + ht = ht.select( + logistic_regression=hl.starmap(run_test, hl.zip(ht.yvecs, ht.null_fits)), + **{f: ht[f] for f in row_fields} + ) + assert 'null_fits' not in row_fields + assert 'logistic_regression' not in row_fields if not y_is_list: - ht = ht.transmute(**ht.logistic_regression[0]) - return ht.drop("x") + assert all(f not in row_fields for f in ht.null_fits[0]) + assert all(f not in row_fields for f in ht.logistic_regression[0]) + ht = ht.select_globals(**ht.null_fits[0]) + return ht.transmute(**ht.logistic_regression[0]) + ht = ht.select_globals('null_fits') + return ht @typecheck(test=enumeration('wald', 'lrt', 'score'), @@ -1415,7 +1537,7 @@ def poisson_regression_rows(test, Non-empty list of column-indexed covariate expressions. pass_through : :obj:`list` of :class:`str` or :class:`.Expression` Additional row fields to include in the resulting table. - tolerance : :obj:`int`, optional + tolerance : :obj:`float`, optional The iterative fit of this model is considered "converged" if the change in the estimated beta is smaller than tolerance. By default the tolerance is 1e-6. @@ -1429,6 +1551,7 @@ def poisson_regression_rows(test, if tolerance is None: tolerance = 1e-6 + assert tolerance > 0.0 if len(covariates) == 0: raise ValueError('Poisson regression requires at least one covariate expression') @@ -1490,7 +1613,7 @@ def _lowered_poisson_regression_rows(test, if tolerance is None: tolerance = 1e-8 - assert tolerance > 0 + assert tolerance > 0.0 k = len(covariates) if k == 0: @@ -1548,7 +1671,7 @@ def _lowered_poisson_regression_rows(test, mt = mt.annotate_globals( null_fit=hl.case().when(mt.null_fit.converged, mt.null_fit).or_error( hl.format('_lowered_poisson_regression_rows: null model did not converge: %s', - mt.null_fit.select('num_iter', 'log_lkhd', 'converged', 'exploded'))) + mt.null_fit.select('n_iterations', 'log_lkhd', 'converged', 'exploded'))) ) mt = mt.annotate_rows(mean_x=hl.agg.mean(mt.x)) mt = mt.annotate_rows(xvec=hl.nd.array(hl.agg.collect(hl.coalesce(mt.x, mt.mean_x)))) @@ -1566,7 +1689,7 @@ def _lowered_poisson_regression_rows(test, chi_sq_stat=chi_sq, p_value=p, **ht.pass_through - ) + ).select_globals('null_fit') X = hl.nd.hstack([covmat, xvec.T.reshape(-1, 1)]) b = hl.nd.hstack([null_fit.b, hl.nd.array([0.0])]) @@ -1589,55 +1712,76 @@ def _lowered_poisson_regression_rows(test, test_fit=test_fit, **lrt_test(X, null_fit, test_fit), **ht.pass_through - ) + ).select_globals('null_fit') assert test == 'wald' return ht.select( test_fit=test_fit, **wald_test(X, test_fit), **ht.pass_through - ) + ).select_globals('null_fit') -def _poisson_fit(covmat, yvec, b, mu, score, fisher, max_iterations, tolerance): +def _poisson_fit(X: hl.NDArrayNumericExpression, # (N, K) + y: hl.NDArrayNumericExpression, # (N,) + b: hl.NDArrayNumericExpression, # (K,) + mu: hl.NDArrayNumericExpression, # (N,) + score: hl.NDArrayNumericExpression, # (K,) + fisher: hl.NDArrayNumericExpression, # (K, K) + max_iterations: int, + tolerance: float + ) -> hl.StructExpression: + """Iteratively reweighted least squares to fit the model y ~ Poisson(exp(X \beta)) + + When fitting the null model, K=n_covariates, otherwise K=n_covariates + 1. + """ + assert max_iterations >= 0 + assert X.ndim == 2 + assert y.ndim == 1 + assert b.ndim == 1 + assert mu.ndim == 1 + assert score.ndim == 1 + assert fisher.ndim == 2 + dtype = numerical_regression_fit_dtype blank_struct = hl.struct(**{k: hl.missing(dtype[k]) for k in dtype}) - def fit(recur, cur_iter, b, mu, score, fisher): - def cont(exploded, delta_b, max_delta_b, log_lkhd): - next_iter = cur_iter + 1 + def fit(recur, iteration, b, mu, score, fisher): + def cont(exploded, delta_b, max_delta_b): + log_lkhd = y @ hl.log(mu) - mu.sum() + next_b = b + delta_b - next_mu = hl.exp(covmat @ next_b) - next_score = covmat.T @ (yvec - next_mu) - next_fisher = (next_mu * covmat.T) @ covmat + next_mu = hl.exp(X @ next_b) + next_score = X.T @ (y - next_mu) + next_fisher = (next_mu * X.T) @ X return (hl.case() .when(exploded | hl.is_nan(delta_b[0]), - blank_struct.annotate(num_iter=cur_iter, log_lkhd=log_lkhd, converged=False, exploded=True)) + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=True)) .when(max_delta_b < tolerance, - hl.struct(b=b, score=score, fisher=fisher, mu=mu, num_iter=cur_iter, log_lkhd=log_lkhd, converged=True, exploded=False)) - .when(cur_iter == max_iterations, - blank_struct.annotate(num_iter=cur_iter, log_lkhd=log_lkhd, converged=False, exploded=False)) - .default(recur(next_iter, next_b, next_mu, next_score, next_fisher))) + hl.struct(b=b, score=score, fisher=fisher, mu=mu, n_iterations=iteration, log_lkhd=log_lkhd, converged=True, exploded=False)) + .when(iteration == max_iterations, + blank_struct.annotate(n_iterations=iteration, log_lkhd=log_lkhd, converged=False, exploded=False)) + .default(recur(iteration + 1, next_b, next_mu, next_score, next_fisher))) + delta_b_struct = hl.nd.solve(fisher, score, no_crash=True) exploded = delta_b_struct.failed delta_b = delta_b_struct.solution max_delta_b = nd_max(delta_b.map(lambda e: hl.abs(e))) - log_lkhd = yvec @ hl.log(mu) - mu.sum() - return hl.bind(cont, exploded, delta_b, max_delta_b, log_lkhd) + return hl.bind(cont, exploded, delta_b, max_delta_b) if max_iterations == 0: - return blank_struct.select(num_iter=0, log_lkhd=0, converged=False, exploded=False) + return blank_struct.select(n_iterations=0, log_lkhd=0, converged=False, exploded=False) return hl.experimental.loop(fit, dtype, 1, b, mu, score, fisher) -def _poisson_score_test(null_fit, covmat, yvec, xvec): +def _poisson_score_test(null_fit, covmat, y, xvec): dof = 1 X = hl.nd.hstack([covmat, xvec.T.reshape(-1, 1)]) b = hl.nd.hstack([null_fit.b, hl.nd.array([0.0])]) mu = hl.exp(X @ b) - score = hl.nd.hstack([null_fit.score, hl.nd.array([xvec @ (yvec - mu)])]) + score = hl.nd.hstack([null_fit.score, hl.nd.array([xvec @ (y - mu)])]) fisher00 = null_fit.fisher fisher01 = ((mu * covmat.T) @ xvec).reshape((-1, 1)) @@ -2484,7 +2628,7 @@ def _logistic_skat(group, - mu : :obj:`.tndarray` the expected value under the null model. - - num_iter : :obj:`.tint32` the number of iterations before termination. + - n_iterations : :obj:`.tint32` the number of iterations before termination. - log_lkhd : :obj:`.tfloat64` the log-likelihood of the final iteration. @@ -2532,7 +2676,7 @@ def _logistic_skat(group, covmat=hl.nd.array(covmat), n_complete_samples=n ) - null_fit = logreg_fit(mt.covmat, mt.yvec, None, max_iter=null_max_iterations, tol=null_tolerance) + null_fit = logreg_fit(mt.covmat, mt.yvec, None, max_iterations=null_max_iterations, tolerance=null_tolerance) mt = mt.annotate_globals( null_fit=hl.case().when(null_fit.converged, null_fit).or_error( hl.format('hl._logistic_skat: null model did not converge: %s', null_fit)) diff --git a/hail/python/hail/nd/nd.py b/hail/python/hail/nd/nd.py index c63442b1775..a231ca83280 100644 --- a/hail/python/hail/nd/nd.py +++ b/hail/python/hail/nd/nd.py @@ -269,19 +269,19 @@ def solve(a, b, no_crash=False): return result -@typecheck(nd_coef=expr_ndarray(), nd_dep=expr_ndarray(), lower=expr_bool) -def solve_triangular(nd_coef, nd_dep, lower=False): - """Solve a triangular linear system. +@typecheck(A=expr_ndarray(), b=expr_ndarray(), lower=expr_bool, no_crash=bool) +def solve_triangular(A, b, lower=False, no_crash=False): + """Solve a triangular linear system Ax = b for x. Parameters ---------- - nd_coef : :class:`.NDArrayNumericExpression`, (N, N) + A : :class:`.NDArrayNumericExpression`, (N, N) Triangular coefficient matrix. - nd_dep : :class:`.NDArrayNumericExpression`, (N,) or (N, K) + b : :class:`.NDArrayNumericExpression`, (N,) or (N, K) Dependent variables. lower : `bool`: - If true, nd_coef is interpreted as a lower triangular matrix - If false, nd_coef is interpreted as a upper triangular matrix + If true, A is interpreted as a lower triangular matrix + If false, A is interpreted as a upper triangular matrix Returns ------- @@ -289,13 +289,21 @@ def solve_triangular(nd_coef, nd_dep, lower=False): Solution to the triangular system Ax = B. Shape is same as shape of B. """ - nd_dep_ndim_orig = nd_dep.ndim - nd_coef, nd_dep = solve_helper(nd_coef, nd_dep, nd_dep_ndim_orig) - return_type = hl.tndarray(hl.tfloat64, 2) + nd_dep_ndim_orig = b.ndim + A, b = solve_helper(A, b, nd_dep_ndim_orig) + + indices, aggregations = unify_all(A, b) - indices, aggregations = unify_all(nd_coef, nd_dep) + if no_crash: + return_type = hl.tstruct(solution=hl.tndarray(hl.tfloat64, 2), failed=hl.tbool) + ir = Apply("linear_triangular_solve_no_crash", return_type, A._ir, b._ir, lower._ir) + result = construct_expr(ir, return_type, indices, aggregations) + if nd_dep_ndim_orig == 1: + result = result.annotate(solution=result.solution.reshape((-1))) + return result - ir = Apply("linear_triangular_solve", return_type, nd_coef._ir, nd_dep._ir, lower._ir) + return_type = hl.tndarray(hl.tfloat64, 2) + ir = Apply("linear_triangular_solve", return_type, A._ir, b._ir, lower._ir) result = construct_expr(ir, return_type, indices, aggregations) if nd_dep_ndim_orig == 1: result = result.reshape((-1)) diff --git a/hail/python/test/hail/methods/test_skat.py b/hail/python/test/hail/methods/test_skat.py index 770e80e920b..d9dd9a74b63 100644 --- a/hail/python/test/hail/methods/test_skat.py +++ b/hail/python/test/hail/methods/test_skat.py @@ -620,7 +620,7 @@ def test_skat_max_iteration_fails_explodes_in_37_steps(): except FatalError as err: assert 'Failed to fit logistic regression null model (MLE with covariates only): exploded at Newton iteration 37' in err.args[0] except HailUserError as err: - assert 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, num_iter: 37, log_lkhd: -0.6931471805599453, converged: false, exploded: true}' in err.args[0] + assert 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, n_iterations: 37, log_lkhd: -0.6931471805599453, converged: false, exploded: true}' in err.args[0] else: assert False @@ -650,6 +650,6 @@ def test_skat_max_iterations_fails_to_converge_in_fewer_than_36_steps(): except FatalError as err: assert 'Failed to fit logistic regression null model (MLE with covariates only): Newton iteration failed to converge' in err.args[0] except HailUserError as err: - assert 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, num_iter: 36, log_lkhd: -0.6931471805599457, converged: false, exploded: false}' in err.args[0] + assert 'hl._logistic_skat: null model did not converge: {b: null, score: null, fisher: null, mu: null, n_iterations: 36, log_lkhd: -0.6931471805599457, converged: false, exploded: false}' in err.args[0] else: assert False diff --git a/hail/python/test/hail/methods/test_statgen.py b/hail/python/test/hail/methods/test_statgen.py index dc4f081d859..ee8462561a7 100644 --- a/hail/python/test/hail/methods/test_statgen.py +++ b/hail/python/test/hail/methods/test_statgen.py @@ -464,7 +464,7 @@ def test_logistic_regression_rows_max_iter_zero(self): covariates=[1], max_iterations=0 ) - ht.null_fits.collect() + ht.globals.collect() # null model is a global except Exception as exc: assert 'Failed to fit logistic regression null model (standard MLE with covariates only): Newton iteration failed to converge' in exc.args[0] else: @@ -490,9 +490,7 @@ def test_logistic_regression_rows_max_iter_explodes(self): assert fit.exploded assert not fit.converged - @fails_local_backend() - @fails_service_backend() - def test_logistic_regression_rows_max_iter_explodes_in_12_steps_for_firth(self): + def test_firth_logistic_regression_rows_explodes_in_12_steps(self): import hail as hl mt = hl.utils.range_matrix_table(1, 3) mt = mt.annotate_entries(x=hl.literal([1, 1, 10])) @@ -508,9 +506,7 @@ def test_logistic_regression_rows_max_iter_explodes_in_12_steps_for_firth(self): assert fit.exploded assert not fit.converged - @fails_local_backend() - @fails_service_backend() - def test_logistic_regression_rows_does_not_converge_with_105_iterations(self): + def test_firth_logistic_regression_rows_does_not_converge_with_105_iterations(self): import hail as hl mt = hl.utils.range_matrix_table(1, 3) mt = mt.annotate_entries(x=hl.literal([1, 3, 10])) @@ -526,9 +522,7 @@ def test_logistic_regression_rows_does_not_converge_with_105_iterations(self): assert not fit.exploded assert not fit.converged - @fails_local_backend() - @fails_service_backend() - def test_logistic_regression_rows_does_converge_with_106_iterations(self): + def test_firth_logistic_regression_rows_does_converge_with_more_iterations(self): import hail as hl mt = hl.utils.range_matrix_table(1, 3) mt = mt.annotate_entries(x=hl.literal([1, 3, 10])) @@ -537,15 +531,14 @@ def test_logistic_regression_rows_does_converge_with_106_iterations(self): y=hl.literal([0, 1, 1])[mt.col_idx], x=mt.x[mt.col_idx], covariates=[1], - max_iterations=106 + max_iterations=106, + tolerance=1e-6 ) result = ht.collect()[0] fit = result.fit - actual_beta = result.beta - expected_beta = 0.19699166375172233 - assert abs(actual_beta - expected_beta) < 1e-16 - assert abs(result.chi_sq_stat - 0.6464918007192411) < 1e-15 - assert abs(result.p_value - 0.4213697518249182) < 1e-15 + assert result.beta == pytest.approx(0.19699166375172233, abs=1e-14) + assert result.chi_sq_stat == pytest.approx(0.6464918007192411, abs=1e-14) + assert result.p_value == pytest.approx(0.4213697518249182, abs=1e-14) assert fit.n_iterations == 106 assert not fit.exploded assert fit.converged @@ -999,93 +992,6 @@ def is_constant(r): self.assertTrue(is_constant(results[9])) self.assertTrue(is_constant(results[10])) - @fails_service_backend() - @fails_local_backend() - def test_logistic_regression_epacts(self): - covariates = hl.import_table(resource('regressionLogisticEpacts.cov'), - key='IND_ID', - types={'PC1': hl.tfloat, 'PC2': hl.tfloat}) - fam = hl.import_fam(resource('regressionLogisticEpacts.fam')) - - mt = hl.import_vcf(resource('regressionLogisticEpacts.vcf')) - mt = mt.annotate_cols(**covariates[mt.s], **fam[mt.s]) - - def get_results(table): - return dict(hl.tuple([table.locus.position, table.row]).collect()) - - wald = get_results(hl.logistic_regression_rows( - test='wald', - y=mt.is_case, - x=mt.GT.n_alt_alleles(), - covariates=[1.0, mt.is_female, mt.PC1, mt.PC2])) - lrt = get_results(hl.logistic_regression_rows( - test='lrt', - y=mt.is_case, - x=mt.GT.n_alt_alleles(), - covariates=[1.0, mt.is_female, mt.PC1, mt.PC2])) - score = get_results(hl.logistic_regression_rows( - test='score', - y=mt.is_case, - x=mt.GT.n_alt_alleles(), - covariates=[1.0, mt.is_female, mt.PC1, mt.PC2])) - firth = get_results(hl.logistic_regression_rows( - test='firth', - y=mt.is_case, - x=mt.GT.n_alt_alleles(), - covariates=[1.0, mt.is_female, mt.PC1, mt.PC2])) - - # 2535 samples from 1K Genomes Project - # Locus("22", 16060511) # MAC 623 - # Locus("22", 16115878) # MAC 370 - # Locus("22", 16115882) # MAC 1207 - # Locus("22", 16117940) # MAC 7 - # Locus("22", 16117953) # MAC 21 - - self.assertAlmostEqual(wald[16060511].beta, -0.097476, places=4) - self.assertAlmostEqual(wald[16060511].standard_error, 0.087478, places=4) - self.assertAlmostEqual(wald[16060511].z_stat, -1.1143, places=4) - self.assertAlmostEqual(wald[16060511].p_value, 0.26516, places=4) - self.assertAlmostEqual(lrt[16060511].p_value, 0.26475, places=4) - self.assertAlmostEqual(score[16060511].p_value, 0.26499, places=4) - self.assertAlmostEqual(firth[16060511].beta, -0.097079, places=4) - self.assertAlmostEqual(firth[16060511].p_value, 0.26593, places=4) - - self.assertAlmostEqual(wald[16115878].beta, -0.052632, places=4) - self.assertAlmostEqual(wald[16115878].standard_error, 0.11272, places=4) - self.assertAlmostEqual(wald[16115878].z_stat, -0.46691, places=4) - self.assertAlmostEqual(wald[16115878].p_value, 0.64056, places=4) - self.assertAlmostEqual(lrt[16115878].p_value, 0.64046, places=4) - self.assertAlmostEqual(score[16115878].p_value, 0.64054, places=4) - self.assertAlmostEqual(firth[16115878].beta, -0.052301, places=4) - self.assertAlmostEqual(firth[16115878].p_value, 0.64197, places=4) - - self.assertAlmostEqual(wald[16115882].beta, -0.15598, places=4) - self.assertAlmostEqual(wald[16115882].standard_error, 0.079508, places=4) - self.assertAlmostEqual(wald[16115882].z_stat, -1.9619, places=4) - self.assertAlmostEqual(wald[16115882].p_value, 0.049779, places=4) - self.assertAlmostEqual(lrt[16115882].p_value, 0.049675, places=4) - self.assertAlmostEqual(score[16115882].p_value, 0.049675, places=4) - self.assertAlmostEqual(firth[16115882].beta, -0.15567, places=4) - self.assertAlmostEqual(firth[16115882].p_value, 0.04991, places=4) - - self.assertAlmostEqual(wald[16117940].beta, -0.88059, places=4) - self.assertAlmostEqual(wald[16117940].standard_error, 0.83769, places=2) - self.assertAlmostEqual(wald[16117940].z_stat, -1.0512, places=2) - self.assertAlmostEqual(wald[16117940].p_value, 0.29316, places=2) - self.assertAlmostEqual(lrt[16117940].p_value, 0.26984, places=4) - self.assertAlmostEqual(score[16117940].p_value, 0.27828, places=4) - self.assertAlmostEqual(firth[16117940].beta, -0.7524, places=4) - self.assertAlmostEqual(firth[16117940].p_value, 0.30731, places=4) - - self.assertAlmostEqual(wald[16117953].beta, 0.54921, places=4) - self.assertAlmostEqual(wald[16117953].standard_error, 0.4517, places=3) - self.assertAlmostEqual(wald[16117953].z_stat, 1.2159, places=3) - self.assertAlmostEqual(wald[16117953].p_value, 0.22403, places=3) - self.assertAlmostEqual(lrt[16117953].p_value, 0.21692, places=4) - self.assertAlmostEqual(score[16117953].p_value, 0.21849, places=4) - self.assertAlmostEqual(firth[16117953].beta, 0.5258, places=4) - self.assertAlmostEqual(firth[16117953].p_value, 0.22562, places=4) - def test_logreg_pass_through(self): covariates = hl.import_table(resource('regressionLogistic.cov'), key='Sample', @@ -1682,3 +1588,174 @@ def test_regression_field_dependence(self): hl.logistic_regression_rows('wald', y=mt.c1, x=x_expr, covariates=[1]) hl.poisson_regression_rows('wald', y=mt.c1, x=x_expr, covariates=[1]) hl.linear_regression_rows(y=mt.c1, x=x_expr, covariates=[1]) + + +@pytest.fixture +def logistic_epacts_mt(): + # 2535 samples from 1K Genomes Project + # Locus("22", 16060511) # MAC 623 + # Locus("22", 16115878) # MAC 370 + # Locus("22", 16115882) # MAC 1207 + # Locus("22", 16117940) # MAC 7 + # Locus("22", 16117953) # MAC 21 + covariates = hl.import_table(resource('regressionLogisticEpacts.cov'), + key='IND_ID', + types={'PC1': hl.tfloat, 'PC2': hl.tfloat}) + fam = hl.import_fam(resource('regressionLogisticEpacts.fam')) + + mt = hl.import_vcf(resource('regressionLogisticEpacts.vcf')) + mt = mt.annotate_cols(**covariates[mt.s], **fam[mt.s]) + return mt + + +def test_logistic_regression_epacts_wald(logistic_epacts_mt): + mt = logistic_epacts_mt + actual = hl.logistic_regression_rows( + test='wald', + y=mt.is_case, + x=mt.GT.n_alt_alleles(), + covariates=[1.0, mt.is_female, mt.PC1, mt.PC2]).collect() + + assert actual[0].locus == hl.Locus("22", 16060511, 'GRCh37') + assert actual[0].beta == pytest.approx(-0.097476, rel=1e-4) + assert actual[0].standard_error == pytest.approx(0.087478, rel=1e-4) + assert actual[0].z_stat == pytest.approx(-1.1143, rel=1e-4) + assert actual[0].p_value == pytest.approx(0.26516, rel=1e-4) + + assert actual[1].locus == hl.Locus("22", 16115878, 'GRCh37') + assert actual[1].beta == pytest.approx(-0.052632, rel=1e-4) + assert actual[1].standard_error == pytest.approx(0.11272, rel=1e-4) + assert actual[1].z_stat == pytest.approx(-0.46691, rel=1e-4) + assert actual[1].p_value == pytest.approx(0.64056, rel=1e-4) + + assert actual[2].locus == hl.Locus("22", 16115882, 'GRCh37') + assert actual[2].beta == pytest.approx(-0.15598, rel=1e-4) + assert actual[2].standard_error == pytest.approx(0.079508, rel=1e-4) + assert actual[2].z_stat == pytest.approx(-1.9619, rel=1e-4) + assert actual[2].p_value == pytest.approx(0.049779, rel=1e-4) + + assert actual[3].locus == hl.Locus("22", 16117940, 'GRCh37') + assert actual[3].beta == pytest.approx(-0.88059, rel=1e-4) + assert actual[3].standard_error == pytest.approx(0.83769, rel=1e-2) + assert actual[3].z_stat == pytest.approx(-1.0512, rel=1e-2) + assert actual[3].p_value == pytest.approx(0.29316, rel=1e-2) + + assert actual[4].locus == hl.Locus("22", 16117953, 'GRCh37') + assert actual[4].beta == pytest.approx(0.54921, rel=1e-4) + assert actual[4].standard_error == pytest.approx(0.4517, rel=1e-3) + assert actual[4].z_stat == pytest.approx(1.2159, rel=1e-3) + assert actual[4].p_value == pytest.approx(0.22403, rel=1e-3) + + +def test_logistic_regression_epacts_lrt(logistic_epacts_mt): + mt = logistic_epacts_mt + actual = hl.logistic_regression_rows( + test='lrt', + y=mt.is_case, + x=mt.GT.n_alt_alleles(), + covariates=[1.0, mt.is_female, mt.PC1, mt.PC2] + ).collect() + + assert actual[0].locus == hl.Locus("22", 16060511, 'GRCh37') + assert actual[0].p_value == pytest.approx(0.26475, rel=1e-4) + + assert actual[1].locus == hl.Locus("22", 16115878, 'GRCh37') + assert actual[1].p_value == pytest.approx(0.64046, rel=1e-4) + + assert actual[2].locus == hl.Locus("22", 16115882, 'GRCh37') + assert actual[2].p_value == pytest.approx(0.049675, rel=1e-4) + + assert actual[3].locus == hl.Locus("22", 16117940, 'GRCh37') + assert actual[3].p_value == pytest.approx(0.26984, rel=1e-4) + + assert actual[4].locus == hl.Locus("22", 16117953, 'GRCh37') + assert actual[4].p_value == pytest.approx(0.21692, rel=1e-4) + + +def test_logistic_regression_epacts_score(logistic_epacts_mt): + # The name of this test suggests it was originally a comparison to EPACTS. The original EPACTS + # values were slightly different from the output of lowered logistic regression. I regenerated + # this test's expected values using R. + # + # 1. Export the data into an R-friendly format: + # + # mt = logistic_epacts_mt() + # mt = mt.select_cols( + # y=hl.int32(mt.is_case), + # c1=1.0, + # c2=hl.int32(mt.is_female), + # c3=mt.PC1, + # c4=mt.PC2, + # x=hl.agg.collect(mt.GT.n_alt_alleles()) + # ) + # mt = mt.transmute_cols(**{ + # f'x{i}': mt.x[i] for i in range(mt.count_rows()) + # }) + # mt.cols().export('phenos.tsv') + # + # 2. Run this model repeatedly for each x: + # + # df = read.table(file = 'phenos.csv', sep = '\t', header = TRUE) + # poisfit <- glm(df$y ~ df$c1 + df$c2 + df$c3 + df$c4 + df$x0, family="binomial") + # poisfitnull <- glm(df$y ~ df$c1 + df$c2 + df$c3 + df$c4, family="binomial") + # scoretest <- anova(poisfitnull, poisfit, test="Rao") + # chi2 <- scoretest[["Rao"]][2] + # pval <- scoretest[["Pr(>Chi)"]][2] + # + mt = logistic_epacts_mt + actual = hl.logistic_regression_rows( + test='score', + y=mt.is_case, + x=mt.GT.n_alt_alleles(), + covariates=[1.0, mt.is_female, mt.PC1, mt.PC2] + ).collect() + + assert actual[0].locus == hl.Locus("22", 16060511, 'GRCh37') + assert actual[0].chi_sq_stat == pytest.approx(1.242482, rel=1e-5) + assert actual[0].p_value == pytest.approx(0.2649933, rel=1e-5) + + assert actual[1].locus == hl.Locus("22", 16115878, 'GRCh37') + assert actual[1].chi_sq_stat == pytest.approx(0.218038, rel=1e-5) + assert actual[1].p_value == pytest.approx(0.6405389, rel=1e-5) + + assert actual[2].locus == hl.Locus("22", 16115882, 'GRCh37') + assert actual[2].chi_sq_stat == pytest.approx(3.850985, rel=1e-5) + assert actual[2].p_value == pytest.approx(0.04971679, rel=1e-5) + + assert actual[3].locus == hl.Locus("22", 16117940, 'GRCh37') + assert actual[3].chi_sq_stat == pytest.approx(1.175474, rel=1e-5) + assert actual[3].p_value == pytest.approx(0.2782793, rel=1e-5) + + assert actual[4].locus == hl.Locus("22", 16117953, 'GRCh37') + assert actual[4].chi_sq_stat == pytest.approx(1.514245, rel=1e-5) + assert actual[4].p_value == pytest.approx(0.2184924, rel=1e-5) + + +def test_logistic_regression_epacts_firth(logistic_epacts_mt): + mt = logistic_epacts_mt + actual = hl.logistic_regression_rows( + test='firth', + y=mt.is_case, + x=mt.GT.n_alt_alleles(), + covariates=[1.0, mt.is_female, mt.PC1, mt.PC2] + ).collect() + + assert actual[0].locus == hl.Locus("22", 16060511, 'GRCh37') + assert actual[0].beta == pytest.approx(-0.097079, rel=1e-4) + assert actual[0].p_value == pytest.approx(0.26593, rel=1e-4) + + assert actual[1].locus == hl.Locus("22", 16115878, 'GRCh37') + assert actual[1].beta == pytest.approx(-0.052301, rel=1e-4) + assert actual[1].p_value == pytest.approx(0.64197, rel=1e-4) + + assert actual[2].locus == hl.Locus("22", 16115882, 'GRCh37') + assert actual[2].beta == pytest.approx(-0.15567, rel=1e-4) + assert actual[2].p_value == pytest.approx(0.04991, rel=1e-4) + + assert actual[3].locus == hl.Locus("22", 16117940, 'GRCh37') + assert actual[3].beta == pytest.approx(-0.7524, rel=1e-4) + assert actual[3].p_value == pytest.approx(0.30731, rel=1e-4) + + assert actual[4].locus == hl.Locus("22", 16117953, 'GRCh37') + assert actual[4].beta == pytest.approx(0.5258, rel=1e-4) + assert actual[4].p_value == pytest.approx(0.22562, rel=1e-4) diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala index 069138ca3d7..b31f1846314 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala @@ -639,13 +639,18 @@ abstract class RegistryFunctions { impl(cb, r, rt, errorID, a1, a2) } + def registerIEmitCode3(name: String, mt1: Type, mt2: Type, mt3: Type, rt: Type, pt: (Type, EmitType, EmitType, EmitType) => EmitType) + (impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode, EmitCode, EmitCode) => IEmitCode): Unit = + registerIEmitCode(name, Array(mt1, mt2, mt3), rt, unwrappedApply(pt)) { case (cb, r, rt, errorID, Array(a1, a2, a3)) => + impl(cb, r, rt, errorID, a1, a2, a3) + } + def registerIEmitCode4(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, rt: Type, pt: (Type, EmitType, EmitType, EmitType, EmitType) => EmitType) (impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode, EmitCode, EmitCode, EmitCode) => IEmitCode): Unit = registerIEmitCode(name, Array(mt1, mt2, mt3, mt4), rt, unwrappedApply(pt)) { case (cb, r, rt, errorID, Array(a1, a2, a3, a4)) => impl(cb, r, rt, errorID, a1, a2, a3, a4) } - def registerIEmitCode5(name: String, mt1: Type, mt2: Type, mt3: Type, mt4: Type, mt5: Type, rt: Type, pt: (Type, EmitType, EmitType, EmitType, EmitType, EmitType) => EmitType) (impl: (EmitCodeBuilder, Value[Region], SType, Value[Int], EmitCode, EmitCode, EmitCode, EmitCode, EmitCode) => IEmitCode): Unit = registerIEmitCode(name, Array(mt1, mt2, mt3, mt4, mt5), rt, unwrappedApply(pt)) { case (cb, r, rt, errorID, Array(a1, a2, a3, a4, a5)) => diff --git a/hail/src/main/scala/is/hail/expr/ir/functions/NDArrayFunctions.scala b/hail/src/main/scala/is/hail/expr/ir/functions/NDArrayFunctions.scala index e66b7f7f109..eadb61ac21d 100644 --- a/hail/src/main/scala/is/hail/expr/ir/functions/NDArrayFunctions.scala +++ b/hail/src/main/scala/is/hail/expr/ir/functions/NDArrayFunctions.scala @@ -139,6 +139,21 @@ object NDArrayFunctions extends RegistryFunctions { resPCode } + registerIEmitCode3("linear_triangular_solve_no_crash", TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2)), TBoolean, TStruct(("solution", TNDArray(TFloat64, Nat(2))), ("failed", TBoolean)), + { (t, p1, p2, p3) => EmitType(PCanonicalStruct(false, ("solution", PCanonicalNDArray(PFloat64Required, 2, false)), ("failed", PBooleanRequired)).sType, false) }) { + case (cb, region, SBaseStructPointer(outputStructType: PCanonicalStruct), errorID, aec, bec, lowerec) => + aec.toI(cb).flatMap(cb) { apc => + bec.toI(cb).flatMap(cb) { bpc => + lowerec.toI(cb).map(cb) { lowerpc => + val outputNDArrayPType = outputStructType.fieldType("solution") + val (resNDPCode, info) = linear_triangular_solve(apc.asNDArray, bpc.asNDArray, lowerpc.asBoolean, outputNDArrayPType, cb, region, errorID) + val ndEmitCode = EmitCode(Code._empty, info cne 0, resNDPCode) + outputStructType.constructFromFields(cb, region, IndexedSeq[EmitCode](ndEmitCode, EmitCode(Code._empty, false, primitive(cb.memoize(info cne 0)))), false) + } + } + } + } + registerSCode3("linear_triangular_solve", TNDArray(TFloat64, Nat(2)), TNDArray(TFloat64, Nat(2)), TBoolean, TNDArray(TFloat64, Nat(2)), { (t, p1, p2, p3) => PCanonicalNDArray(PFloat64Required, 2, true).sType }) { case (er, cb, SNDArrayPointer(pt), apc, bpc, lower, errorID) =>