Skip to content

Commit

Permalink
1.11: more methodinstance stuff (#1989)
Browse files Browse the repository at this point in the history
* 1.11: more methodinstance stuff

* fixup

* fix

* fix elsz issue

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Oct 21, 2024
1 parent 72763e9 commit 924a271
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 90 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.4"
Enzyme_jll = "0.0.155"
Enzyme_jll = "0.0.156"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1"
LLVM = "6.1, 7, 8, 9"
LogExpFunctions = "0.3"
Expand Down
24 changes: 0 additions & 24 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3816,30 +3816,6 @@ function enzyme!(
LLVM.API.LLVMValueRef,
)
),
"julia.gc_loaded" => @cfunction(
inoutgcloaded_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"julia.pointer_from_objref" => @cfunction(
inout_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"jl_inactive_inout" => @cfunction(
inout_rule,
UInt8,
Expand Down
19 changes: 9 additions & 10 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,13 +628,11 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin

if memory
if fwd
shadowsrc = inttoptr!(B, memoryptr, LLVM.PointerType(LLVM.IntType(8)))
lookup_src = false
shadowsrc = invert_pointer(gutils, memoryptr, B)
else
shadowsrc = invert_pointer(gutils, shadowsrc, B)
if !fwd
shadowsrc = lookup_value(gutils, shadowsrc, B)
end
shadowsrc = invert_pointer(gutils, shadowsrc, B)
shadowsrc = lookup_value(gutils, shadowsrc, B)
end
else
shadowsrc = invert_pointer(gutils, shadowsrc, B)
Expand Down Expand Up @@ -674,12 +672,13 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin
# src already has done the lookup from the argument
shadowsrc0 = if lookup_src
if memory
# TODO this may not be at the same offset as the start of the copy, e.g. get_memory_data(src) != memoryptr
get_memory_data(B, evsrc)
else
get_array_data(B, evsrc)
end
else
evsrc
inttoptr!(B, evsrc, LLVM.PointerType(LLVM.IntType(8)))
end

shadowdst0 = if memory
Expand Down Expand Up @@ -781,7 +780,7 @@ end
false,
) #=lookup=#
if is_constant_value(gutils, origops[1])
elSize = get_array_elsz(B, ev)
elSize = get_memory_elsz(B, ev)
elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t)))
length = LLVM.mul!(B, len, elSize)
bt = GPUCompiler.backtrace(orig)
Expand All @@ -792,7 +791,7 @@ end
GPUCompiler.@safe_warn "TODO forward zero-set of memorycopy used memset rather than runtime type $btstr"
LLVM.memset!(
B,
ev2,
inttoptr!(B, ev2, LLVM.PointerType(LLVM.IntType(8))),
LLVM.ConstantInt(i8, 0, false),
length,
algn,
Expand Down Expand Up @@ -838,7 +837,7 @@ end
shadowres = LLVM.Value(unsafe_load(shadowR))

len = new_from_original(gutils, origops[3])
memoryptr = new_from_original(gutils, origops[2])
memoryptr = origops[2]
arraycopy_common(true, B, orig, origops[1], gutils, shadowres; len, memoryptr)
end

Expand All @@ -849,7 +848,7 @@ end
origops = LLVM.operands(orig)
if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig)
len = new_from_original(gutils, origops[3])
memoryptr = new_from_original(gutils, origops[2])
memoryptr = origops[2]
arraycopy_common(false, B, orig, origops[1], gutils, nothing; len, memoryptr)
end

Expand Down
40 changes: 0 additions & 40 deletions src/rules/typerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,43 +92,3 @@ function inoutcopyslice_rule(
end
return UInt8(false)
end

function inoutgcloaded_rule(
direction::Cint,
ret::API.CTypeTreeRef,
args::Ptr{API.CTypeTreeRef},
known_values::Ptr{API.IntList},
numArgs::Csize_t,
val::LLVM.API.LLVMValueRef,
)::UInt8
if numArgs != 1
return UInt8(false)
end
inst = LLVM.Instruction(val)

legal, typ = abs_typeof(inst)

if legal
if (direction & API.DOWN) != 0
ctx = LLVM.context(inst)
dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst)))))
if GPUCompiler.deserves_retbox(typ)
typ = Ptr{typ}
end
rest = typetree(typ, ctx, dl)
changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest)
@assert legal
end
return UInt8(false)
end

if (direction & API.UP) != 0
changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args, 2), ret)
@assert legal
end
if (direction & API.DOWN) != 0
changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args, 2))
@assert legal
end
return UInt8(false)
end
75 changes: 62 additions & 13 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,24 +253,73 @@ export codegen_world_age

if VERSION >= v"1.11.0-DEV.1552"


const prevmethodinstance = GPUCompiler.generic_methodinstance

function methodinstance_generator(world::UInt, source, self, ft::Type, tt::Type)
@nospecialize
@assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt)
ft = ft.parameters[1]
tt = tt.parameters[1]

stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ft, :tt), Core.svec())

# look up the method match
method_error = :(throw(MethodError(ft, tt, $world)))
sig = Tuple{ft, tt.parameters...}
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
sig, #=mt=# nothing, world, min_world, max_world)
match === nothing && return stub(world, source, method_error)

# look up the method and code instance
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance},
(Any, Any, Any), match.method, match.spec_types, match.sparams)
ci = Core.Compiler.retrieve_code_info(mi, world)

# prepare a new code info
new_ci = copy(ci)
empty!(new_ci.code)
empty!(new_ci.codelocs)
empty!(new_ci.linetable)
empty!(new_ci.ssaflags)
new_ci.ssavaluetypes = 0

# propagate edge metadata
new_ci.min_world = min_world[]
new_ci.max_world = max_world[]
new_ci.edges = MethodInstance[mi]

# prepare the slots
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
new_ci.slotflags = UInt8[0x00 for i = 1:3]

# return the method instance
push!(new_ci.code, Core.Compiler.ReturnNode(mi))
push!(new_ci.ssaflags, 0x00)
push!(new_ci.linetable, GPUCompiler.@LineInfoNode(methodinstance))
push!(new_ci.codelocs, 1)
new_ci.ssavaluetypes += 1

return new_ci
end

@eval function prevmethodinstance(ft, tt)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, methodinstance_generator))
end

# XXX: version of Base.method_instance that uses a function type
@inline function my_methodinstance(@nospecialize(ft::Type), @nospecialize(tt::Type),
world::Integer=tls_world_age())
sig = GPUCompiler.signature_type_by_tt(ft, tt)
# @assert Base.isdispatchtuple(sig) # JuliaLang/julia#52233

mi = ccall(:jl_method_lookup_by_tt, Any,
(Any, Csize_t, Any),
sig, world, #=method_table=# nothing)
mi === nothing && throw(MethodError(ft, tt, world))
mi = mi::MethodInstance

# `jl_method_lookup_by_tt` and `jl_method_lookup` can return a unspecialized mi
if !Base.isdispatchtuple(mi.specTypes)
mi = Core.Compiler.specialize_method(mi.def, sig, mi.sparam_vals)::MethodInstance
if Base.isdispatchtuple(sig) # JuliaLang/julia#52233
return GPUCompiler.methodinstance(ft, tt, world)
else
return prevmethodinstance(ft, tt, world)
end

return mi
end
else
import GPUCompiler: methodinstance as my_methodinstance
Expand Down
4 changes: 2 additions & 2 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ module InternalRules

using Enzyme
using Enzyme.EnzymeRules
using EnzymeTestUtils
using FiniteDifferences
using LinearAlgebra
using SparseArrays
using Test
Expand Down Expand Up @@ -155,6 +153,7 @@ function tr_solv(A, B, uplo, trans, diag, idx)
end


using FiniteDifferences
@testset "Reverse triangular solve" begin
A = [0.7550523937508613 0.7979976952197996 0.29318222271218364; 0.4416768066117529 0.4335305304334933 0.8895389673238051; 0.07752980210005678 0.05978245503334367 0.4504482683752542]
B = [0.10527381151977078 0.5450388247476627 0.3179106723232359 0.43919576779182357 0.20974326586875847; 0.7551160501548224 0.049772782182839426 0.09284926395551141 0.07862188927391855 0.17346407477062986; 0.6258040138863172 0.5928022963567454 0.24251650865340169 0.6626410383247967 0.32752198021506784]
Expand Down Expand Up @@ -576,6 +575,7 @@ end
@test Enzyme.gradient(Reverse, chol_upper, x)[1] [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
end

using EnzymeTestUtils
@testset "Linear solve for triangular matrices" begin
@testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular),
TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3))
Expand Down

0 comments on commit 924a271

Please sign in to comment.