Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Writing an rrule for an AbstractKernel differentiated with Enzyme #665

Closed
renatobellotti opened this issue Aug 19, 2022 · 5 comments
Closed

Comments

@renatobellotti
Copy link

Hi,

I'm new to Julia and the auto-diff world. My goal is to write differentiable GPU kernels for which I don't have an analytical solution. Writing the kernels and differentiating them works in Enzyme.jl. Now I'm trying to compose two kernels. Additionally, I would like to implement a loss function that converts the kernel output, which is a vector, to a scalar. I guess if I can manage to compose two kernels, I can interface them with any AD library compatible with ChainRules.

I've started to write an example , but I'm getting type errors for CuArrays:

using ChainRulesCore
using CUDA
using CUDAKernels
using Enzyme
using KernelAbstractions
using KernelGradients
using Zygote


# Two kernels to be called one after the other.
@kernel function example_kernel(x, y, z)
    i = @index(Global)
    if(i == 1)
        z[i] = 2 * x[i] + y[i]
    elseif (i == 2)
        z[i] = 3 * x[i] + y[i]
    elseif (i == 3)
        z[i] = 4 * x[i] + y[i]
    elseif (i == 4)
        z[i] = 5 * x[i] + y[i]
    end
    nothing
end


@kernel function example_kernel2(z, a, result)
    i = @index(Global)
    result[i] = 3 * z[i] + a[i]
    nothing
end


# Function calls to allow easier high-level code.
function call_example_kernel1(x, y)
    z = similar(x)
    fill!(z, 1)

    kernel = example_kernel(CUDADevice())
    event = kernel(x, y, z, ndrange=4)
    wait(event)
    return z
end


function call_example_kernel2(z, a)
    result = similar(x)
    fill!(result, 1)

    kernel = example_kernel2(CUDADevice())
    event = kernel(z, a, result, ndrange=4)
    wait(event)
    return result
end


function call_all(x, y, a)
    z = call_example_kernel1(x, y)
    result = call_example_kernel2(z, a)
    
    return result
end


# rrule for ChainRules.
function ChainRulesCore.rrule(::typeof(call_example_kernel1), x, y)
    # Allocate shadow memory.
    dz_dx = similar(x)
    fill!(dz_dx, 0)
    dz_dy = similar(y)
    fill!(dz_dy, 0)
    
    z = call_example_kernel1(x, y)
    r = similar(x)
    fill!(r, 1)

    dx = Duplicated(x, dz_dx)
    dy = Duplicated(y, dz_dy)
    dz = Duplicated(z, r)
    
    function calculate_z_pullback(z̄)
        gpu_kernel_autodiff = autodiff(example_kernel(CUDADevice()))
        event = gpu_kernel_autodiff(dx, dy, dz, ndrange=4)
        wait(event)
        
        f̄ = NoTangent()
        x̄ = Tangent(dx.dval)
        ȳ = Tangent(dy.dval)
        
        return f̄, x̄, ȳ
    end
    
    return z, calculate_z_pullback
end


# Example input.
x = cu([1., 2, 3, 4])
y = cu([5., 6, 7, 8])
a = cu([9., 10, 11, 12])

# This works:
# call_all(x, y, a)

# This does not work:
Zygote.jacobian(call_example_kernel1, x, y)

Here is the error message:

MethodError: no method matching Tangent(::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})

Stacktrace:
  [1] (::var"#calculate_z_pullback#2"{Duplicated{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Duplicated{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Duplicated{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}})(z̄::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Main ./In[2]:86
  [2] ZBack
    @ ~/.julia/packages/Zygote/D7j8v/src/compiler/chainrules.jl:205 [inlined]
  [3] (::Zygote.var"#206#207"{Tuple{Tuple{Nothing, Nothing}}, Zygote.ZBack{var"#calculate_z_pullback#2"{Duplicated{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Duplicated{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Duplicated{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/lib/lib.jl:205
  [4] (::Zygote.var"#1894#back#208"{Zygote.var"#206#207"{Tuple{Tuple{Nothing, Nothing}}, Zygote.ZBack{var"#calculate_z_pullback#2"{Duplicated{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Duplicated{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}, Duplicated{CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}}}}}})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67
  [5] Pullback
    @ ./operators.jl:1085 [inlined]
  [6] (::typeof(∂(#_#83)))(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
  [7] (::Zygote.var"#206#207"{Tuple{Tuple{Nothing, Nothing}, Tuple{Nothing, Nothing}}, typeof(∂(#_#83))})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/lib/lib.jl:205
  [8] #1894#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
  [9] Pullback
    @ ./operators.jl:1085 [inlined]
 [10] (::typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(call_example_kernel1)}(Zygote._jvec, call_example_kernel1))))(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface2.jl:0
 [11] (::Zygote.var"#60#61"{typeof(∂(ComposedFunction{typeof(Zygote._jvec), typeof(call_example_kernel1)}(Zygote._jvec, call_example_kernel1)))})(Δ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/compiler/interface.jl:41
 [12] withjacobian(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/lib/grad.jl:150
 [13] jacobian(::Function, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, ::CuArray{Float32, 1, CUDA.Mem.DeviceBuffer})
    @ Zygote ~/.julia/packages/Zygote/D7j8v/src/lib/grad.jl:128
 [14] top-level scope
    @ In[2]:105
 [15] eval
    @ ./boot.jl:373 [inlined]
 [16] include_string(mapexpr::typeof(REPL.softscope), mod::Module, code::String, filename::String)
    @ Base ./loading.jl:1196

I appreciate every opinion, tip or other constructive feedback!

@renatobellotti
Copy link
Author

I got it running after some refinements, at least for the first kernel:

function ChainRulesCore.rrule(::typeof(call_example_kernel1), x, y)
    z = call_example_kernel1(x, y)

    function call_example_kernel1_pullback(z̄)
        # Allocate shadow memory.
        dz_dx = similar(x)
        fill!(dz_dx, 0)
        dz_dy = similar(y)
        fill!(dz_dy, 0)

        # Define differentials.
        dx = Duplicated(x, dz_dx)
        dy = Duplicated(y, dz_dy)
        dz = Duplicated(z, z̄)
    
        # AD call.
        gpu_kernel_autodiff = autodiff(example_kernel(CUDADevice()))
        event = gpu_kernel_autodiff(dx, dy, dz, ndrange=4)
        
        # Return differentials of input.= NoTangent()
        x̄ = dx.dval
        ȳ = dy.dval
        
        return f̄, x̄, ȳ
    end
    
    return z, call_example_kernel1_pullback
end

I will post the complete code so that other people can find it using search engines.

@renatobellotti
Copy link
Author

As promised, here is the complete example. If I understood correctly, you need to do the following steps:

  1. Write your kernel.
  2. Write a wrapper function that allocates output memory, calls the kernel without AD and returns the results.
  3. Write an rrule:
    1. Call the wrapper to get the output value.
    2. Write the pullback that constructs the differentials and calls the autodiff of Enzyme.
    3. Return the pullback.
using ChainRulesCore
using CUDA
using CUDAKernels
using Enzyme
using KernelAbstractions
using KernelGradients
using Zygote


# Two kernels to be called one after the other.
@kernel function example_kernel(x, y, z)
    i = @index(Global)
    if(i == 1)
        z[i] = 2 * x[i] + y[i]
    elseif (i == 2)
        z[i] = 3 * x[i] + y[i]
    elseif (i == 3)
        z[i] = 4 * x[i] + y[i]
    elseif (i == 4)
        z[i] = 5 * x[i] + y[i]
    end
    nothing
end


@kernel function example_kernel2(z, a, result)
    i = @index(Global)
    result[i] = 3 * z[i] + a[i]
    nothing
end


# Function calls to allow easier high-level code.
function call_example_kernel1(x, y)
    z = similar(x)
    fill!(z, 1)

    kernel = example_kernel(CUDADevice())
    event = kernel(x, y, z, ndrange=4)
    wait(event)
    return z
end


function call_example_kernel2(z, a)
    result = similar(x)
    fill!(result, 1)

    kernel = example_kernel2(CUDADevice())
    event = kernel(z, a, result, ndrange=4)
    wait(event)
    return result
end


function call_all(x, y, a)
    z = call_example_kernel1(x, y)
    result = call_example_kernel2(z, a)
    
    return result
end


# rrule for ChainRules.
function ChainRulesCore.rrule(::typeof(call_example_kernel1), x, y)
    z = call_example_kernel1(x, y)

    function call_example_kernel1_pullback(z̄)
        # Allocate shadow memory.
        dz_dx = similar(x)
        fill!(dz_dx, 0)
        dz_dy = similar(y)
        fill!(dz_dy, 0)

        # Define differentials.
        dx = Duplicated(x, dz_dx)
        dy = Duplicated(y, dz_dy)
        dz = Duplicated(z, z̄)
    
        # AD call.
        gpu_kernel_autodiff = autodiff(example_kernel(CUDADevice()))
        event = gpu_kernel_autodiff(dx, dy, dz, ndrange=4)
        
        # Return differentials of input.= NoTangent()
        x̄ = dx.dval
        ȳ = dy.dval
        
        return f̄, x̄, ȳ
    end
    
    return z, call_example_kernel1_pullback
end


function ChainRulesCore.rrule(::typeof(call_example_kernel2), z, a)
    z = call_example_kernel2(z, a)

    function call_example_kernel2_pullback(result_bar)
        # Allocate shadow memory.
        dresult_dz = similar(z)
        fill!(dresult_dz, 0)
        dresult_da = similar(a)
        fill!(dresult_da, 0)

        # Define differentials.
        dz = Duplicated(x, dresult_dz)
        da = Duplicated(y, dresult_da)
        dresult = Duplicated(z, result_bar)
    
        # AD call.
        gpu_kernel_autodiff = autodiff(example_kernel2(CUDADevice()))
        event = gpu_kernel_autodiff(dz, da, dresult, ndrange=4)
        
        # Return differentials of input.= NoTangent()
        z̄ = dz.dval
        ā = da.dval
        
        return f̄, z̄, ā
    end
    
    return z, call_example_kernel2_pullback
end


# Example input.
x = cu([1., 2, 3, 4])
y = cu([5., 6, 7, 8])
a = cu([9., 10, 11, 12])

z = call_example_kernel1(x, y)

# Calculation without gradients:
# call_all(x, y, a)

Jx, Jy = Zygote.jacobian(call_example_kernel1, x, y)

@show Jx;
@show Jy;

Jz, Ja = Zygote.jacobian(call_example_kernel2, z, a)

@show Jz;
@show Ja;

Jx, Jy, Ja = Zygote.jacobian(call_all, x, y, a)

@show Jx;
@show Jy;
@show Jz;

This seems rather inefficient to me in terms of repeated calculations. Is there a more efficient way to do it?

@vchuravy
Copy link

Is there a more efficient way to do it?

Not currently. Enzyme expects the gradient values to be passed in and ChainRules expect them as the output. So there needs to be some mapping from Enzyme convention to ChainRules convention.

@renatobellotti
Copy link
Author

Ok, thanks. Then I'm closing it for now.

@vchuravy
Copy link

This would also be a great addition to the Enzyme! Generally how to use Enzyme to add a ChainRule. Could you add your code about to the Examples section?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants