Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Higher order differentiation, autodiff_deferred and CustomRules #1059

Closed
Crown421 opened this issue Sep 17, 2023 · 2 comments
Closed

Higher order differentiation, autodiff_deferred and CustomRules #1059

Crown421 opened this issue Sep 17, 2023 · 2 comments

Comments

@Crown421
Copy link

I have been trying to write a custom rule in the context of higher order differentiation. However, I simply cannot get it to work.
Happy to try to track this issue down further, but I don't know where to start.

The (relatively) minimal code is the following. The return of the custom rule is not the correct differential, but it is simple enough to try to work with.

using Enzyme
import .EnzymeRules: forward
Enzyme.API.printall!(true)

fun(x, y) = (x - y)^2

function forward(func::Const{typeof(fun)}, o, x, y::Const)
    println("Custom Rule")
    return Duplicated(x.val + y.val, 1.0)
end

x = 1.
y = 1.1

df(y) = autodiff(Forward, fun, Duplicated, Duplicated(x, 1.0), y)
df(y)

only(autodiff(
    Forward,
    yt -> autodiff_deferred(Forward, fun, Duplicated, Duplicated(x, 1.0), yt),
    Duplicated,
    Duplicated(y, 1.0)))

This code returns the following Error

ERROR: KeyError: key "julia_fun_4703" not found
Stacktrace:
  [1] getindex
    @ ~/.julia/packages/LLVM/Od0DH/src/core/module.jl:245 [inlined]
  [2] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:8874
  [3] codegen
    @ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:8723 [inlined]
  [4] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9671
  [5] _thunk
    @ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9671 [inlined]
  [6] cached_compilation
    @ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9705 [inlined]
  [7] (::Enzyme.Compiler.var"#475#476"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9768
  [8] JuliaContext(f::Enzyme.Compiler.var"#475#476"{DataType, UnionAll, DataType, Enzyme.API.CDerivativeMode, Tuple{Bool, Bool}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/YO8Uj/src/driver.jl:47
  [9] #s292#474
    @ ~/.julia/packages/Enzyme/0SYwj/src/compiler.jl:9723 [inlined]
 [10] var"#s292#474"(FA::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, ReturnPrimal::Any, ShadowInit::Any, World::Any, ABI::Any, ::Any, #unused#::Type, #unused#::Type, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler ./none:0
 [11] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core ./boot.jl:602
 [12] autodiff(#unused#::ForwardMode{FFIABI}, f::Const{var"#25#26"}, #unused#::Type{Duplicated}, args::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/0SYwj/src/Enzyme.jl:328
 [13] autodiff(::ForwardMode{FFIABI}, ::var"#25#26", ::Type, ::Duplicated{Float64})
    @ Enzyme ~/.julia/packages/Enzyme/0SYwj/src/Enzyme.jl:222

and the Enzyme trace

after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4671mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
  %2 = call double @julia_fun_4671(double %0, double %1) #4
  ret double %2
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4671mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
  %2 = alloca [2 x double], align 8
  %3 = call {}*** @julia.get_pgcstack()
  %4 = call {}*** @julia.get_pgcstack()
  %5 = bitcast {}*** %4 to {}**
  %6 = getelementptr inbounds {}*, {}** %5, i64 -13
  %7 = getelementptr inbounds {}*, {}** %6, i64 15
  %8 = bitcast {}** %7 to i8**
  %9 = load i8*, i8** %8, align 8
  %10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
  %11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
  %12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
  %13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
  store double %0, double addrspace(11)* %13, align 8
  %14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
  store double %"'", double addrspace(11)* %14, align 8
  %15 = bitcast {}*** %3 to {}**
  %16 = getelementptr inbounds {}*, {}** %15, i64 -13
  %17 = getelementptr inbounds {}*, {}** %16, i64 15
  %18 = bitcast {}** %17 to i8**
  %19 = load i8*, i8** %18, align 8
  %20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
  %21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
  %22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
  %23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
  store double %1, double addrspace(11)* %23, align 8
  call void @llvm.experimental.noalias.scope.decl(metadata !19)
  %24 = call {}*** @julia.get_pgcstack()
  call fastcc void @julia_println_4676() #8, !dbg !22, !noalias !19
  %25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
  %26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
  %27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %29 = fadd double %27, %28, !dbg !28
  %.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
  store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
  store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %30 = load [2 x double], [2 x double]* %2, align 8
  %31 = extractvalue [2 x double] %30, 0
  %32 = extractvalue [2 x double] %30, 1
  %33 = insertvalue { double, double } undef, double %31, 0
  %34 = insertvalue { double, double } %33, double %32, 1
  ret { double, double } %34
}

after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4681mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
  %2 = call double @julia_fun_4681(double %0, double %1) #4
  ret double %2
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4681mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
  %2 = alloca [2 x double], align 8
  %3 = call {}*** @julia.get_pgcstack()
  %4 = call {}*** @julia.get_pgcstack()
  %5 = bitcast {}*** %4 to {}**
  %6 = getelementptr inbounds {}*, {}** %5, i64 -13
  %7 = getelementptr inbounds {}*, {}** %6, i64 15
  %8 = bitcast {}** %7 to i8**
  %9 = load i8*, i8** %8, align 8
  %10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
  %11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
  %12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
  %13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
  store double %0, double addrspace(11)* %13, align 8
  %14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
  store double %"'", double addrspace(11)* %14, align 8
  %15 = bitcast {}*** %3 to {}**
  %16 = getelementptr inbounds {}*, {}** %15, i64 -13
  %17 = getelementptr inbounds {}*, {}** %16, i64 15
  %18 = bitcast {}** %17 to i8**
  %19 = load i8*, i8** %18, align 8
  %20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
  %21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
  %22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
  %23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
  store double %1, double addrspace(11)* %23, align 8
  call void @llvm.experimental.noalias.scope.decl(metadata !19)
  %24 = call {}*** @julia.get_pgcstack()
  call fastcc void @julia_println_4686() #8, !dbg !22, !noalias !19
  %25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
  %26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
  %27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %29 = fadd double %27, %28, !dbg !28
  %.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
  store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
  store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %30 = load [2 x double], [2 x double]* %2, align 8
  %31 = extractvalue [2 x double] %30, 0
  %32 = extractvalue [2 x double] %30, 1
  %33 = insertvalue { double, double } undef, double %31, 0
  %34 = insertvalue { double, double } %33, double %32, 1
  ret { double, double } %34
}

after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4693mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
  %2 = call double @julia_fun_4693(double %0, double %1) #4
  ret double %2
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4693mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
  %2 = alloca [2 x double], align 8
  %3 = call {}*** @julia.get_pgcstack()
  %4 = call {}*** @julia.get_pgcstack()
  %5 = bitcast {}*** %4 to {}**
  %6 = getelementptr inbounds {}*, {}** %5, i64 -13
  %7 = getelementptr inbounds {}*, {}** %6, i64 15
  %8 = bitcast {}** %7 to i8**
  %9 = load i8*, i8** %8, align 8
  %10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
  %11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
  %12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
  %13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
  store double %0, double addrspace(11)* %13, align 8
  %14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
  store double %"'", double addrspace(11)* %14, align 8
  %15 = bitcast {}*** %3 to {}**
  %16 = getelementptr inbounds {}*, {}** %15, i64 -13
  %17 = getelementptr inbounds {}*, {}** %16, i64 15
  %18 = bitcast {}** %17 to i8**
  %19 = load i8*, i8** %18, align 8
  %20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
  %21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
  %22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
  %23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
  store double %1, double addrspace(11)* %23, align 8
  call void @llvm.experimental.noalias.scope.decl(metadata !19)
  %24 = call {}*** @julia.get_pgcstack()
  call fastcc void @julia_println_4698() #8, !dbg !22, !noalias !19
  %25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
  %26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
  %27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %29 = fadd double %27, %28, !dbg !28
  %.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
  store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
  store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %30 = load [2 x double], [2 x double]* %2, align 8
  %31 = extractvalue [2 x double] %30, 0
  %32 = extractvalue [2 x double] %30, 1
  %33 = insertvalue { double, double } undef, double %31, 0
  %34 = insertvalue { double, double } %33, double %32, 1
  ret { double, double } %34
}

after simplification :
; Function Attrs: mustprogress nofree nosync readnone willreturn
define double @preprocess_julia_fun_4703mustwrap_inner.1(double %0, double %1) local_unnamed_addr #3 {
entry:
  %2 = call double @julia_fun_4703(double %0, double %1) #4
  ret double %2
}

; Function Attrs: mustprogress nofree nosync readnone willreturn
define internal { double, double } @fwddiffejulia_fun_4703mustwrap_inner.1(double %0, double %"'", double %1) local_unnamed_addr #3 {
entry:
  %2 = alloca [2 x double], align 8
  %3 = call {}*** @julia.get_pgcstack()
  %4 = call {}*** @julia.get_pgcstack()
  %5 = bitcast {}*** %4 to {}**
  %6 = getelementptr inbounds {}*, {}** %5, i64 -13
  %7 = getelementptr inbounds {}*, {}** %6, i64 15
  %8 = bitcast {}** %7 to i8**
  %9 = load i8*, i8** %8, align 8
  %10 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %6, i64 16, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675504249552 to {}*) to {} addrspace(10)*))
  %11 = bitcast {} addrspace(10)* %10 to [2 x double] addrspace(10)*
  %12 = addrspacecast [2 x double] addrspace(10)* %11 to [2 x double] addrspace(11)*
  %13 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 0
  store double %0, double addrspace(11)* %13, align 8
  %14 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i32 1
  store double %"'", double addrspace(11)* %14, align 8
  %15 = bitcast {}*** %3 to {}**
  %16 = getelementptr inbounds {}*, {}** %15, i64 -13
  %17 = getelementptr inbounds {}*, {}** %16, i64 15
  %18 = bitcast {}** %17 to i8**
  %19 = load i8*, i8** %18, align 8
  %20 = call {} addrspace(10)* @julia.gc_alloc_obj({}** %16, i64 8, {} addrspace(10)* addrspacecast ({}* inttoptr (i64 140675500583760 to {}*) to {} addrspace(10)*))
  %21 = bitcast {} addrspace(10)* %20 to [1 x double] addrspace(10)*
  %22 = addrspacecast [1 x double] addrspace(10)* %21 to [1 x double] addrspace(11)*
  %23 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i32 0
  store double %1, double addrspace(11)* %23, align 8
  call void @llvm.experimental.noalias.scope.decl(metadata !19)
  %24 = call {}*** @julia.get_pgcstack()
  call fastcc void @julia_println_4708() #8, !dbg !22, !noalias !19
  %25 = getelementptr inbounds [2 x double], [2 x double] addrspace(11)* %12, i64 0, i64 0, !dbg !24
  %26 = getelementptr inbounds [1 x double], [1 x double] addrspace(11)* %22, i64 0, i64 0, !dbg !24
  %27 = load double, double addrspace(11)* %25, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %28 = load double, double addrspace(11)* %26, align 8, !dbg !28, !tbaa !30, !alias.scope !34, !noalias !37
  %29 = fadd double %27, %28, !dbg !28
  %.sroa.0.0..sroa_idx.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 0, !dbg !27
  store double %29, double* %.sroa.0.0..sroa_idx.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %.sroa.2.0..sroa_idx1.i = getelementptr inbounds [2 x double], [2 x double]* %2, i64 0, i64 1, !dbg !27
  store double 1.000000e+00, double* %.sroa.2.0..sroa_idx1.i, align 8, !dbg !27, !alias.scope !19, !noalias !42
  %30 = load [2 x double], [2 x double]* %2, align 8
  %31 = extractvalue [2 x double] %30, 0
  %32 = extractvalue [2 x double] %30, 1
  %33 = insertvalue { double, double } undef, double %31, 0
  %34 = insertvalue { double, double } %33, double %32, 1
  ret { double, double } %34
}
@wsmoses
Copy link
Member

wsmoses commented Sep 18, 2023

So at the moment, I don't think the CustomRules are guaranteed to apply at the higher order level, which is somethign we need to do.

@wsmoses
Copy link
Member

wsmoses commented Sep 21, 2024

This particular issue is fixed by #1877

@wsmoses wsmoses closed this as completed Sep 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants