Skip to content

Commit

Permalink
Fix method ambiguities in SparseArrays (#30120)
Browse files Browse the repository at this point in the history
* Remove unused struct CapturedScalars

* Fix method ambiguities in SparseArrays

* Fix HigherOrderFns._copy(f)

(cherry picked from commit f10530e)
  • Loading branch information
tkf authored and KristofferC committed Dec 30, 2018
1 parent 7352037 commit f47a4f4
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
11 changes: 6 additions & 5 deletions stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@ function _copy(f, args...)
parevalf, passedargstup = capturescalars(f, args)
return _copy(parevalf, passedargstup...)
end
_copy(f) = throw(MethodError(_copy, (f,))) # avoid method ambiguity

function _shapecheckbc(f, args...)
_aresameshape(args...) ? _noshapecheck_map(f, args...) : _diffshape_broadcast(f, args...)
Expand Down Expand Up @@ -1006,10 +1007,6 @@ end
_copyto!(parevalf, dest, passedsrcargstup...)
end

struct CapturedScalars{F, Args, Order}
args::Args
end

# capturescalars takes a function (f) and a tuple of mixed sparse vectors/matrices and
# broadcast scalar arguments (mixedargs), and returns a function (parevalf, i.e. partially
# evaluated f) and a reduced argument tuple (passedargstup) containing only the sparse
Expand All @@ -1024,9 +1021,13 @@ end
# Work around losing Type{T}s as DataTypes within the tuple that makeargs creates
@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Vararg{Any}}) where {T} =
capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
@inline capturescalars(f, mixedargs::Tuple{Ref{Type{T}}, Ref{Type{S}}, Vararg{Any}}) where {T, S} =
# This definition is identical to the one above and necessary only for
# avoiding method ambiguity.
capturescalars((args...)->f(T, args...), Base.tail(mixedargs))
@inline capturescalars(f, mixedargs::Tuple{SparseVecOrMat, Ref{Type{T}}, Vararg{Any}}) where {T} =
capturescalars((a1, args...)->f(a1, T, args...), (mixedargs[1], Base.tail(Base.tail(mixedargs))...))
@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
@inline capturescalars(f, mixedargs::Tuple{Union{Ref,AbstractArray{<:Any,0}}, Ref{Type{T}}, Vararg{Any}}) where {T} =
capturescalars((args...)->f(mixedargs[1], T, args...), Base.tail(Base.tail(mixedargs)))

nonscalararg(::SparseVecOrMat) = true
Expand Down
4 changes: 4 additions & 0 deletions stdlib/SparseArrays/test/ambiguous_exec.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

using Test, SparseArrays
@test detect_ambiguities(SparseArrays; imported=true, recursive=true) == []
24 changes: 24 additions & 0 deletions stdlib/SparseArrays/test/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -632,4 +632,28 @@ end
@test minimum(sparse([1, 2], [1, 2], ones(Int32, 2)), dims = 1) isa Matrix
end

@testset "Issue #30118" begin
@test ((_, x) -> x).(Int, spzeros(3)) == spzeros(3)
@test ((_, _, x) -> x).(Int, Int, spzeros(3)) == spzeros(3)
@test ((_, _, _, x) -> x).(Int, Int, Int, spzeros(3)) == spzeros(3)
@test_broken ((_, _, _, _, x) -> x).(Int, Int, Int, Int, spzeros(3)) == spzeros(3)
end

using SparseArrays.HigherOrderFns: SparseVecStyle

@testset "Issue #30120: method ambiguity" begin
# HigherOrderFns._copy(f) was ambiguous. It may be impossible to
# invoke this from dot notation and it is an error anyway. But
# when someone invokes it by accident, we want it to produce a
# meaningful error.
err = try
copy(Broadcast.Broadcasted{SparseVecStyle}(rand, ()))
catch err
err
end
@test err isa MethodError
@test !occursin("is ambiguous", sprint(showerror, err))
@test occursin("no method matching _copy(::typeof(rand))", sprint(showerror, err))
end

end # module
15 changes: 15 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2341,4 +2341,19 @@ end
@test m2.module == SparseArrays
end

@testset "sprandn with type $T" for T in (Float64, Float32, Float16, ComplexF64, ComplexF32, ComplexF16)
@test sprandn(T, 5, 5, 0.5) isa AbstractSparseMatrix{T}
end
@testset "sprandn with invalid type $T" for T in (AbstractFloat, BigFloat, Complex)
@test_throws MethodError sprandn(T, 5, 5, 0.5)
end

@testset "method ambiguity" begin
# Ambiguity test is run inside a clean process.
# https://github.com/JuliaLang/julia/issues/28804
script = joinpath(@__DIR__, "ambiguous_exec.jl")
cmd = `$(Base.julia_cmd()) --startup-file=no $script`
@test success(pipeline(cmd; stdout=stdout, stderr=stderr))
end

end # module

0 comments on commit f47a4f4

Please sign in to comment.