Skip to content

Commit

Permalink
More 1.11 stuff (#2015)
Browse files Browse the repository at this point in the history
* More 1.11 stuff

* fixup
  • Loading branch information
wsmoses authored Oct 28, 2024
1 parent cdee028 commit ecd490c
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 6 deletions.
38 changes: 32 additions & 6 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1615,8 +1615,8 @@ function julia_error(
legal2, obj = absint(cur)

# Only do so for the immediate operand/etc to a phi, since otherwise we will make multiple
if legal2 &&
active_reg_inner(TT, (), world) == ActiveState &&
if legal2
if active_reg_inner(TT, (), world) == ActiveState &&
isa(cur, LLVM.ConstantExpr) &&
cur == data2
if width == 1
Expand All @@ -1634,6 +1634,14 @@ function julia_error(
end
return shadowres
end
end

@static if VERSION < v"1.11-"
else
if obj isa Memory && obj == typeof(obj).instance
return make_batched(ncur, prevbb)
end
end
end

badval = if legal2
Expand All @@ -1652,10 +1660,8 @@ function julia_error(
if isa(cur, LLVM.UndefValue)
return make_batched(ncur, prevbb)
end
@static if LLVM.version() >= v"12"
if isa(cur, LLVM.PoisonValue)
return make_batched(ncur, prevbb)
end
if isa(cur, LLVM.PoisonValue)
return make_batched(ncur, prevbb)
end
if isa(cur, LLVM.ConstantAggregateZero)
return make_batched(ncur, prevbb)
Expand Down Expand Up @@ -1794,6 +1800,18 @@ function julia_error(
return shadowres
end
end

if isa(cur, LLVM.LoadInst) || isa(cur, LLVM.BitCastInst) || isa(cur, LLVM.AddrSpaceCastInst) || (isa(cur, LLVM.GetElementPtrInst) && all(x->isa(x, LLVM.ConstantInt), operands(cur)[2:end]))
lhs = make_replacement(operands(cur)[1], prevbb)
if illegal
return ncur
end
if lhs == operands(ncur)[1]
return make_batched(ncur, prevbb)
elseif width != 1 && isa(lhs, LLVM.InsertValueInst) && operands(lhs)[2] == operands(ncur)[1]
return make_batched(ncur, prevbb)
end
end

if isa(cur, LLVM.PHIInst)
Bphi = IRBuilder()
Expand Down Expand Up @@ -6322,6 +6340,14 @@ function GPUCompiler.codegen(

func = mi.specTypes.parameters[1]

@static if VERSION < v"1.11-"
else
if func == typeof(Core.memoryref)
attributes = function_attributes(llvmfn)
push!(attributes, EnumAttribute("alwaysinline", 0))
end
end

meth = mi.def
name = meth.name
jlmod = meth.module
Expand Down
5 changes: 5 additions & 0 deletions src/compiler/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,11 @@ end

function is_alwaysinline_func(@nospecialize(TT))
isa(TT, DataType) || return false
@static if VERSION v"1.11-"
if TT.parameters[1] == typeof(Core.memoryref)
return true
end
end
return false
end

Expand Down
18 changes: 18 additions & 0 deletions test/abi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -590,5 +590,23 @@ end

Enzyme.autodiff(Forward, byrefs, BatchDuplicated([1.0], ([1.0], [1.0])), BatchDuplicated([1.0], ([1.0], [1.0]) ) )
end

function myunique0()
return Vector{Float64}(undef, 0)
end
@static if VERSION < v"1.11-"
@testset "Forward mode array construct" begin
autodiff(Forward, myunique0, Duplicated)
end
else
function myunique()
m = Memory{Float64}.instance
return Core.memoryref(m)
end
@testset "Forward mode array construct" begin
autodiff(Forward, myunique, Duplicated)
autodiff(Forward, myunique0, Duplicated)
end
end

include("usermixed.jl")

0 comments on commit ecd490c

Please sign in to comment.