Skip to content

Commit

Permalink
Cleanup (#1872)
Browse files Browse the repository at this point in the history
* Return type config

* cleanup

* Update runtests.jl

---------

Co-authored-by: William Moses <wsmoses@cyclops.juliacomputing.io>
  • Loading branch information
wsmoses and William Moses authored Sep 21, 2024
1 parent f14bd4a commit 00037e7
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 80 deletions.
2 changes: 1 addition & 1 deletion lib/EnzymeCore/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "EnzymeCore"
uuid = "f151be2c-9106-41f4-ab19-57ee4f262869"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.8.1"
version = "0.8.2"

[compat]
Adapt = "3, 4"
Expand Down
15 changes: 4 additions & 11 deletions lib/EnzymeCore/src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,21 @@ const RevConfigWidth{Width} = RevConfig{<:Any,<:Any, Width}
@inline runtime_activity(::RevConfig{<:Any, <:Any, <:Any, <:Any, RuntimeActivity}) where RuntimeActivity = RuntimeActivity

"""
primal_type(::FwdConfig, ::Type{<:Annotation{RT}})
primal_type(::RevConfig, ::Type{<:Annotation{RT}})
Compute the exepcted primal return type given a reverse mode config and return activity
"""
@inline primal_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing
@inline primal_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_primal(config) ? RT : Nothing

"""
shadow_type(::FwdConfig, ::Type{<:Annotation{RT}})
shadow_type(::RevConfig, ::Type{<:Annotation{RT}})
Compute the exepcted shadow return type given a reverse mode config and return activity
"""
@inline shadow_type(config::FwdConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing
@inline shadow_type(config::RevConfig, ::Type{<:Annotation{RT}}) where RT = needs_shadow(config) ? (width(config) == 1 ? RT : NTuple{width(config), RT}) : Nothing

"""
Expand Down Expand Up @@ -191,9 +195,6 @@ function isapplicable(@nospecialize(f), @nospecialize(TT);
caller::Union{Nothing,Core.MethodInstance}=nothing)
tt = Base.to_tuple_type(TT)
sig = Base.signature_type(f, tt)
@static if VERSION < v"1.7.0"
return !isempty(Base._methods_by_ftype(sig, -1, world))
end
mt = ccall(:jl_method_table_for, Any, (Any,), sig)
mt isa Core.MethodTable || return false
if method_table === nothing
Expand Down Expand Up @@ -234,14 +235,6 @@ function add_mt_backedge!(caller::Core.MethodInstance, mt::Core.MethodTable, @no
return nothing
end

function issupported()
@static if VERSION < v"1.7.0"
return false
else
return true
end
end

"""
inactive(func::typeof(f), args...)
Expand Down
10 changes: 2 additions & 8 deletions lib/EnzymeTestUtils/src/generate_tangent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,8 @@ end

# get around the constructors and make the type directly
# Note this is moderately evil accessing julia's internals
if VERSION >= v"1.3"
@generated function _force_construct(T, args...)
return Expr(:splatnew, :T, :args)
end
else
@generated function _force_construct(T, args...)
return Expr(:new, :T, Any[:(args[$i]) for i in 1:length(args)]...)
end
@generated function _force_construct(T, args...)
return Expr(:splatnew, :T, :args)
end

function _construct(T, args...)
Expand Down
7 changes: 1 addition & 6 deletions lib/EnzymeTestUtils/src/test_reverse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@ for N in 1:30
eval(quote
function call_with_kwargs(fkwargs::NT, f::FT, $(argexprs...)) where {NT, FT}
Base.@_inline_meta
@static if VERSION v"1.8"
# callsite inline syntax unsupported in <= 1.8
f($(argexprs...); fkwargs...)
else
@inline f($(argexprs...); fkwargs...)
end
@inline f($(argexprs...); fkwargs...)
end
end)
end
Expand Down
53 changes: 16 additions & 37 deletions lib/EnzymeTestUtils/test/test_forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ end
elseif TT <: NamedTuple
x = (a=randn(T), b=randn(T))
else # TT <: TestStruct
if VERSION <= v"1.8" && Tx == BatchDuplicated
continue
end
x = TestStruct(randn(T, 5), randn(T))
end
atol = rtol = sqrt(eps(real(T)))
Expand Down Expand Up @@ -117,38 +114,26 @@ end
a = randn(T)
atol = rtol = sqrt(eps(real(T)))

if VERSION < v"1.8" && (
Tret <: BatchDuplicated ||
Tx <: BatchDuplicated ||
Ta <: BatchDuplicated
)
@test !fails() do
test_forward(f_multiarg, Tret, (x, Tx), (a, Ta); atol, rtol)
end skip = true
else
@test !fails() do
test_forward(f_multiarg, Tret, (x, Tx), (a, Ta); atol, rtol)
end broken = (
VERSION < v"1.8" && Tx <: Const && !(Ta <: Const) && T <: Complex
)
end
@test !fails() do
test_forward(f_multiarg, Tret, (x, Tx), (a, Ta); atol, rtol)
end
end
end

VERSION >= v"1.8" && @testset "structured array inputs/outputs" begin
@testset for Tret in (Const, Duplicated, BatchDuplicated),
Tx in (Const, Duplicated, BatchDuplicated),
T in (Float32, Float64, ComplexF32, ComplexF64)
@testset "structured array inputs/outputs" begin
@testset for Tret in (Const, Duplicated, BatchDuplicated),
Tx in (Const, Duplicated, BatchDuplicated),
T in (Float32, Float64, ComplexF32, ComplexF64)

# if some are batch, none must be duplicated
are_activities_compatible(Tret, Tx) || continue
# if some are batch, none must be duplicated
are_activities_compatible(Tret, Tx) || continue

x = Hermitian(randn(T, 5, 5))
x = Hermitian(randn(T, 5, 5))

atol = rtol = sqrt(eps(real(T)))
test_forward(f_structured_array, Tret, (x, Tx); atol, rtol)
end
end
atol = rtol = sqrt(eps(real(T)))
test_forward(f_structured_array, Tret, (x, Tx); atol, rtol)
end
end

@testset "equivalent arrays in output" begin
function f(x)
Expand Down Expand Up @@ -197,7 +182,7 @@ end
atol = rtol = sqrt(eps(real(T)))
@test !fails() do
test_forward(f_mut_fwd!, Tret, (y, Ty), (x, Tx), (a, Ta); atol, rtol, runtime_activity=true)
end skip = (VERSION < v"1.8" && T <: Complex)
end
end
end

Expand Down Expand Up @@ -230,13 +215,7 @@ end
atol = rtol = sqrt(eps(real(T)))
@test !fails() do
test_forward((c, Tc), Tret, (y, Ty); atol, rtol)
end skip = (
VERSION < v"1.8" && (
Tret <: BatchDuplicated ||
Tc <: BatchDuplicated ||
Ty <: BatchDuplicated
)
)
end
end
end
end
Expand Down
8 changes: 4 additions & 4 deletions src/rules/customrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ end
if RT <: Const
if needsPrimal
if RealRt != fwd_RT
emit_error(B, orig, "Enzyme: incorrect return type of const primal-only forward custom rule - "*(string(RT))*" "*string(activity)*" want just return type "*string(RealRt)*" found "*string(fwd_RT))
emit_error(B, orig, "Enzyme: incorrect return type of const primal-only forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just return type "*string(RealRt)*" found "*string(fwd_RT))
return false
end
if get_return_info(RealRt)[2] !== nothing
Expand All @@ -508,7 +508,7 @@ end
end
else
if Nothing != fwd_RT
emit_error(B, orig, "Enzyme: incorrect return type of const no-primal forward custom rule - "*(string(RT))*" "*string(activity)*" want just return type Nothing found "*string(fwd_RT))
emit_error(B, orig, "Enzyme: incorrect return type of const no-primal forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just return type Nothing found "*string(fwd_RT))
return false
end
end
Expand All @@ -519,7 +519,7 @@ end
ST = NTuple{Int(width), ST}
end
if ST != fwd_RT
emit_error(B, orig, "Enzyme: incorrect return type of shadow-only forward custom rule - "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT))
emit_error(B, orig, "Enzyme: incorrect return type of shadow-only forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT))
return false
end
if get_return_info(RealRt)[2] !== nothing
Expand All @@ -539,7 +539,7 @@ end
BatchDuplicated{RealRt, Int(width)}
end
if ST != fwd_RT
emit_error(B, orig, "Enzyme: incorrect return type of prima/shadow forward custom rule - "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT))
emit_error(B, orig, "Enzyme: incorrect return type of prima/shadow forward custom rule - $C "*(string(RT))*" "*string(activity)*" want just shadow type "*string(ST)*" found "*string(fwd_RT))
return false
end
if get_return_info(RealRt)[2] !== nothing
Expand Down
15 changes: 7 additions & 8 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,13 @@ end
include("abi.jl")
include("typetree.jl")

@static if Enzyme.EnzymeRules.issupported()
include("rules.jl")
include("rrules.jl")
include("kwrules.jl")
include("kwrrules.jl")
include("internal_rules.jl")
include("ruleinvalidation.jl")
end
include("rules.jl")
include("rrules.jl")
include("kwrules.jl")
include("kwrrules.jl")
include("internal_rules.jl")
include("ruleinvalidation.jl")

@static if !Sys.iswindows()
include("blas.jl")
end
Expand Down
5 changes: 0 additions & 5 deletions test/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,9 @@ end

out = [1.0, 2.0]
dout = [1.0, 1.0]
@static if VERSION < v"1.8"
# GPUCompiler causes a stack overflow due to https://github.com/JuliaGPU/GPUCompiler.jl/issues/587
# @test_throws AssertionError autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0))
else
res = autodiff(Reverse, f_multi, Const, Duplicated(out, dout), Active(2.0))
@test res[1][2] 2.0
end
end

@testset "Closure-less threads $(Threads.nthreads())" begin
function bf(i, x)
Expand Down

2 comments on commit 00037e7

@wsmoses
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir="lib/EnzymeCore"

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/115630

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a EnzymeCore-v0.8.2 -m "<description of version>" 00037e7ff8fb32f36691bbdba5ce8dc251fe2dec
git push origin EnzymeCore-v0.8.2

Please sign in to comment.