Skip to content

Commit

Permalink
Merge branch 'main' into hatchling
Browse files Browse the repository at this point in the history
  • Loading branch information
maresb authored Jan 31, 2023
2 parents 888c52c + d789a5f commit 051652b
Show file tree
Hide file tree
Showing 6 changed files with 445 additions and 3 deletions.
177 changes: 177 additions & 0 deletions aesara/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,3 +1481,180 @@ def c_code(self, *args, **kwargs):


betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")


class Hyp2F1(ScalarOp):
"""
Gaussian hypergeometric function ``2F1(a, b; c; z)``.
"""

nin = 4
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)

@staticmethod
def st_impl(a, b, c, z):
return scipy.special.hyp2f1(a, b, c, z)

def impl(self, a, b, c, z):
return Hyp2F1.st_impl(a, b, c, z)

def grad(self, inputs, grads):
a, b, c, z = inputs
(gz,) = grads
return [
gz * hyp2f1_der(a, b, c, z, wrt=0),
gz * hyp2f1_der(a, b, c, z, wrt=1),
gz * hyp2f1_der(a, b, c, z, wrt=2),
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")


class Hyp2F1Der(ScalarOp):
"""Derivatives of the Gaussian hypergeometric function :math:`2_F_1(a, b; c; z)`.
This is only implemented for one of the first three inputs.
Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
"""

nin = 5

def impl(self, a, b, c, z, wrt):
def check_2f1_converges(a, b, c, z) -> bool:
num_terms = 0
is_polynomial = False

def is_nonpositive_integer(x):
return x <= 0 and x.is_integer()

if is_nonpositive_integer(a) and abs(a) >= num_terms:
is_polynomial = True
num_terms = int(np.floor(abs(a)))
if is_nonpositive_integer(b) and abs(b) >= num_terms:
is_polynomial = True
num_terms = int(np.floor(abs(b)))

is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms

return not is_undefined and (
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
)

def compute_grad_2f1(a, b, c, z, wrt):
r"""
Notes
-----
The algorithm can be derived by looking at the ratio of two successive terms in the series:
.. math::
\beta_{k+1} / \beta_{k} = A(k) / B(k) \\
\beta_{k+1} = A(k) / B(k) \beta_{k} \\
d[\beta_{k+1}] = d[A(k) / B(k)] \beta_{k} + A(k) / B(k) d[\beta_{k}]
via the product rule.
In the :math:`2_F_1`, :math:`A(k) / B(k)` corresponds to
:math:`(((a + k) (b + k) / ((c + k) (1 + k))) z` The partial
:math:`d[A(k)/B(k)]` with respect to the first three inputs can be
obtained from the ratio :math:`A(k)/B(k)`, by dropping the
respective term
.. math::
d/da[A(k) / B(k)] = A(k) / B(k) / (a + k) \\
d/db[A(k) / B(k)] = A(k) / B(k) / (b + k) \\
d/dc[A(k) / B(k)] = A(k) / B(k) (c + k)
The algorithm is implemented in the log scale, which adds the
complexity of working with absolute terms and tracking their signs.
"""

wrt_a = wrt_b = False
if wrt == 0:
wrt_a = True
elif wrt == 1:
wrt_b = True
elif wrt != 2:
raise ValueError(f"wrt must be 0, 1, or 2; got {wrt}")

min_steps = 10 # https://github.com/stan-dev/math/issues/2857
max_steps = int(1e6)
precision = 1e-14

res = 0

if z == 0:
return res

log_g_old = -np.inf
log_t_old = 0.0
log_t_new = 0.0
sign_z = np.sign(z)
log_z = np.log(np.abs(z))

log_g_old_sign = 1
log_t_old_sign = 1
log_t_new_sign = 1
sign_zk = sign_z

for k in range(max_steps):
p = (a + k) * (b + k) / ((c + k) * (k + 1))
if p == 0:
return res
log_t_new += np.log(np.abs(p)) + log_z
log_t_new_sign = np.sign(p) * log_t_new_sign

term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
if wrt_a:
term += np.reciprocal(a + k)
elif wrt_b:
term += np.reciprocal(b + k)
else:
term -= np.reciprocal(c + k)

log_g_old = log_t_new + np.log(np.abs(term))
log_g_old_sign = np.sign(term) * log_t_new_sign
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
res += g_current

log_t_old = log_t_new
log_t_old_sign = log_t_new_sign
sign_zk *= sign_z

if k >= min_steps and np.abs(g_current) <= precision:
return res

warnings.warn(
f"hyp2f1_der did not converge after {k} iterations",
RuntimeWarning,
)
return np.nan

# TODO: We could implement the Euler transform to expand supported domain, as Stan does
if not check_2f1_converges(a, b, c, z):
warnings.warn(
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
RuntimeWarning,
)
return np.nan

return compute_grad_2f1(a, b, c, z, wrt=wrt)

def __call__(self, a, b, c, z, wrt):
# This allows wrt to be a keyword argument
return super().__call__(a, b, c, z, wrt)

def c_code(self, *args, **kwargs):
raise NotImplementedError()


hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
5 changes: 5 additions & 0 deletions aesara/tensor/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,11 @@ def conj_inplace(a):
"""elementwise conjugate (inplace on `a`)"""


@scalar_elemwise
def hyp2f1_inplace(a, b, c, z):
"""gaussian hypergeometric function"""


pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))
Expand Down
12 changes: 12 additions & 0 deletions aesara/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,6 +1386,16 @@ def gammal(k, x):
"""Lower incomplete gamma function."""


@scalar_elemwise
def hyp2f1(a, b, c, z):
"""Gaussian hypergeometric function."""


@scalar_elemwise
def hyp2f1_der(a, b, c, z):
"""Derivatives for Gaussian hypergeometric function."""


@scalar_elemwise
def j0(x):
"""Bessel function of the first kind of order 0."""
Expand Down Expand Up @@ -3128,6 +3138,8 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
"power",
"logaddexp",
"logsumexp",
"hyp2f1",
"hyp2f1_der",
]

DEPRECATED_NAMES = [
Expand Down
14 changes: 13 additions & 1 deletion aesara/tensor/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from aesara.graph.basic import Apply
from aesara.link.c.op import COp
from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import neg, sum
from aesara.tensor.math import gamma, neg, sum


class SoftmaxGrad(COp):
Expand Down Expand Up @@ -768,7 +768,19 @@ def log_softmax(c, axis=UNSET_AXIS):
return LogSoftmax(axis=axis)(c)


def poch(z, m):
"""Compute the Pochhammer/rising factorial."""
return gamma(z + m) / gamma(z)


def factorial(n):
"""Compute the factorial."""
return gamma(n + 1)


__all__ = [
"softmax",
"log_softmax",
"poch",
"factorial",
]
Loading

0 comments on commit 051652b

Please sign in to comment.