Skip to content

Commit

Permalink
Continuing 1.11 stuff (#1984)
Browse files Browse the repository at this point in the history
* Continuing 1.11 stuff

* cleanup

* fix

* fix

* fix

* fixup

* fixup

* bypass for now

* more info and utter confusion

* more stringent assertions

* correct checks

* s

* better prints

* clean
  • Loading branch information
wsmoses authored Oct 18, 2024
1 parent 6987986 commit da53c03
Show file tree
Hide file tree
Showing 5 changed files with 298 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2319,7 +2319,7 @@ function shadow_alloc_rewrite(V::LLVM.API.LLVMValueRef, gutils::API.EnzymeGradie
world = enzyme_extract_world(fn)
has, Ty, byref = abs_typeof(V)
if !has
throw(AssertionError("Allocation could not have its type statically determined $(string(V))"))
throw(AssertionError("$(string(fn))\n Allocation could not have its type statically determined $(string(V))"))
end
rt = active_reg_inner(Ty, (), world)
if rt == ActiveState || rt == MixedState
Expand Down
38 changes: 38 additions & 0 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,21 @@ struct AutodiffCallInfo <: CallInfo
info::CallInfo
end

@static if VERSION < v"1.11.0-"
else
@inline function myunsafe_copyto!(dest::MemoryRef{T}, src::MemoryRef{T}, n) where {T}
Base.@_terminates_globally_notaskstate_meta
@boundscheck memoryref(dest, n), memoryref(src, n)
t1 = Base.@_gc_preserve_begin dest
t2 = Base.@_gc_preserve_begin src
Base.memmove(pointer(dest), pointer(src), n * Base.aligned_sizeof(T))
Base.@_gc_preserve_end t2
Base.@_gc_preserve_end t1
return dest
end
end


function abstract_call_known(
interp::EnzymeInterpreter,
@nospecialize(f),
Expand Down Expand Up @@ -322,6 +337,29 @@ function abstract_call_known(
end
end

@static if VERSION < v"1.11.0-"
else
if f === Base.unsafe_copyto! && length(argtypes) == 4 &&
widenconst(argtypes[2]) <: Base.MemoryRef &&
widenconst(argtypes[3]) == widenconst(argtypes[2]) &&
Base.allocatedinline(eltype(widenconst(argtypes[2]))) && Base.isbitstype(eltype(widenconst(argtypes[2])))

arginfo2 = ArgInfo(
fargs isa Nothing ? nothing :
[:(Enzyme.Compiler.Interpreter.myunsafe_copyto!), fargs[2:end]...],
[Core.Const(Enzyme.Compiler.Interpreter.myunsafe_copyto!), argtypes[2:end]...],
)
return abstract_call_known(
interp,
Enzyme.Compiler.Interpreter.myunsafe_copyto!,
arginfo2,
si,
sv,
max_methods,
)
end
end

if f === Enzyme.autodiff && length(argtypes) >= 4
if widenconst(argtypes[2]) <: Enzyme.Mode &&
widenconst(argtypes[3]) <: Enzyme.Annotation &&
Expand Down
76 changes: 74 additions & 2 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,7 @@ function nodecayed_phis!(mod::LLVM.Module)
b = IRBuilder()
position!(b, terminator(pb))


v0 = v
@inline function getparent(v, offset, hasload)
if addr == 11 && addrspace(value_type(v)) == 10
Expand All @@ -794,16 +795,87 @@ function nodecayed_phis!(mod::LLVM.Module)
if addr == 13 && !hasload
if isa(v, LLVM.LoadInst)
v2, o2, hl2 = getparent(operands(v)[1], LLVM.ConstantInt(offty, 0), true)
@assert o2 == LLVM.ConstantInt(offty, 0)
rhs = LLVM.ConstantInt(offty, 0)
if o2 != rhs
msg = sprint() do io::IO
println(
io,
"Enzyme internal error addr13 load doesn't keep offset 0",
)
println(io, "v=", string(v))
println(io, "v2=", string(v2))
println(io, "o2=", string(o2))
println(io, "hl2=", string(hl2))
println(io, "offty=", string(offty))
println(io, "rhs=", string(rhs))
end
throw(AssertionError(msg))
end
return v2, offset, true
end
if isa(v, LLVM.CallInst)
cf = LLVM.called_operand(v)
if isa(cf, LLVM.Function) && LLVM.name(cf) == "julia.gc_loaded"
ld = operands(v)[2]
while isa(ld, LLVM.BitCastInst) || isa(ld, LLVM.AddrSpaceCastInst)
ld = operands(ld)[1]
end
if isa(ld, LLVM.LoadInst)
v2, o2, hl2 = getparent(operands(ld)[1], LLVM.ConstantInt(offty, 0), true)
@assert o2 == LLVM.ConstantInt(offty, sizeof(Int))
rhs = LLVM.ConstantInt(offty, sizeof(Int))
if o2 != rhs
msg = sprint() do io::IO
println(
io,
"Enzyme internal error addr13 load doesn't keep offset 0",
)
println(io, "mod=", string(LLVM.parent(f)))
println(io, "f=", string(f))
println(io, "v=", string(v))
println(io, "opv[1]=", string(operands(v)[1]))
println(io, "opv[2]=", string(operands(v)[2]))
println(io, "ld=", string(ld))
println(io, "ld_op[1]=", string(operands(ld)[1]))

println(io, "v2=", string(v2))
println(io, "o2=", string(o2))
println(io, "hl2=", string(hl2))

println(io, "offty=", string(offty))
println(io, "rhs=", string(rhs))
end
throw(AssertionError(msg))
end

# We currently only support gc_loaded(mem, ptr) where ptr = (({size_t, {}*}*)mem)->second
# [aka a load of the second element of mem]
base_2, off_2, _ = get_base_and_offset(v2)
base_1, off_1, _ = get_base_and_offset(operands(v)[1])
if base_1 != base_2 || off_1 != off_2
msg = sprint() do io::IO
println(
io,
"Enzyme internal error addr13 load data isn't offset of mem",
)
println(io, "f=", string(f))
println(io, "v=", string(v))
println(io, "opv[1]=", string(operands(v)[1]))
println(io, "opv[2]=", string(operands(v)[2]))
println(io, "ld=", string(ld))
println(io, "ld_op[1]=", string(operands(ld)[1]))

println(io, "v2=", string(v2))
println(io, "o2=", string(o2))
println(io, "hl2=", string(hl2))

println(io, "base_1=", string(base_1))
println(io, "base_2=", string(base_2))
println(io, "off_1=", string(off_1))
println(io, "off_2=", string(off_2))
end
throw(AssertionError(msg))
end

return v2, offset, true
end
end
Expand Down
178 changes: 167 additions & 11 deletions src/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,24 @@ function restore_lookups(mod::LLVM.Module)
eraseInst(mod, f)
end
end
for f in functions(mod)
for fattr in collect(function_attributes(f))
if isa(fattr, LLVM.StringAttribute)
if kind(fattr) == "enzymejl_needs_restoration"
v = parse(UInt, LLVM.value(fattr))
replace_uses!(
f,
LLVM.Value(
LLVM.API.LLVMConstIntToPtr(
ConstantInt(T_size_t, convert(UInt, v)),
value_type(f),
),
),
)
end
end
end
end
end

function check_ir(job, mod::LLVM.Module)
Expand Down Expand Up @@ -457,25 +475,163 @@ function check_ir!(job, errors, imported, f::LLVM.Function, deletedfns)

initfn = unwrap_ptr_casts(LLVM.initializer(fn_got))
loadfn = first(instructions(first(blocks(initfn))))::LLVM.LoadInst
opv = operands(loadfn)[1]::LLVM.GlobalVariable

if startswith(fname, "jl_") || startswith(fname, "ijl_")
else
@assert "unsupported jl got"
opv = operands(loadfn)[1]
if !isa(opv, LLVM.GlobalVariable)
msg = sprint() do io::IO
println(
io,
"Enzyme internal error unsupported got",
"Enzyme internal error unsupported got(load)",
)
println(io, "inst=", inst)
println(io, "fname=", fname)
println(io, "FT=", FT)
println(io, "fn_got=", fn_got)
println(io, "init=", string(initfn))
println(io, "mod=", string(mod))
println(io, "initfn=", string(initfn))
println(io, "loadfn=", string(loadfn))
println(io, "opv=", string(opv))
end
throw(AssertionError(msg))
end
opv = opv::LLVM.GlobalVariable

if startswith(fname, "jl_") || startswith(fname, "ijl_") || startswith(fname, "_j_")
else
found = nothing
for lbb in blocks(initfn), linst in collect(instructions(lbb))
if !isa(linst, LLVM.CallInst)
continue
end
cv = LLVM.called_value(linst)
if !isa(cv, LLVM.Function)
continue
end
if LLVM.name(cv) == "ijl_load_and_lookup"
found = linst
break
end
end
if found == nothing
msg = sprint() do io::IO
println(
io,
"Enzyme internal error unsupported got",
)
println(io, "inst=", inst)
println(io, "fname=", fname)
println(io, "FT=", FT)
println(io, "fn_got=", fn_got)
println(io, "init=", string(initfn))
println(io, "opv=", string(opv))
end
throw(AssertionError(msg))
end

legal1, arg1 = abs_cstring(operands(found)[1])
if legal1
else
arg1 = operands(found)[1]

while isa(arg1, ConstantExpr)
if opcode(arg1) == LLVM.API.LLVMAddrSpaceCast ||
opcode(arg1) == LLVM.API.LLVMBitCast ||
opcode(arg1) == LLVM.API.LLVMIntToPtr
arg1 = operands(arg1)[1]
else
break
end
end
if !isa(arg1, LLVM.ConstantInt)
msg = sprint() do io::IO
println(
io,
"Enzyme internal error unsupported got(arg1)",
)
println(io, "inst=", inst)
println(io, "fname=", fname)
println(io, "FT=", FT)
println(io, "fn_got=", fn_got)
println(io, "init=", string(initfn))
println(io, "opv=", string(opv))
println(io, "found=", string(found))
println(io, "arg1=", string(arg1))
end
throw(AssertionError(msg))
end

arg1 = reinterpret(Ptr{Cvoid}, convert(UInt, arg1))
end

legal2, fname = abs_cstring(operands(found)[2])
if !legal2
msg = sprint() do io::IO
println(
io,
"Enzyme internal error unsupported got(fname)",
)
println(io, "inst=", inst)
println(io, "fname=", fname)
println(io, "FT=", FT)
println(io, "fn_got=", fn_got)
println(io, "init=", string(initfn))
println(io, "opv=", string(opv))
println(io, "found=", string(found))
println(io, "fname=", string(operands(found)[2]))
end
throw(AssertionError(msg))
end

hnd = operands(found)[3]

if !isa(hnd, LLVM.GlobalVariable)
msg = sprint() do io::IO
println(
io,
"Enzyme internal error unsupported got(hnd)",
)
println(io, "inst=", inst)
println(io, "fname=", fname)
println(io, "FT=", FT)
println(io, "fn_got=", fn_got)
println(io, "init=", string(initfn))
println(io, "opv=", string(opv))
println(io, "found=", string(found))
println(io, "hnd=", string(hnd))
end
throw(AssertionError(msg))
end
hnd = LLVM.name(hnd)
# println(string(mod))

# TODO we don't restore/lookup now because this fails
# @vchuravy / @gbaraldi this needs help looking at how to get the actual handle and setup

if true
res = nothing
elseif arg1 isa AbstractString
res = ccall(
:ijl_load_and_lookup,
Ptr{Cvoid},
(Cstring, Cstring, Ptr{Cvoid}),
arg1,
fname,
reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr),
)
else
res = ccall(
:ijl_load_and_lookup,
Ptr{Cvoid},
(Ptr{Cvoid}, Cstring, Ptr{Cvoid}),
arg1,
fname,
reinterpret(Ptr{Cvoid}, JIT.lookup(nothing, hnd).ptr),
)
end

if res !== nothing
push!(function_attributes(newf), StringAttribute("enzymejl_needs_restoration", string(convert(UInt, res))))
end
# TODO we can make this relocatable if desired by having restore lookups re-create this got initializer/etc
# metadata(newf)["enzymejl_flib"] = flib
# metadata(newf)["enzymejl_flib"] = flib

end

if value_type(newf) != value_type(inst)
newf = const_pointercast(newf, value_type(inst))
Expand Down
Loading

0 comments on commit da53c03

Please sign in to comment.