From cb3037f92f77eb43fc89efdf0af51a102e629e5b Mon Sep 17 00:00:00 2001 From: William Moses Date: Mon, 18 Dec 2023 17:32:55 -0600 Subject: [PATCH] Disable sort on integers (#1207) * Disable sort on integers * fixup --- src/compiler.jl | 7 +++++++ src/internal_rules.jl | 15 ++++++++------- test/internal_rules.jl | 25 ++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/compiler.jl b/src/compiler.jl index 46a49a07ff..9916ddf6a1 100644 --- a/src/compiler.jl +++ b/src/compiler.jl @@ -1672,6 +1672,13 @@ function julia_error(cstr::Cstring, val::LLVM.API.LLVMValueRef, errtype::API.Err ip = API.EnzymeTypeAnalyzerToString(data) sval = Base.unsafe_string(ip) API.EnzymeStringFree(ip) + + if isa(val, LLVM.Instruction) + mi, rt = enzyme_custom_extract_mi(LLVM.parent(LLVM.parent(val))::LLVM.Function, #=error=#false) + if mi !== nothing + msg *= "\n" * string(mi) * "\n" + end + end throw(IllegalTypeAnalysisException(msg, sval, ir, bt)) elseif errtype == API.ET_NoType @assert B != C_NULL diff --git a/src/internal_rules.jl b/src/internal_rules.jl index d073a78c47..6f6fa465f5 100644 --- a/src/internal_rules.jl +++ b/src/internal_rules.jl @@ -486,9 +486,9 @@ end function EnzymeRules.forward( ::Const{typeof(sort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated; + xs::Duplicated{T}; kwargs... - ) + ) where {T <: AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -506,7 +506,7 @@ function EnzymeRules.forward( RT::Type{<:Union{Const, BatchDuplicatedNoNeed, BatchDuplicated}}, xs::BatchDuplicated{T, N}; kwargs... - ) where {T, N} + ) where {T <: AbstractArray{<:AbstractFloat}, N} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] for i in 1:N @@ -521,13 +521,14 @@ function EnzymeRules.forward( end end + function EnzymeRules.augmented_primal( config::EnzymeRules.ConfigWidth{1}, ::Const{typeof(sort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, - xs::Duplicated; + xs::Duplicated{T}; kwargs... - ) + ) where {T <: AbstractArray{<:AbstractFloat}} inds = sortperm(xs.val; kwargs...) xs.val .= xs.val[inds] xs.dval .= xs.dval[inds] @@ -549,9 +550,9 @@ function EnzymeRules.reverse( ::Const{typeof(sort!)}, RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated}}, tape, - xs::Duplicated; + xs::Duplicated{T}; kwargs..., - ) + ) where {T <: AbstractArray{<:AbstractFloat}} inds = tape back_inds = sortperm(inds) xs.dval .= xs.dval[back_inds] diff --git a/test/internal_rules.jl b/test/internal_rules.jl index ccf61fef25..f1e6171f30 100644 --- a/test/internal_rules.jl +++ b/test/internal_rules.jl @@ -4,7 +4,19 @@ using Enzyme using Enzyme.EnzymeRules using Test -@testset "Internal rules" begin +struct TPair + a::Float64 + b::Float64 +end + +function sorterrfn(t, x) + function lt(a, b) + return a.a < b.a + end + return first(sortperm(t, lt=lt)) * x +end + +@testset "Sort rules" begin function f1(x) a = [1.0, 3.0, x] sort!(a) @@ -27,6 +39,17 @@ using Test @test autodiff(Forward, f2, Duplicated(2.0, 1.0))[1] == -3 @test autodiff(Forward, f2, BatchDuplicated(2.0, (1.0, 2.0)))[1] == (var"1"=-3.0, var"2"=-6.0) @test autodiff(Reverse, f2, Active, Active(2.0))[1][1] == -3 + + dd = Duplicated([TPair(1, 2), TPair(2, 3), TPair(0, 1)], [TPair(0, 0), TPair(0, 0), TPair(0, 0)]) + res = Enzyme.autodiff(Reverse, sorterrfn, dd, Active(1.0)) + + @test res[1][2] ≈ 3 + @test dd.dval[1].a ≈ 0 + @test dd.dval[1].b ≈ 0 + @test dd.dval[2].a ≈ 0 + @test dd.dval[2].b ≈ 0 + @test dd.dval[3].a ≈ 0 + @test dd.dval[3].b ≈ 0 end @testset "Linear Solve" begin