-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Writing an rrule for an AbstractKernel differentiated with Enzyme #665
Comments
I got it running after some refinements, at least for the first kernel: function ChainRulesCore.rrule(::typeof(call_example_kernel1), x, y)
z = call_example_kernel1(x, y)
function call_example_kernel1_pullback(z̄)
# Allocate shadow memory.
dz_dx = similar(x)
fill!(dz_dx, 0)
dz_dy = similar(y)
fill!(dz_dy, 0)
# Define differentials.
dx = Duplicated(x, dz_dx)
dy = Duplicated(y, dz_dy)
dz = Duplicated(z, z̄)
# AD call.
gpu_kernel_autodiff = autodiff(example_kernel(CUDADevice()))
event = gpu_kernel_autodiff(dx, dy, dz, ndrange=4)
# Return differentials of input.
f̄ = NoTangent()
x̄ = dx.dval
ȳ = dy.dval
return f̄, x̄, ȳ
end
return z, call_example_kernel1_pullback
end I will post the complete code so that other people can find it using search engines. |
As promised, here is the complete example. If I understood correctly, you need to do the following steps:
using ChainRulesCore
using CUDA
using CUDAKernels
using Enzyme
using KernelAbstractions
using KernelGradients
using Zygote
# Two kernels to be called one after the other.
@kernel function example_kernel(x, y, z)
i = @index(Global)
if(i == 1)
z[i] = 2 * x[i] + y[i]
elseif (i == 2)
z[i] = 3 * x[i] + y[i]
elseif (i == 3)
z[i] = 4 * x[i] + y[i]
elseif (i == 4)
z[i] = 5 * x[i] + y[i]
end
nothing
end
@kernel function example_kernel2(z, a, result)
i = @index(Global)
result[i] = 3 * z[i] + a[i]
nothing
end
# Function calls to allow easier high-level code.
function call_example_kernel1(x, y)
z = similar(x)
fill!(z, 1)
kernel = example_kernel(CUDADevice())
event = kernel(x, y, z, ndrange=4)
wait(event)
return z
end
function call_example_kernel2(z, a)
result = similar(x)
fill!(result, 1)
kernel = example_kernel2(CUDADevice())
event = kernel(z, a, result, ndrange=4)
wait(event)
return result
end
function call_all(x, y, a)
z = call_example_kernel1(x, y)
result = call_example_kernel2(z, a)
return result
end
# rrule for ChainRules.
function ChainRulesCore.rrule(::typeof(call_example_kernel1), x, y)
z = call_example_kernel1(x, y)
function call_example_kernel1_pullback(z̄)
# Allocate shadow memory.
dz_dx = similar(x)
fill!(dz_dx, 0)
dz_dy = similar(y)
fill!(dz_dy, 0)
# Define differentials.
dx = Duplicated(x, dz_dx)
dy = Duplicated(y, dz_dy)
dz = Duplicated(z, z̄)
# AD call.
gpu_kernel_autodiff = autodiff(example_kernel(CUDADevice()))
event = gpu_kernel_autodiff(dx, dy, dz, ndrange=4)
# Return differentials of input.
f̄ = NoTangent()
x̄ = dx.dval
ȳ = dy.dval
return f̄, x̄, ȳ
end
return z, call_example_kernel1_pullback
end
function ChainRulesCore.rrule(::typeof(call_example_kernel2), z, a)
z = call_example_kernel2(z, a)
function call_example_kernel2_pullback(result_bar)
# Allocate shadow memory.
dresult_dz = similar(z)
fill!(dresult_dz, 0)
dresult_da = similar(a)
fill!(dresult_da, 0)
# Define differentials.
dz = Duplicated(x, dresult_dz)
da = Duplicated(y, dresult_da)
dresult = Duplicated(z, result_bar)
# AD call.
gpu_kernel_autodiff = autodiff(example_kernel2(CUDADevice()))
event = gpu_kernel_autodiff(dz, da, dresult, ndrange=4)
# Return differentials of input.
f̄ = NoTangent()
z̄ = dz.dval
ā = da.dval
return f̄, z̄, ā
end
return z, call_example_kernel2_pullback
end
# Example input.
x = cu([1., 2, 3, 4])
y = cu([5., 6, 7, 8])
a = cu([9., 10, 11, 12])
z = call_example_kernel1(x, y)
# Calculation without gradients:
# call_all(x, y, a)
Jx, Jy = Zygote.jacobian(call_example_kernel1, x, y)
@show Jx;
@show Jy;
Jz, Ja = Zygote.jacobian(call_example_kernel2, z, a)
@show Jz;
@show Ja;
Jx, Jy, Ja = Zygote.jacobian(call_all, x, y, a)
@show Jx;
@show Jy;
@show Jz; This seems rather inefficient to me in terms of repeated calculations. Is there a more efficient way to do it? |
Not currently. Enzyme expects the gradient values to be passed in and ChainRules expect them as the output. So there needs to be some mapping from Enzyme convention to ChainRules convention. |
Ok, thanks. Then I'm closing it for now. |
This would also be a great addition to the Enzyme! Generally how to use Enzyme to add a ChainRule. Could you add your code about to the Examples section? |
Hi,
I'm new to Julia and the auto-diff world. My goal is to write differentiable GPU kernels for which I don't have an analytical solution. Writing the kernels and differentiating them works in
Enzyme.jl
. Now I'm trying to compose two kernels. Additionally, I would like to implement a loss function that converts the kernel output, which is a vector, to a scalar. I guess if I can manage to compose two kernels, I can interface them with any AD library compatible with ChainRules.I've started to write an example , but I'm getting type errors for CuArrays:
Here is the error message:
I appreciate every opinion, tip or other constructive feedback!
The text was updated successfully, but these errors were encountered: