diff --git a/src/compiler.jl b/src/compiler.jl index 36ba2c8656..bb7ec835f8 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -3207,7 +3207,20 @@ function annotate!(mod, mode) ) if haskey(fns, fname) fn = fns[fname] - push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) + if LLVM.version().major <= 15 + push!(function_attributes(fn), LLVM.EnumAttribute("readonly", 0)) + else + push!(function_attributes(fn), + EnumAttribute( + "memory", + MemoryEffect( + (MRI_Ref << getLocationPos(ArgMem)) | + (MRI_NoModRef << getLocationPos(InaccessibleMem)) | + (MRI_NoModRef << getLocationPos(Other)), + ).data, + ) + ) + end for u in LLVM.uses(fn) c = LLVM.user(u) if !isa(c, LLVM.CallInst) diff --git a/src/compiler/optimize.jl b/src/compiler/optimize.jl index 86760b423f..a35de5608f 100644 --- a/src/compiler/optimize.jl +++ b/src/compiler/optimize.jl @@ -828,13 +828,14 @@ function nodecayed_phis!(mod::LLVM.Module) base_1, off_1, _ = get_base_and_offset(operands(v)[1]) if o2 == rhs && base_1 == base_2 && off_1 == off_2 - return v2, offset, true + return operands(v)[1], offset, true end rhs = ptrtoint!(b, get_memory_data(b, operands(v)[1]), offty) lhs = ptrtoint!(b, operands(v)[2], offty) - off2 = nuwsub!(b, rhs, lhs) - return v2, nuwadd!(b, offset, off2), true + off2 = nuwsub!(b, lhs, rhs) + add = nuwadd!(b, offset, off2) + return operands(v)[1], add, true end end end @@ -905,8 +906,12 @@ function nodecayed_phis!(mod::LLVM.Module) end if isa(v, LLVM.BitCastInst) + preop = operands(v)[1] + while isa(preop, LLVM.BitCastInst) + preop = operands(preop)[1] + end v2, offset, skipload = - getparent(operands(v)[1], offset, hasload) + getparent(preop, offset, hasload) v2 = bitcast!( b, v2, @@ -1059,7 +1064,7 @@ function nodecayed_phis!(mod::LLVM.Module) end nb = IRBuilder() - position!(nb, inst) + position!(nb, nonphi) offset = goffsets[inst] append!(LLVM.incoming(offset), offsets) @@ -1068,15 +1073,26 @@ function nodecayed_phis!(mod::LLVM.Module) end nphi = nextvs[inst] - if !all(x -> x[1] == nvs[1][1], nvs) - append!(LLVM.incoming(nphi), nvs) - else - replace_uses!(nphi, nvs[1][1]) + + function ogbc(x) + while isa(x, LLVM.BitCastInst) + x = operands(x)[1] + end + return x + end + + if all(x -> ogbc(x[1]) == ogbc(nvs[1][1]), nvs) + bc = ogbc(nvs[1][1]) + if value_type(bc) != value_type(nphi) + bc = bitcast!(nb, bc, value_type(nphi)) + end + replace_uses!(nphi, bc) LLVM.API.LLVMInstructionEraseFromParent(nphi) - nphi = nvs[1][1] + nphi = bc + else + append!(LLVM.incoming(nphi), nvs) end - position!(nb, nonphi) if addr == 13 @static if VERSION < v"1.11-" nphi = bitcast!(nb, nphi, LLVM.PointerType(ty, 10)) diff --git a/test/optimize.jl b/test/optimize.jl new file mode 100644 index 0000000000..a4fcc1768f --- /dev/null +++ b/test/optimize.jl @@ -0,0 +1,46 @@ +using Enzyme, LinearAlgebra, Test + +function gcloaded_fixup(dest, src) + N = size(src) + dat = src.data + len = N[1] + + i = 1 + while true + j = 1 + while true + ld = @inbounds if i <= j + dat[(i-1) * 2 + j] + else + dat[(j-1) * 2 + i] + end + @inbounds dest[(i-1) * 2 + j] = ld + if j == len + break + end + j += 1 + end + if i == len + break + end + i += 1 + end + return nothing +end + +@testset "GCLoaded fixup" begin + H = Hermitian(Matrix([4.0 1.0; 2.0 5.0])) + dest = Matrix{Float64}(undef, 2, 2) + + Enzyme.autodiff( + ForwardWithPrimal, + gcloaded_fixup, + Const, + Const(dest), + Const(H), + )[1] + @test dest ≈ [4.0 2.0; 2.0 5.0] + dest = Matrix{Float64}(undef, 2, 2) + gcloaded_fixup(dest, H) + @test dest ≈ [4.0 2.0; 2.0 5.0] +end diff --git a/test/runtests.jl b/test/runtests.jl index 8c7ca39abc..b3a64a2a21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -73,6 +73,7 @@ end include("abi.jl") include("typetree.jl") +include("optimize.jl") include("rules.jl") include("rrules.jl")