diff --git a/src/distrs/beta.jl b/src/distrs/beta.jl index 0276c57..c7450a4 100644 --- a/src/distrs/beta.jl +++ b/src/distrs/beta.jl @@ -20,6 +20,28 @@ betapdf(α::Real, β::Real, x::Real) = exp(betalogpdf(α, β, x)) betalogpdf(α::Real, β::Real, x::Real) = betalogpdf(promote(α, β, x)...) function betalogpdf(α::T, β::T, x::T) where {T<:Real} + # Handle degenerate cases + xf = float(typeof(x)) + if isinf(α) + if isinf(β) + return float(last(promote(α, β, x, + x == .5 ? convert(xf, NaN) : convert(xf, -Inf) + ))) + else + return float(last(promote(α, β, x, + x == 1 ? convert(xf, NaN) : convert(xf, -Inf) + ))) + end + elseif (iszero(α) && β > 0) || isinf(β) + return float(last(promote(α, β, x, + x == 0 ? convert(xf, NaN) : convert(xf, -Inf) + ))) + elseif iszero(β) && α > 0 + return float(last(promote(α, β, x, + x == 1 ? convert(xf, NaN) : convert(xf, -Inf) + ))) + end + # we ensure that `log(x)` and `log1p(-x)` do not error y = clamp(x, 0, 1) val = xlogy(α - 1, y) + xlog1py(β - 1, -y) - logbeta(α, β) @@ -28,7 +50,13 @@ end function betacdf(α::Real, β::Real, x::Real) # Handle degenerate cases - if iszero(α) && β > 0 + if isinf(α) + if isinf(β) + return float(last(promote(α, β, x, x >= 0.5f0))) + else + return float(last(promote(α, β, x, x >= 1))) + end + elseif (iszero(α) && β > 0) || isinf(β) return float(last(promote(α, β, x, x >= 0))) elseif iszero(β) && α > 0 return float(last(promote(α, β, x, x >= 1))) @@ -39,7 +67,13 @@ end function betaccdf(α::Real, β::Real, x::Real) # Handle degenerate cases - if iszero(α) && β > 0 + if isinf(α) + if isinf(β) + return float(last(promote(α, β, x, x < 0.5f0))) + else + return float(last(promote(α, β, x, x < 1))) + end + elseif (iszero(α) && β > 0) || isinf(β) return float(last(promote(α, β, x, x < 0))) elseif iszero(β) && α > 0 return float(last(promote(α, β, x, x < 1))) @@ -52,7 +86,13 @@ end # to an implementation based on the hypergeometric function ₂F₁ to avoid underflow. function betalogcdf(α::T, β::T, x::T) where {T<:Real} # Handle degenerate cases - if iszero(α) && β > 0 + if isinf(α) + if isinf(β) + return log(last(promote(x, x >= 0.5f0))) + else + return log(last(promote(x, x >= 1))) + end + elseif (iszero(α) && β > 0) || isinf(β) return log(last(promote(x, x >= 0))) elseif iszero(β) && α > 0 return log(last(promote(x, x >= 1))) @@ -74,7 +114,13 @@ betalogcdf(α::Real, β::Real, x::Real) = betalogcdf(promote(α, β, x)...) function betalogccdf(α::Real, β::Real, x::Real) # Handle degenerate cases - if iszero(α) && β > 0 + if isinf(α) + if isinf(β) + return log(last(promote(α, β, x, x < 0.5f0))) + else + return log(last(promote(α, β, x, x < 1))) + end + elseif (iszero(α) && β > 0) || isinf(β) return log(last(promote(α, β, x, x < 0))) elseif iszero(β) && α > 0 return log(last(promote(α, β, x, x < 1))) @@ -91,10 +137,16 @@ end function betainvcdf(α::Real, β::Real, p::Real) # Handle degenerate cases if 0 ≤ p ≤ 1 - if iszero(α) && β > 0 - return last(promote(α, β, p, false)) + if isinf(α) + if isinf(β) + return last(promote(α, β, p, convert(float(typeof(p)), 0.5))) + else + return last(promote(α, β, p, 1)) + end + elseif (iszero(α) && β > 0) || isinf(β) + return last(promote(α, β, p, 0)) elseif iszero(β) && α > 0 - return last(promote(α, β, p, p > 0)) + return last(promote(α, β, p, 1)) end end @@ -104,12 +156,18 @@ end function betainvccdf(α::Real, β::Real, p::Real) # Handle degenerate cases if 0 ≤ p ≤ 1 - if iszero(α) && β > 0 - return last(promote(α, β, p, p == 0)) + if isinf(α) + if isinf(β) + return last(promote(α, β, p, convert(float(typeof(p)), 0.5))) + else + return last(promote(α, β, p, 1)) + end + elseif (iszero(α) && β > 0) || isinf(β) + return last(promote(α, β, p, 0)) elseif iszero(β) && α > 0 - return last(promote(α, β, p, true)) + return last(promote(α, β, p, 1)) end end return last(beta_inc_inv(β, α, p)) -end +end \ No newline at end of file diff --git a/test/rmath.jl b/test/rmath.jl index 0276353..9773c71 100644 --- a/test/rmath.jl +++ b/test/rmath.jl @@ -191,28 +191,61 @@ end # Beta(α, 0) is a Dirac distribution at x=1 α = β = 1//2 - for x in 0f0:0.01f0:1f0 + for x in -1f0:0.05f0:1f0 + # Check betapdf + @test @inferred(betapdf(0, β, x)) === (x == 0 ? NaN32 : 0f0) + @test @inferred(betapdf(α, 0, x)) === (x == 1 ? NaN32 : 0f0) + @test @inferred(betapdf(Inf32, β, x)) === (x == 1 ? NaN32 : 0f0) + @test @inferred(betapdf(α, Inf32, x)) === (x == 0 ? NaN32 : 0f0) + @test @inferred(betapdf(Inf32, Inf32, x)) === (x === 0.5f0 ? NaN32 : 0f0) + + # Check betalogpdf + @test @inferred(betalogpdf(0, β, x)) === (x == 0 ? NaN32 : -Inf32) + @test @inferred(betalogpdf(α, 0, x)) === (x == 1 ? NaN32 : -Inf32) + @test @inferred(betalogpdf(Inf32, β, x)) === (x == 1 ? NaN32 : -Inf32) + @test @inferred(betalogpdf(α, Inf32, x)) === (x == 0 ? NaN32 : -Inf32) + @test @inferred(betalogpdf(Inf32, Inf32, x)) === (x === 0.5f0 ? NaN32 : -Inf32) + # Check betacdf - @test @inferred(betacdf(0, β, x)) === 1f0 + @test @inferred(betacdf(0, β, x)) === (x < 0 ? 0f0 : 1f0) @test @inferred(betacdf(α, 0, x)) === (x < 1 ? 0f0 : 1f0) - + @test @inferred(betacdf(Inf32, β, x)) === (x < 1 ? 0f0 : 1f0) + @test @inferred(betacdf(α, Inf32, x)) === (x < 0 ? 0f0 : 1f0) + @test @inferred(betacdf(Inf32, Inf32, x)) === (x < .5 ? 0f0 : 1f0) + # Check betaccdf, betalogcdf, and betalogccdf based on betacdf @test @inferred(betaccdf(0, β, x)) === 1 - betacdf(0, β, x) @test @inferred(betaccdf(α, 0, x)) === 1 - betacdf(α, 0, x) + @test @inferred(betaccdf(Inf32, β, x)) === 1 - betacdf(Inf32, β, x) + @test @inferred(betaccdf(α, Inf32, x)) === 1 - betacdf(α, Inf32, x) + @test @inferred(betaccdf(Inf32, Inf32, x)) === 1 - betacdf(Inf32, Inf32, x) + @test @inferred(betalogcdf(0, β, x)) === log(betacdf(0, β, x)) @test @inferred(betalogcdf(α, 0, x)) === log(betacdf(α, 0, x)) + @test @inferred(betalogcdf(Inf32, β, x)) === log(betacdf(Inf32, β, x)) + @test @inferred(betalogcdf(α, Inf32, x)) === log(betacdf(α, Inf32, x)) + @test @inferred(betalogcdf(Inf32, Inf32, x)) === log(betacdf(Inf32, Inf32, x)) + @test @inferred(betalogccdf(0, β, x)) === log(betaccdf(0, β, x)) @test @inferred(betalogccdf(α, 0, x)) === log(betaccdf(α, 0, x)) + @test @inferred(betalogccdf(Inf32, β, x)) === log(betaccdf(Inf32, β, x)) + @test @inferred(betalogccdf(α, Inf32, x)) === log(betaccdf(α, Inf32, x)) + @test @inferred(betalogccdf(Inf32, Inf32, x)) === log(betaccdf(Inf32, Inf32, x)) end - for p in 0f0:0.01f0:1f0 + for p in 0f0:0.05f0:1f0 # Check betainvcdf @test @inferred(betainvcdf(0, β, p)) === 0f0 - @test @inferred(betainvcdf(α, 0, p)) === (p > 0 ? 1f0 : 0f0) + @test @inferred(betainvcdf(α, 0, p)) === 1f0 + @test @inferred(betainvcdf(Inf32, β, p)) === 1f0 + @test @inferred(betainvcdf(α, Inf32, p)) === 0f0 + @test @inferred(betainvcdf(Inf32, Inf32, p)) === 0.5f0 # Check betainvccdf - @test @inferred(betainvccdf(0, β, p)) === (p > 0 ? 0f0 : 1f0) + @test @inferred(betainvccdf(0, β, p)) === 0f0 @test @inferred(betainvccdf(α, 0, p)) === 1f0 + @test @inferred(betainvccdf(Inf32, β, p)) === 1f0 + @test @inferred(betainvccdf(Inf32, Inf32, p)) === 0.5f0 end end