Skip to content

Commit

Permalink
fix: BFloat16 extension and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Aug 10, 2024
1 parent 02c5855 commit 48e4f0c
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.12.29"
version = "0.12.30"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand Down
2 changes: 1 addition & 1 deletion ext/EnzymeBFloat16sExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module EnzymeBFloat16sExt
using BFloat16s
using Enzyme

function Enzyme.typetree_inner(::Type{Core.BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable)
function Enzyme.typetree_inner(::Type{BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable)
return TypeTree(Enzyme.API.DT_BFloat16, -1, ctx)
end

Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Expand Down
7 changes: 7 additions & 0 deletions test/ext/bfloat16s.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using Enzyme
using Test
using BFloat16s

@test_broken Enzyme.gradient(Reverse, sum, ones(BFloat16, 10)) ones(BFloat16, 10)

@test_broken Enzyme.gradient(Forward, sum, ones(BFloat16, 10)) ones(BFloat16, 10)
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3693,6 +3693,10 @@ end
include("ext/chainrulescore.jl")
end
include("ext/logexpfunctions.jl")

@testset "BFloat16s ext" begin
include("ext/bfloat16s.jl")
end
end


Expand Down

0 comments on commit 48e4f0c

Please sign in to comment.