Skip to content

Commit

Permalink
Disable sort on integers (#1207)
Browse files Browse the repository at this point in the history
* Disable sort on integers

* fixup
  • Loading branch information
wsmoses authored Dec 18, 2023
1 parent ab9c91a commit cb3037f
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
7 changes: 7 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1672,6 +1672,13 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err
ip = API.EnzymeTypeAnalyzerToString(data)
sval = Base.unsafe_string(ip)
API.EnzymeStringFree(ip)

if isa(val, LLVM.Instruction)
mi, rt = enzyme_custom_extract_mi(LLVM.parent(LLVM.parent(val))::LLVM.Function, #=error=#false)
if mi !== nothing
msg *= "\n" * string(mi) * "\n"
end
end
throw(IllegalTypeAnalysisException(msg, sval, ir, bt))
elseif errtype == API.ET_NoType
@assert B != C_NULL
Expand Down
15 changes: 8 additions & 7 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,9 +486,9 @@ end
function EnzymeRules.forward(
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
xs::Duplicated;
xs::Duplicated{T};
kwargs...
)
) where {T <: AbstractArray{<:AbstractFloat}}
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
xs.dval .= xs.dval[inds]
Expand All @@ -506,7 +506,7 @@ function EnzymeRules.forward(
RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}},
xs::BatchDuplicated{T, N};
kwargs...
) where {T, N}
) where {T <: AbstractArray{<:AbstractFloat}, N}
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
for i in 1:N
Expand All @@ -521,13 +521,14 @@ function EnzymeRules.forward(
end
end


function EnzymeRules.augmented_primal(
config::EnzymeRules.ConfigWidth{1},
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
xs::Duplicated;
xs::Duplicated{T};
kwargs...
)
) where {T <: AbstractArray{<:AbstractFloat}}
inds = sortperm(xs.val; kwargs...)
xs.val .= xs.val[inds]
xs.dval .= xs.dval[inds]
Expand All @@ -549,9 +550,9 @@ function EnzymeRules.reverse(
::Const{typeof(sort!)},
RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}},
tape,
xs::Duplicated;
xs::Duplicated{T};
kwargs...,
)
) where {T <: AbstractArray{<:AbstractFloat}}
inds = tape
back_inds = sortperm(inds)
xs.dval .= xs.dval[back_inds]
Expand Down
25 changes: 24 additions & 1 deletion test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,19 @@ using Enzyme
using Enzyme.EnzymeRules
using Test

@testset "Internal rules" begin
struct TPair
a::Float64
b::Float64
end

function sorterrfn(t, x)
function lt(a, b)
return a.a < b.a
end
return first(sortperm(t, lt=lt)) * x
end

@testset "Sort rules" begin
function f1(x)
a = [1.0, 3.0, x]
sort!(a)
Expand All @@ -27,6 +39,17 @@ using Test
@test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3
@test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0)
@test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3

dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)])
res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0))

@test res[1][2] 3
@test dd.dval[1].a 0
@test dd.dval[1].b 0
@test dd.dval[2].a 0
@test dd.dval[2].b 0
@test dd.dval[3].a 0
@test dd.dval[3].b 0
end

@testset "Linear Solve" begin
Expand Down

0 comments on commit cb3037f

Please sign in to comment.