Skip to content

Commit

Permalink
[query] lower poisson regression (#12793)
Browse files Browse the repository at this point in the history
cc @tpoterba 

My apologies. I made several changes to lowered logistic regression as
well.

All the generalized linear model methods share the same fit result. I
abstracted this into one datatype at the top of `statgen.py`:
`numerical_regression_fit_dtype`.

---

You'll notice I moved the cases such that we check for convergence
*before* checking if we are at the maximum iteration. It seemed to me
that:
- `max_iter == 0` means do not even attempt to fit.
- `max_iter == 1` means take one gradient step, if you've converged,
then return successfully, otherwise fail.
- etc. The `main` branch currently always fails if you set `max_iter ==
1`, even if the first step lands on the true maximum likelihood fit.

I substantially refactored logistic regression. There were dead code
paths (e.g. the covariates array is known to be non-empty). I also found
all the function currying and comingling of fitting and testing really
confusing. To be fair, the Scala code does this (and its really
confusing). I think the current structure is easier to follow:

1. Fit the null model.
2. If wald, assume the beta for the genotypes is zero and use the rest
of the parameters from the null model fit to compute the score (i.e. the
gradient of the likelihood). Recall calculus: gradient near zero =>
value near the maximum. Return: this is the test.
3. Otherwise, fit the full model starting at the null fit parameters.
4. Test the "goodness" of this new & full fit.

---

Poisson regression is similar but with a different likelihood function
and gradient thereof. Notice that I `key_cols_by()` to indicate to Hail
that the order of the cols is irrelevant (the result is a locus-keyed
table after all). This is necessary at least until #12753 merges. I
think it's generally a good idea though: it indicates to Hail that the
ordering of the columns is irrelevant, which is potentially useful
information for the optimizer!

---

Both logistic and Poisson regression can benefit from BLAS3 by running
at least the score test for multiple variants at once.

---

I'll attach an image in the comments, but I spend ~6 seconds compiling
this trivial model and ~140ms testing it.

```python3
import hail as hl
mt = hl.utils.range_matrix_table(1, 3)
mt = mt.annotate_entries(x=hl.literal([1, 3, 10, 5]))
ht = hl.poisson_regression_rows(
    'wald', y=hl.literal([0, 1, 1, 0])[mt.col_idx], x=mt.x[mt.col_idx], covariates=[1], max_iterations=2)
ht.collect()
```

I grabbed some [sample code from

scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.PoissonRegressor.html)
for Poisson regression (doing a score test rather than a wald test) and
timed it. It takes ~8ms. So we're 3 orders of magnitude including the
compiler, and ~1.2 orders of magnitude off without the compiler. Digging
in a bit:
- ~65ms for class loading.
- ~15ms for region allocation.
- ~20ms various little spots. Leaving about 40ms strictly executing
generated code That's about 5x which is starting to feel reasonable.
  • Loading branch information
danking authored Mar 21, 2023
1 parent 2bf84f3 commit 2eab792
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 120 deletions.
17 changes: 11 additions & 6 deletions hail/python/hail/expr/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _seeded_func(name, ret_type, seed, *args):
def ndarray_broadcasting(func):
def broadcast_or_not(x):
if isinstance(x.dtype, tndarray):
return x.map(lambda term: func(term))
return x.map(func)
else:
return func(x)
return broadcast_or_not
Expand Down Expand Up @@ -1748,7 +1748,7 @@ def parse_json(x, dtype):
return _func("parse_json", ttuple(dtype), x, type_args=(dtype,))[0]


@typecheck(x=expr_float64, base=nullable(expr_float64))
@typecheck(x=oneof(expr_float64, expr_ndarray(expr_float64)), base=nullable(expr_float64))
def log(x, base=None) -> Float64Expression:
"""Take the logarithm of the `x` with base `base`.
Expand Down Expand Up @@ -1777,11 +1777,16 @@ def log(x, base=None) -> Float64Expression:
-------
:class:`.Expression` of type :py:data:`.tfloat64`
"""
def scalar_log(x):
if base is not None:
return _func("log", tfloat64, x, to_expr(base))
else:
return _func("log", tfloat64, x)

x = to_expr(x)
if base is not None:
return _func("log", tfloat64, x, to_expr(base))
else:
return _func("log", tfloat64, x)
if isinstance(x.dtype, tndarray):
return x.map(scalar_log)
return scalar_log(x)


@typecheck(x=oneof(expr_float64, expr_ndarray(expr_float64)))
Expand Down
Loading

0 comments on commit 2eab792

Please sign in to comment.