Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AutoMooncake type #89

Merged
merged 13 commits into from
Sep 25, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ADTypes"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com>, Guillaume Dalle and contributors"]
version = "1.8.1"
version = "1.9.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
8 changes: 7 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ AutoGTPSA
### Reverse mode

```@docs
AutoMooncake
AutoReverseDiff
AutoTapir
AutoTracker
AutoZygote
```
Expand Down Expand Up @@ -106,3 +106,9 @@ ADTypes.SymbolicMode
```@docs
ADTypes.Auto
```

## Deprecated

```@docs
AutoTapir
```
1 change: 1 addition & 0 deletions src/ADTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ export AutoChainRules,
AutoForwardDiff,
AutoGTPSA,
AutoModelingToolkit,
AutoMooncake,
AutoPolyesterForwardDiff,
AutoReverseDiff,
AutoSymbolics,
Expand Down
39 changes: 29 additions & 10 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,27 @@ function Base.show(io::IO, backend::AutoGTPSA{D}) where {D}
print(io, ")")
end

"""
AutoMooncake

Struct used to select the [Mooncake.jl](https://github.com/compintell/Mooncake.jl) backend for automatic differentiation.

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Constructors

AutoMooncake(; config)

# Fields

- `config`: either `nothing` or an instance of `Mooncake.Config` -- see the docstring of `Mooncake.Config` for more information. `AutoMooncake(; config=nothing)` is equivalent to `AutoMooncake(; config=Mooncake.Config())`, i.e. the default configuration.
"""
Base.@kwdef struct AutoMooncake{Tconfig} <: AbstractADType
config::Tconfig
end

mode(::AutoMooncake) = ReverseMode()

"""
AutoPolyesterForwardDiff{chunksize,T}

Expand Down Expand Up @@ -323,7 +344,11 @@ mode(::AutoSymbolics) = SymbolicMode()
"""
AutoTapir

Struct used to select the [Tapir.jl](https://github.com/withbayes/Tapir.jl) backend for automatic differentiation.
!!! danger

`AutoTapir` is deprecated following a package renaming, please use [`AutoMooncake`](@ref) instead.

Struct used to select the [Tapir.jl](https://github.com/compintell/Tapir.jl) backend for automatic differentiation.

Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

Expand All @@ -333,16 +358,10 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl).

# Fields

- `safe_mode::Bool`: whether to run additional checks to catch errors early. While this is
on by default to ensure that users are aware of this option, you should generally turn
it off for actual use, as it has substantial performance implications.
If you encounter a problem with using Tapir (it fails to differentiate a function, or
something truly nasty like a segfault occurs), then you should try switching `safe_mode`
on and look at what happens. Often errors are caught earlier and the error messages are
more useful.
- `safe_mode::Bool`: whether to run additional checks to catch errors early.
"""
Base.@kwdef struct AutoTapir <: AbstractADType
safe_mode::Bool = true
struct AutoTapir <: AbstractADType
safe_mode::Bool
end

mode(::AutoTapir) = ReverseMode()
Expand Down
7 changes: 7 additions & 0 deletions src/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,10 @@ function AutoModelingToolkit(; obj_sparse::Bool = false, cons_sparse::Bool = fal
:AutoModelingToolkit; force = false)
return mtk_to_symbolics(obj_sparse, cons_sparse)
end

function AutoTapir(; safe_mode=true)
Base.depwarn(
"`AutoTapir` is deprecated in favour of `AutoMooncake`.", :AutoTapir; force=false
)
return AutoTapir(safe_mode)
end
8 changes: 8 additions & 0 deletions test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,14 @@ end
@test ad.descriptor == Val(:descriptor)
end

@testset "AutoMooncake" begin
ad = AutoMooncake(; config=nothing)
@test ad isa AbstractADType
@test ad isa AutoMooncake
@test mode(ad) isa ReverseMode
@test ad.config === nothing
end

@testset "AutoPolyesterForwardDiff" begin
ad = AutoPolyesterForwardDiff()
@test ad isa AbstractADType
Expand Down
5 changes: 5 additions & 0 deletions test/legacy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,8 @@ end
ad = @test_deprecated AutoReverseDiff(true)
@test ad.compile
end

@testset "AutoTapir" begin
@test_deprecated AutoTapir()
gdalle marked this conversation as resolved.
Show resolved Hide resolved
@test_deprecated AutoTapir(; safe_mode=false)
end
Loading