Skip to content

Commit

Permalink
More bfloat fix (#1728)
Browse files Browse the repository at this point in the history
* More bfloat fix

* Update bfloat16s.jl

* Update bfloat16s.jl
  • Loading branch information
wsmoses authored Aug 12, 2024
1 parent cf619b3 commit 0cf47c1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ext/EnzymeBFloat16sExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using BFloat16s
using Enzyme

function Enzyme.typetree_inner(::Type{BFloat16}, ctx, dl, seen::Enzyme.Compiler.TypeTreeTable)
return TypeTree(Enzyme.API.DT_BFloat16, -1, ctx)
return Enzyme.TypeTree(Enzyme.API.DT_BFloat16, -1, ctx)
end

end
9 changes: 4 additions & 5 deletions src/typetree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,6 @@ function typetree_inner(::Type{BigFloat}, ctx, dl, seen::TypeTreeTable)
return TypeTree()
end

function typetree_inner(::Type{T}, ctx, dl, seen::TypeTreeTable) where {T<:AbstractFloat}
GPUCompiler.@safe_warn "Unknown floating point type" T
return TypeTree()
end

function typetree_inner(::Type{<:DataType}, ctx, dl, seen::TypeTreeTable)
return TypeTree()
end
Expand Down Expand Up @@ -225,6 +220,10 @@ function typetree_inner(@nospecialize(T), ctx, dl, seen::TypeTreeTable)
end
end

if T <: AbstractFloat
throw(AssertionError("Unknown floating point type $T"))
end

try
fieldcount(T)
catch
Expand Down

0 comments on commit 0cf47c1

Please sign in to comment.