Skip to content

Commit

Permalink
Tapir -> Mooncake (???)
Browse files Browse the repository at this point in the history
  • Loading branch information
penelopeysm committed Oct 24, 2024
1 parent a962770 commit dd05dc2
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 31 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/AD.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@ jobs:
AD:
- Enzyme
- ForwardDiff
- Tapir
- Mooncake
- Tracker
- ReverseDiff
- Zygote
exclude:
- version: 1.6
AD: Tapir
# TODO(mhauru) Hopefully can enable Enzyme on older versions at some point, see
# discussion in https://github.com/TuringLang/Bijectors.jl/pull.
- version: 1.6
Expand Down
10 changes: 5 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.13.19"
version = "0.14.0"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down Expand Up @@ -29,8 +29,8 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
Expand All @@ -39,8 +39,8 @@ BijectorsEnzymeExt = "Enzyme"
BijectorsForwardDiffExt = "ForwardDiff"
BijectorsLazyArraysExt = "LazyArrays"
BijectorsReverseDiffExt = "ReverseDiff"
BijectorsMooncakeExt = "Mooncake"
BijectorsTrackerExt = "Tracker"
BijectorsTapirExt = "Tapir"
BijectorsZygoteExt = "Zygote"

[compat]
Expand All @@ -65,7 +65,7 @@ Requires = "0.5, 1"
ReverseDiff = "1"
Roots = "1.3.4, 2"
Statistics = "1"
Tapir = "0.2.23"
Mooncake = "0.4.19"
Tracker = "0.2"
Zygote = "0.6.63"
julia = "1.6"
Expand All @@ -76,6 +76,6 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
15 changes: 8 additions & 7 deletions ext/BijectorsTapirExt.jl → ext/BijectorsMooncakeExt.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module BijectorsTapirExt
module BijectorsMooncakeExt

if isdefined(Base, :get_extension)
using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule
using Mooncake:
@is_primitive, MinimalCtx, Mooncake, CoDual, primal, tangent_type, @from_rrule
using Bijectors: find_alpha, ChainRulesCore
else
using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule
using ..Mooncake: @is_primitive, MinimalCtx, Mooncake, primal, tangent_type, @from_rrule
using ..Bijectors: find_alpha, ChainRulesCore
end

Expand All @@ -19,20 +20,20 @@ end
# unusual Integer type is encountered.
@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat})

function Tapir.rrule!!(
function Mooncake.rrule!!(
::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I}
) where {P<:Base.IEEEFloat,I<:Integer}
# Require that the integer is non-differentiable.
if tangent_type(I) != Tapir.NoTangent
if tangent_type(I) != Mooncake.NoTangent
msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent."
throw(ArgumentError(msg))
end
out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z))
function find_alpha_pb(dout::P)
_, dx, dy, _ = pb(dout)
return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData()
return Mooncake.NoRData(), P(dx), P(dy), Mooncake.NoRData()
end
return Tapir.zero_fcodual(out), find_alpha_pb
return Mooncake.zero_fcodual(out), find_alpha_pb
end

end
14 changes: 7 additions & 7 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,37 +27,37 @@ end
test_frule(Bijectors.find_alpha, x, y, z)
test_rrule(Bijectors.find_alpha, x, y, z)

if @isdefined Tapir
if @isdefined Mooncake
rng = Xoshiro(123456)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
z;
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
3;
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
Tapir.TestUtils.test_rule(
Mooncake.TestUtils.test_rule(
rng,
Bijectors.find_alpha,
x,
y,
UInt32(3);
is_primitive=true,
perf_flag=:none,
interp=Tapir.TapirInterpreter(),
interp=Mooncake.MooncakeInterpreter(),
)
end

Expand Down
10 changes: 5 additions & 5 deletions test/ad/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
b in (
:ForwardDiff,
:Zygote,
:Tapir,
:Mooncake,
:ReverseDiff,
:Enzyme,
:EnzymeForward,
Expand Down Expand Up @@ -78,12 +78,12 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
end
end

if (AD == "All" || AD == "Tapir") && VERSION >= v"1.10"
rule = Tapir.build_rrule(f, x; safety_on=false)
if (AD == "All" || AD == "Mooncake") && VERSION >= v"1.10"
rule = Mooncake.build_rrule(f, x; safety_on=false)
if :tapir in broken
@test_broken(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
Mooncake.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
Expand All @@ -92,7 +92,7 @@ function test_ad(f, x, broken=(); rtol=1e-6, atol=1e-6)
else
@test(
isapprox(
Tapir.value_and_gradient!!(rule, f, x)[2][2],
Mooncake.value_and_gradient!!(rule, f, x)[2][2],
finitediff;
rtol=rtol,
atol=atol,
Expand Down
8 changes: 4 additions & 4 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ if VERSION < v"1.9"
using Compat: stack
end

# Sadly, Tapir.jl cannot be installed on version 1.6, so we have to add it if we're testing
# on at least version 1.10.
# Mooncake.jl cannot be installed on version 1.6, so we have to add it if we're
# testing on at least version 1.10.
if VERSION >= v"1.10"
using Pkg
Pkg.add("Tapir")
using Tapir
Pkg.add("Mooncake")
using Mooncake
end

const GROUP = get(ENV, "GROUP", "All")
Expand Down

0 comments on commit dd05dc2

Please sign in to comment.