Skip to content

Commit

Permalink
1.11: more gcloaded work (#1999)
Browse files Browse the repository at this point in the history
* 1.11: more gcloaded work

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Oct 22, 2024
1 parent 924a271 commit b76585a
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 12 deletions.
15 changes: 14 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3207,7 +3207,20 @@ function annotate!(mod, mode)
)
if haskey(fns, fname)
fn = fns[fname]
push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0))
if LLVM.version().major <= 15
push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0))
else
push!(function_attributes(fn),
EnumAttribute(
"memory",
MemoryEffect(
(MRI_Ref << getLocationPos(ArgMem)) |
(MRI_NoModRef << getLocationPos(InaccessibleMem)) |
(MRI_NoModRef << getLocationPos(Other)),
).data,
)
)
end
for u in LLVM.uses(fn)
c = LLVM.user(u)
if !isa(c, LLVM.CallInst)
Expand Down
38 changes: 27 additions & 11 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -828,13 +828,14 @@ function nodecayed_phis!(mod::LLVM.Module)
base_1, off_1, _ = get_base_and_offset(operands(v)[1])

if o2 == rhs && base_1 == base_2 && off_1 == off_2
return v2, offset, true
return operands(v)[1], offset, true
end

rhs = ptrtoint!(b, get_memory_data(b, operands(v)[1]), offty)
lhs = ptrtoint!(b, operands(v)[2], offty)
off2 = nuwsub!(b, rhs, lhs)
return v2, nuwadd!(b, offset, off2), true
off2 = nuwsub!(b, lhs, rhs)
add = nuwadd!(b, offset, off2)
return operands(v)[1], add, true
end
end
end
Expand Down Expand Up @@ -905,8 +906,12 @@ function nodecayed_phis!(mod::LLVM.Module)
end

if isa(v, LLVM.BitCastInst)
preop = operands(v)[1]
while isa(preop, LLVM.BitCastInst)
preop = operands(preop)[1]
end
v2, offset, skipload =
getparent(operands(v)[1], offset, hasload)
getparent(preop, offset, hasload)
v2 = bitcast!(
b,
v2,
Expand Down Expand Up @@ -1059,7 +1064,7 @@ function nodecayed_phis!(mod::LLVM.Module)
end

nb = IRBuilder()
position!(nb, inst)
position!(nb, nonphi)

offset = goffsets[inst]
append!(LLVM.incoming(offset), offsets)
Expand All @@ -1068,15 +1073,26 @@ function nodecayed_phis!(mod::LLVM.Module)
end

nphi = nextvs[inst]
if !all(x -> x[1] == nvs[1][1], nvs)
append!(LLVM.incoming(nphi), nvs)
else
replace_uses!(nphi, nvs[1][1])

function ogbc(x)
while isa(x, LLVM.BitCastInst)
x = operands(x)[1]
end
return x
end

if all(x -> ogbc(x[1]) == ogbc(nvs[1][1]), nvs)
bc = ogbc(nvs[1][1])
if value_type(bc) != value_type(nphi)
bc = bitcast!(nb, bc, value_type(nphi))
end
replace_uses!(nphi, bc)
LLVM.API.LLVMInstructionEraseFromParent(nphi)
nphi = nvs[1][1]
nphi = bc
else
append!(LLVM.incoming(nphi), nvs)
end

position!(nb, nonphi)
if addr == 13
@static if VERSION < v"1.11-"
nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10))
Expand Down
46 changes: 46 additions & 0 deletions test/optimize.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using Enzyme, LinearAlgebra, Test

function gcloaded_fixup(dest, src)
N = size(src)
dat = src.data
len = N[1]

i = 1
while true
j = 1
while true
ld = @inbounds if i <= j
dat[(i-1) * 2 + j]
else
dat[(j-1) * 2 + i]
end
@inbounds dest[(i-1) * 2 + j] = ld
if j == len
break
end
j += 1
end
if i == len
break
end
i += 1
end
return nothing
end

@testset "GCLoaded fixup" begin
H = Hermitian(Matrix([4.0 1.0; 2.0 5.0]))
dest = Matrix{Float64}(undef, 2, 2)

Enzyme.autodiff(
ForwardWithPrimal,
gcloaded_fixup,
Const,
Const(dest),
Const(H),
)[1]
@test dest [4.0 2.0; 2.0 5.0]
dest = Matrix{Float64}(undef, 2, 2)
gcloaded_fixup(dest, H)
@test dest [4.0 2.0; 2.0 5.0]
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ end

include("abi.jl")
include("typetree.jl")
include("optimize.jl")

include("rules.jl")
include("rrules.jl")
Expand Down

0 comments on commit b76585a

Please sign in to comment.