diff --git a/Project.toml b/Project.toml index ff04441c83..a02ade766c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Enzyme" uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9" authors = ["William Moses ", "Valentin Churavy "] -version = "0.12.29" +version = "0.12.30" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" diff --git a/ext/EnzymeBFloat16sExt.jl b/ext/EnzymeBFloat16sExt.jl index 0fda13617e..c23797ffff 100644 --- a/ext/EnzymeBFloat16sExt.jl +++ b/ext/EnzymeBFloat16sExt.jl @@ -3,7 +3,7 @@ module EnzymeBFloat16sExt using BFloat16s using Enzyme -function Enzyme.typetree_inner(::Type{Core.BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable) +function Enzyme.typetree_inner(::Type{BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable) return TypeTree(Enzyme.API.DT_BFloat16, -1, ctx) end diff --git a/test/Project.toml b/test/Project.toml index 5c8286d1af..a3f8452712 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" diff --git a/test/ext/bfloat16s.jl b/test/ext/bfloat16s.jl new file mode 100644 index 0000000000..0a47f48f03 --- /dev/null +++ b/test/ext/bfloat16s.jl @@ -0,0 +1,7 @@ +using Enzyme +using Test +using BFloat16s + +@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10)) ≈ ones(BFloat16, 10) + +@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10)) ≈ ones(BFloat16, 10) diff --git a/test/runtests.jl b/test/runtests.jl index 94015cfa4f..ac62137d35 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3693,6 +3693,10 @@ end include("ext/chainrulescore.jl") end include("ext/logexpfunctions.jl") + + @testset "BFloat16s ext" begin + include("ext/bfloat16s.jl") + end end