diff --git a/Project.toml b/Project.toml index 480409a3f0..b1c9a8ed17 100644 --- a/Project.toml +++ b/Project.toml @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5" CEnum = "0.4, 0.5" ChainRulesCore = "1" EnzymeCore = "0.8.4" -Enzyme_jll = "0.0.155" +Enzyme_jll = "0.0.156" GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1" LLVM = "6.1, 7, 8, 9" LogExpFunctions = "0.3" diff --git a/src/compiler.jl b/src/compiler.jl index cf879000b5..36ba2c8656 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3816,30 +3816,6 @@ function enzyme!( LLVM.API.LLVMValueRef, ) ), - "julia.gc_loaded" => @cfunction( - inoutgcloaded_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), - "julia.pointer_from_objref" => @cfunction( - inout_rule, - UInt8, - ( - Cint, - API.CTypeTreeRef, - Ptr{API.CTypeTreeRef}, - Ptr{API.IntList}, - Csize_t, - LLVM.API.LLVMValueRef, - ) - ), "jl_inactive_inout" => @cfunction( inout_rule, UInt8, diff --git a/src/rules/llvmrules.jl b/src/rules/llvmrules.jl index edb2bcd6e6..667d6f8abb 100644 --- a/src/rules/llvmrules.jl +++ b/src/rules/llvmrules.jl @@ -628,13 +628,11 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin if memory if fwd - shadowsrc = inttoptr!(B, memoryptr, LLVM.PointerType(LLVM.IntType(8))) lookup_src = false + shadowsrc = invert_pointer(gutils, memoryptr, B) else - shadowsrc = invert_pointer(gutils, shadowsrc, B) - if !fwd - shadowsrc = lookup_value(gutils, shadowsrc, B) - end + shadowsrc = invert_pointer(gutils, shadowsrc, B) + shadowsrc = lookup_value(gutils, shadowsrc, B) end else shadowsrc = invert_pointer(gutils, shadowsrc, B) @@ -674,12 +672,13 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin # src already has done the lookup from the argument shadowsrc0 = if lookup_src if memory + # TODO this may not be at the same offset as the start of the copy, e.g. get_memory_data(src) != memoryptr get_memory_data(B, evsrc) else get_array_data(B, evsrc) end else - evsrc + inttoptr!(B, evsrc, LLVM.PointerType(LLVM.IntType(8))) end shadowdst0 = if memory @@ -781,7 +780,7 @@ end false, ) #=lookup=# if is_constant_value(gutils, origops[1]) - elSize = get_array_elsz(B, ev) + elSize = get_memory_elsz(B, ev) elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t))) length = LLVM.mul!(B, len, elSize) bt = GPUCompiler.backtrace(orig) @@ -792,7 +791,7 @@ end GPUCompiler.@safe_warn "TODO forward zero-set of memorycopy used memset rather than runtime type $btstr" LLVM.memset!( B, - ev2, + inttoptr!(B, ev2, LLVM.PointerType(LLVM.IntType(8))), LLVM.ConstantInt(i8, 0, false), length, algn, @@ -838,7 +837,7 @@ end shadowres = LLVM.Value(unsafe_load(shadowR)) len = new_from_original(gutils, origops[3]) - memoryptr = new_from_original(gutils, origops[2]) + memoryptr = origops[2] arraycopy_common(true, B, orig, origops[1], gutils, shadowres; len, memoryptr) end @@ -849,7 +848,7 @@ end origops = LLVM.operands(orig) if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig) len = new_from_original(gutils, origops[3]) - memoryptr = new_from_original(gutils, origops[2]) + memoryptr = origops[2] arraycopy_common(false, B, orig, origops[1], gutils, nothing; len, memoryptr) end diff --git a/src/rules/typerules.jl b/src/rules/typerules.jl index de11d3c1cd..2cba33c14e 100644 --- a/src/rules/typerules.jl +++ b/src/rules/typerules.jl @@ -92,43 +92,3 @@ function inoutcopyslice_rule( end return UInt8(false) end - -function inoutgcloaded_rule( - direction::Cint, - ret::API.CTypeTreeRef, - args::Ptr{API.CTypeTreeRef}, - known_values::Ptr{API.IntList}, - numArgs::Csize_t, - val::LLVM.API.LLVMValueRef, -)::UInt8 - if numArgs != 1 - return UInt8(false) - end - inst = LLVM.Instruction(val) - - legal, typ = abs_typeof(inst) - - if legal - if (direction & API.DOWN) != 0 - ctx = LLVM.context(inst) - dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst))))) - if GPUCompiler.deserves_retbox(typ) - typ = Ptr{typ} - end - rest = typetree(typ, ctx, dl) - changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest) - @assert legal - end - return UInt8(false) - end - - if (direction & API.UP) != 0 - changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args, 2), ret) - @assert legal - end - if (direction & API.DOWN) != 0 - changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args, 2)) - @assert legal - end - return UInt8(false) -end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index cd28ae9a20..0b441b1b25 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -253,24 +253,73 @@ export codegen_world_age if VERSION >= v"1.11.0-DEV.1552" + +const prevmethodinstance = GPUCompiler.generic_methodinstance + +function methodinstance_generator(world::UInt, source, self, ft::Type, tt::Type) + @nospecialize + @assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt) + ft = ft.parameters[1] + tt = tt.parameters[1] + + stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ft, :tt), Core.svec()) + + # look up the method match + method_error = :(throw(MethodError(ft, tt, $world))) + sig = Tuple{ft, tt.parameters...} + min_world = Ref{UInt}(typemin(UInt)) + max_world = Ref{UInt}(typemax(UInt)) + match = ccall(:jl_gf_invoke_lookup_worlds, Any, + (Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}), + sig, #=mt=# nothing, world, min_world, max_world) + match === nothing && return stub(world, source, method_error) + + # look up the method and code instance + mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, + (Any, Any, Any), match.method, match.spec_types, match.sparams) + ci = Core.Compiler.retrieve_code_info(mi, world) + + # prepare a new code info + new_ci = copy(ci) + empty!(new_ci.code) + empty!(new_ci.codelocs) + empty!(new_ci.linetable) + empty!(new_ci.ssaflags) + new_ci.ssavaluetypes = 0 + + # propagate edge metadata + new_ci.min_world = min_world[] + new_ci.max_world = max_world[] + new_ci.edges = MethodInstance[mi] + + # prepare the slots + new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt] + new_ci.slotflags = UInt8[0x00 for i = 1:3] + + # return the method instance + push!(new_ci.code, Core.Compiler.ReturnNode(mi)) + push!(new_ci.ssaflags, 0x00) + push!(new_ci.linetable, GPUCompiler.@LineInfoNode(methodinstance)) + push!(new_ci.codelocs, 1) + new_ci.ssavaluetypes += 1 + + return new_ci +end + +@eval function prevmethodinstance(ft, tt) + $(Expr(:meta, :generated_only)) + $(Expr(:meta, :generated, methodinstance_generator)) +end + # XXX: version of Base.method_instance that uses a function type @inline function my_methodinstance(@nospecialize(ft::Type), @nospecialize(tt::Type), world::Integer=tls_world_age()) sig = GPUCompiler.signature_type_by_tt(ft, tt) - # @assert Base.isdispatchtuple(sig) # JuliaLang/julia#52233 - - mi = ccall(:jl_method_lookup_by_tt, Any, - (Any, Csize_t, Any), - sig, world, #=method_table=# nothing) - mi === nothing && throw(MethodError(ft, tt, world)) - mi = mi::MethodInstance - - # `jl_method_lookup_by_tt` and `jl_method_lookup` can return a unspecialized mi - if !Base.isdispatchtuple(mi.specTypes) - mi = Core.Compiler.specialize_method(mi.def, sig, mi.sparam_vals)::MethodInstance + if Base.isdispatchtuple(sig) # JuliaLang/julia#52233 + return GPUCompiler.methodinstance(ft, tt, world) + else + return prevmethodinstance(ft, tt, world) end - - return mi end else import GPUCompiler: methodinstance as my_methodinstance diff --git a/test/internal_rules.jl b/test/internal_rules.jl index 3635ce07e2..fb5926a1d3 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -2,8 +2,6 @@ module InternalRules using Enzyme using Enzyme.EnzymeRules -using EnzymeTestUtils -using FiniteDifferences using LinearAlgebra using SparseArrays using Test @@ -155,6 +153,7 @@ function tr_solv(A, B, uplo, trans, diag, idx) end +using FiniteDifferences @testset "Reverse triangular solve" begin A = [0.7550523937508613 0.7979976952197996 0.29318222271218364; 0.4416768066117529 0.4335305304334933 0.8895389673238051; 0.07752980210005678 0.05978245503334367 0.4504482683752542] B = [0.10527381151977078 0.5450388247476627 0.3179106723232359 0.43919576779182357 0.20974326586875847; 0.7551160501548224 0.049772782182839426 0.09284926395551141 0.07862188927391855 0.17346407477062986; 0.6258040138863172 0.5928022963567454 0.24251650865340169 0.6626410383247967 0.32752198021506784] @@ -576,6 +575,7 @@ end @test Enzyme.gradient(Reverse, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] end +using EnzymeTestUtils @testset "Linear solve for triangular matrices" begin @testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular), TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3))