Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect gradients with BFloat16 #1695

Closed
avik-pal opened this issue Aug 2, 2024 · 8 comments
Closed

Incorrect gradients with BFloat16 #1695

avik-pal opened this issue Aug 2, 2024 · 8 comments
Labels
help wanted Extra attention is needed

Comments

@avik-pal
Copy link
Contributor

avik-pal commented Aug 2, 2024

julia> using BFloat16s, Enzyme

julia> x = rand(BFloat16, 24);

julia> f(x) = sum(x);

julia> Enzyme.gradient(Reverse, f, x)
┌ Warning: Unknown floating point type
│   T = BFloat16
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BFloat16
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BFloat16
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BFloat16
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BFloat16
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BFloat16
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
┌ Warning: Unknown floating point type
│   T = BFloat16
└ @ Enzyme ~/.julia/packages/GPUCompiler/Y4hSX/src/utils.jl:59
24-element Vector{BFloat16}:
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0
 0.0

julia> using Zygote

julia> Zygote.gradient(f, x)
(Fill(BFloat16(1.0), 24),)

I have a feeling it is also crashing my LuxLib tests with an Assertion 😓 but I will try to minimize that later. https://github.com/LuxDL/LuxLib.jl/actions/runs/10210191155/job/28249425224?pr=115#step:6:963

@vchuravy vchuravy added the help wanted Extra attention is needed label Aug 2, 2024
@vchuravy
Copy link
Member

vchuravy commented Aug 2, 2024

I am not even sure we fully support Float16 proper. So this would likely require a bit of work from an intrepid hero across Enzyme proper and Enzyme.jl

@wsmoses
Copy link
Member

wsmoses commented Aug 4, 2024

So most of Enzyme internally is invariant to floating point types. I can see a couple of places where bfloat16 may be missing, but I don't think it's that bad.

@avik-pal immediately adding a corresponding conversion of BFloat16 in typetree.jl instead of emitting that error message would at least fix the 0's [since otherwise it is presently detected as inactive and/or internal error]

@avik-pal
Copy link
Contributor Author

avik-pal commented Aug 4, 2024

Should it be API.DT_Half or does the bit representation need to be specified somewhere?

@vchuravy
Copy link
Member

vchuravy commented Aug 4, 2024

I think we would need to add DT_BFLOAT in enzyme proper otherwise Enzyme would emit half ops in reverse mode xD

@vchuravy
Copy link
Member

vchuravy commented Aug 4, 2024

I think you would only need to update CApi.cpp for that though and then grep through the Enzyme core source for Half and add things like https://github.com/EnzymeAD/Enzyme/blob/9d6a86b46086a2009725f8bfb9af89e7dc7168f6/enzyme/Enzyme/TypeAnalysis/ConcreteType.h#L99

@wsmoses
Copy link
Member

wsmoses commented Aug 4, 2024

Yeah basically we just need to add something to the capi that shim's bfloat https://github.com/EnzymeAD/Enzyme/blob/9d6a86b46086a2009725f8bfb9af89e7dc7168f6/enzyme/Enzyme/CApi.cpp#L94 and https://github.com/EnzymeAD/Enzyme/blob/9d6a86b46086a2009725f8bfb9af89e7dc7168f6/enzyme/Enzyme/CApi.h#L58

This is only really needed for the c api, as the rest of enzyme's type trees use actual types [i.e. bfloat directly]

@wsmoses
Copy link
Member

wsmoses commented Aug 6, 2024

x/ref EnzymeAD/Enzyme#2033

@wsmoses
Copy link
Member

wsmoses commented Aug 8, 2024

hypothetically resolved by #1708

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants