Skip to content

Commit

Permalink
Merge pull request #26259 from Qazalbash:scipy-expon
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 723576962
  • Loading branch information
Google-ML-Automation committed Feb 5, 2025
2 parents c46b021 + 7fc605f commit 1eda5e2
Show file tree
Hide file tree
Showing 4 changed files with 282 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/jax.scipy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,11 @@ jax.scipy.stats.expon

logpdf
pdf
logcdf
cdf
logsf
sf
ppf

jax.scipy.stats.gamma
~~~~~~~~~~~~~~~~~~~~~
Expand Down
193 changes: 192 additions & 1 deletion jax/_src/scipy/stats/expon.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from jax import lax

import jax.numpy as jnp
from jax import lax
from jax._src.numpy.util import promote_args_inexact
from jax._src.typing import Array, ArrayLike

Expand Down Expand Up @@ -41,7 +42,13 @@ def logpdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
array of logpdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.logpdf", x, loc, scale)
log_scale = lax.log(scale)
Expand Down Expand Up @@ -73,6 +80,190 @@ def pdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.exp(logpdf(x, loc, scale))


def cdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential cumulative density function.
JAX implementation of :obj:`scipy.stats.expon` ``cdf``.
The cdf is defined as
.. math::
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
where :math:`f_{pdf}` is the exponential distribution probability density function,
:func:`jax.scipy.stats.expon.pdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.cdf", x, loc, scale)
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
return jnp.where(
lax.lt(x, loc),
jnp.zeros_like(neg_scaled_x),
lax.neg(lax.expm1(neg_scaled_x)),
)


def logcdf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log cumulative density function.
JAX implementation of :obj:`scipy.stats.expon` ``logcdf``.
The cdf is defined as
.. math::
f_{cdf}(x) = \int_{-\infty}^x f_{pdf}(y)\mathrm{d}y
where :math:`f_{pdf}` is the exponential distribution probability density function,
:func:`jax.scipy.stats.expon.pdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.log1p(lax.neg(sf(x, loc, scale)))


def logsf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential log survival function.
JAX implementation of :obj:`scipy.stats.expon` ``logsf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
:func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
x, loc, scale = promote_args_inexact("expon.sf", x, loc, scale)
neg_scaled_x = lax.div(lax.sub(loc, x), scale)
return jnp.where(lax.lt(x, loc), jnp.zeros_like(neg_scaled_x), neg_scaled_x)


def sf(x: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential survival function.
JAX implementation of :obj:`scipy.stats.expon` ``sf``.
The survival function is defined as
.. math::
f_{sf}(x) = 1 - f_{cdf}(x)
where :math:`f_{cdf}(x)` is the exponential cumulative distribution function,
:func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
return lax.exp(logsf(x, loc, scale))


def ppf(q: ArrayLike, loc: ArrayLike = 0, scale: ArrayLike = 1) -> Array:
r"""Exponential survival function.
JAX implementation of :obj:`scipy.stats.expon` ``ppf``.
The percent point function is defined as the inverse of the
cumulative distribution function, :func:`jax.scipy.stats.expon.cdf`.
Args:
x: arraylike, value at which to evaluate the PDF
loc: arraylike, distribution offset parameter
scale: arraylike, distribution scale parameter
Returns:
array of pdf values.
See Also:
:func:`jax.scipy.stats.expon.cdf`
:func:`jax.scipy.stats.expon.pdf`
:func:`jax.scipy.stats.expon.ppf`
:func:`jax.scipy.stats.expon.sf`
:func:`jax.scipy.stats.expon.logcdf`
:func:`jax.scipy.stats.expon.logpdf`
:func:`jax.scipy.stats.expon.logsf`
"""
q, loc, scale = promote_args_inexact("expon.ppf", q, loc, scale)
neg_scaled_q = lax.div(lax.sub(loc, q), scale)
return jnp.where(
jnp.isnan(q) | (q < 0) | (q > 1),
jnp.nan,
lax.neg(lax.log1p(neg_scaled_q)),
)
5 changes: 5 additions & 0 deletions jax/scipy/stats/expon.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570

from jax._src.scipy.stats.expon import (
cdf as cdf,
logcdf as logcdf,
logpdf as logpdf,
logsf as logsf,
pdf as pdf,
ppf as ppf,
sf as sf,
)
80 changes: 80 additions & 0 deletions tests/scipy_stats_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,86 @@ def args_maker():
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
def testExponLogCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.logcdf
lax_fun = lsp_stats.expon.logcdf

def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
def testExponCdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.cdf
lax_fun = lsp_stats.expon.cdf

def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
def testExponSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.sf
lax_fun = lsp_stats.expon.sf

def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
def testExponLogSf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.logsf
lax_fun = lsp_stats.expon.logsf

def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(3)
def testExponPpf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
scipy_fun = osp_stats.expon.ppf
lax_fun = lsp_stats.expon.ppf

def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
return [q, loc, scale]

with jtu.strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, check_dtypes=False, tol=5e-4
)
self._CompileAndCheck(lax_fun, args_maker)

@genNamedParametersNArgs(4)
def testGammaLogPdf(self, shapes, dtypes):
rng = jtu.rand_positive(self.rng())
Expand Down

0 comments on commit 1eda5e2

Please sign in to comment.