Skip to content

Commit

Permalink
move vegasmc to ext
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Dec 30, 2023
1 parent b7270f6 commit cd61b8a
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 65 deletions.
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "4.1.0"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
HCubature = "19dc6840-f33b-545b-b366-655c7e3ffd49"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
MonteCarloIntegration = "4886b29c-78c9-11e9-0a6e-41e1f4161f7b"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand All @@ -20,6 +19,7 @@ Cuba = "8a292aeb-7a57-582c-b821-06e4c11590b1"
Cubature = "667455a9-e2ce-5579-9412-b964f529a492"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
Expand All @@ -28,6 +28,7 @@ IntegralsCubaExt = "Cuba"
IntegralsCubatureExt = "Cubature"
IntegralsFastGaussQuadratureExt = "FastGaussQuadrature"
IntegralsForwardDiffExt = "ForwardDiff"
IntegralsMCIntegrationExt = "MCIntegration"
IntegralsZygoteExt = ["Zygote", "ChainRulesCore"]

[compat]
Expand Down Expand Up @@ -66,6 +67,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MCIntegration = "ea1e2de9-7db7-4b42-91ee-0cd1bf6df167"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Expand All @@ -74,4 +76,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "Arblib", "SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature"]
test = ["Aqua", "Arblib", "SciMLSensitivity", "StaticArrays", "FiniteDiff", "Pkg", "SafeTestsets", "Test", "Distributions", "ForwardDiff", "Zygote", "ChainRulesCore", "FastGaussQuadrature", "Cuba", "Cubature", "MCIntegration"]
48 changes: 48 additions & 0 deletions ext/IntegralsMCIntegrationExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
module IntegralsMCIntegrationExt

using MCIntegration, Integrals

function Integrals.__solvebp_call(prob::IntegralProblem, alg::VEGASMC, sensealg, domain, p;
reltol = nothing, abstol = nothing, maxiters = 1000)
lb, ub = domain
mid = vec(collect((lb + ub) / 2))
vars = Continuous(vec([tuple(a,b) for (a,b) in zip(lb, ub)]))

if prob.f isa BatchIntegralFunction
error("VEGASMC doesn't support batching. See https://github.com/numericalEFT/MCIntegration.jl/issues/29")
else
if isinplace(prob)
f0 = similar(prob.f.integrand_prototype)
f_ = (x, f, c) -> begin
n = 0
for v in x
mid[n+=1] = first(v)
end
prob.f(f0, mid, p)
f .= vec(f0)
end
else
f0 = prob.f(mid, p)
f_ = (x, c) -> begin
n = 0
for v in x
mid[n+=1] = first(v)
end
fx = prob.f(mid, p)
fx isa AbstractArray ? vec(fx) : fx
end
end
dof = ones(Int, length(f0)) # each composite Continuous var gets 1 dof
res = integrate(f_, inplace=isinplace(prob), var=vars, dof=dof, solver=:vegasmc,
neval=alg.neval, niter=min(maxiters,alg.niter), block=alg.block, adapt=alg.adapt,
gamma=alg.gamma, verbose=alg.verbose, debug=alg.debug, type=eltype(f0), print=-2)
out, err, chi = if f0 isa Number
map(only, (res.mean, res.stdev, res.chi2))
else
map(a -> reshape(a, size(f0)), (res.mean, res.stdev, res.chi2))
end
SciMLBase.build_solution(prob, VEGASMC(), out, err, chi=chi, retcode = ReturnCode.Success)
end
end

end
45 changes: 1 addition & 44 deletions src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ if !isdefined(Base, :get_extension)
using Requires
end

using Reexport, MonteCarloIntegration, QuadGK, HCubature, MCIntegration
using Reexport, MonteCarloIntegration, QuadGK, HCubature
@reexport using SciMLBase
using LinearAlgebra

Expand Down Expand Up @@ -218,49 +218,6 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, domain, p;
end


function __solvebp_call(prob::IntegralProblem, alg::VEGASMC, sensealg, domain, p;
reltol = nothing, abstol = nothing, maxiters = 1000)
lb, ub = domain
mid = collect((lb + ub) / 2)
vars = Continuous(vec([tuple(a,b) for (a,b) in zip(lb, ub)]))

if prob.f isa BatchIntegralFunction
error("VEGASMC doesn't support batching. See https://github.com/numericalEFT/MCIntegration.jl/issues/29")
else
if isinplace(prob)
f0 = similar(prob.f.integrand_prototype)
f_ = (x, f, c) -> begin
n = 0
for v in x
mid[n+=1] = first(v)
end
prob.f(f0, mid, p)
f .= vec(f0)
end
else
f0 = prob.f(mid, p)
f_ = (x, c) -> begin
n = 0
for v in x
mid[n+=1] = first(v)
end
prob.f(mid, p)
end
end
dof = f0 isa Number ? 1 : ones(Int, length(f0))
res = integrate(f_, inplace=isinplace(prob), var=vars, dof=dof, solver=:vegasmc,
neval=alg.neval, niter=min(maxiters,alg.niter), block=alg.block, adapt=alg.adapt,
gamma=alg.gamma, verbose=alg.verbose, debug=alg.debug)
out, err, chi = if f0 isa Number
map(only, (res.mean, res.stdev, res.chi2))
else
map(a -> reshape(a, size(f0)), (res.mean, res.stdev, res.chi2))
end
SciMLBase.build_solution(prob, VEGASMC(), out, err, chi=chi, retcode = ReturnCode.Success)
end
end


export QuadGKJL, HCubatureJL, VEGAS, VEGASMC, GaussLegendre, QuadratureRule, TrapezoidalRule
export CubaVegas, CubaSUAVE, CubaDivonne, CubaCuhre
export CubatureJLh, CubatureJLp
Expand Down
35 changes: 17 additions & 18 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,6 @@ struct VEGAS <: SciMLBase.AbstractIntegralAlgorithm
end
VEGAS(; nbins = 100, ncalls = 1000, debug = false) = VEGAS(nbins, ncalls, debug)


"""
VEGASMC()
Markov-chain based Vegas algorithm from MCIntegration.jl
"""
struct VEGASMC <: SciMLBase.AbstractIntegralAlgorithm
neval::Int
niter::Int
block::Int
adapt::Bool
gamma::Float64
verbose::Int
debug::Bool
end
VEGASMC(; neval=10^4, niter=20, block=16, adapt=true, gamma=1.0, verbose=-2, debug=false) =
VEGASMC(neval, niter, block, adapt, gamma, verbose, debug)

"""
GaussLegendre{C, N, W}
Expand Down Expand Up @@ -388,3 +370,20 @@ end
function ArblibJL(; check_analytic=false, take_prec=false, warn_on_no_convergence=false, opts=C_NULL)
return ArblibJL(check_analytic, take_prec, warn_on_no_convergence, opts)
end

"""
VEGASMC(; neval=10^4, niter=20, block=16, adapt=true, gamma=1.0, verbose=-2, debug=false)
Markov-chain based Vegas algorithm from MCIntegration.jl
"""
struct VEGASMC <: SciMLBase.AbstractIntegralAlgorithm
neval::Int
niter::Int
block::Int
adapt::Bool
gamma::Float64
verbose::Int
debug::Bool
end
VEGASMC(; neval=10^4, niter=20, block=16, adapt=true, gamma=1.0, verbose=-2, debug=false) =

Check warning on line 388 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L388

Added line #L388 was not covered by tests
VEGASMC(neval, niter, block, adapt, gamma, verbose, debug)
2 changes: 1 addition & 1 deletion test/interface_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Integrals
using Cuba, Cubature, Arblib
using Cuba, Cubature, Arblib, MCIntegration
using Test

max_dim_test = 2
Expand Down

0 comments on commit cd61b8a

Please sign in to comment.