From 0a8fbcb727f5f0efac54e297edc49001c5a207d2 Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Sat, 25 Nov 2017 16:34:42 +0100 Subject: [PATCH] Change find() to return the same index type as pairs() This does not change anything for AbstractVectors and general iterables, which continue to use linear indices. For other AbstractArrays, return CartesianIndexes (rather than linear indices). For Dicts, return keys (previously not supported at all). Relying on collect() to choose the return element type allows supporting any definition of pairs(), including that for Dict, which creates a standard Generator for which eltype() returns Any. --- base/array.jl | 88 ++++++++++++++++++------------------- base/sparse/sparsematrix.jl | 2 +- test/arrayops.jl | 10 +++++ 3 files changed, 55 insertions(+), 45 deletions(-) diff --git a/base/array.jl b/base/array.jl index 1e3ea06975a83..d6fcf5f213003 100644 --- a/base/array.jl +++ b/base/array.jl @@ -1723,48 +1723,60 @@ findlast(testf::Function, A) = findprev(testf, A, endof(A)) """ find(f::Function, A) -Return a vector `I` of the linear indexes of `A` where `f(A[I])` returns `true`. +Return a vector `I` of the indices or keys of `A` where `f(A[I])` returns `true`. If there are no such elements of `A`, return an empty array. +Indices or keys are of the same type as those returned by [`keys(A)`](@ref) +and [`pairs(A)`](@ref) for `AbstractArray` and `Associative` objects, +and are linear indices of type `Int` for other iterables. + # Examples ```jldoctest -julia> A = [1 2 0; 3 4 0] -2×3 Array{Int64,2}: - 1 2 0 - 3 4 0 +julia> x = [1, 3, 4] +3-element Array{Int64,1}: + 1 + 3 + 4 -julia> find(isodd, A) +julia> find(isodd, x) 2-element Array{Int64,1}: 1 2 + julia> A = [1 2 0; 3 4 0] + 2×3 Array{Int64,2}: + 1 2 0 + 3 4 0 + + julia> find(isodd, A) + 2-element Array{CartesianIndex{2},1}: + CartesianIndex(1, 1) + CartesianIndex(2, 1) + julia> find(!iszero, A) -4-element Array{Int64,1}: - 1 - 2 - 3 - 4 +4-element Array{CartesianIndex{2},1}: +CartesianIndex(1, 1) +CartesianIndex(2, 1) +CartesianIndex(1, 2) +CartesianIndex(2, 2) + +julia> d = Dict(:A => 10, :B => -1, :C => 0) +Dict{Symbol,Int64} with 3 entries: + :A => 10 + :B => -1 + :C => 0 + +julia> find(x -> x >= 0, d) +2-element Array{Symbol,1}: + :A + :C -julia> find(isodd, [2, 4]) -0-element Array{Int64,1} ``` """ -function find(testf::Function, A) - # use a dynamic-length array to store the indexes, then copy to a non-padded - # array for the return - tmpI = Vector{Int}() - inds = _index_remapper(A) - for (i,a) = enumerate(A) - if testf(a) - push!(tmpI, inds[i]) - end - end - I = Vector{Int}(uninitialized, length(tmpI)) - copy!(I, tmpI) - return I -end -_index_remapper(A::AbstractArray) = linearindices(A) -_index_remapper(iter) = OneTo(typemax(Int)) # safe for objects that don't implement length +find(testf::Function, A) = collect(first(p) for p in _pairs(A) if testf(last(p))) + +_pairs(A::Union{AbstractArray, Associative}) = pairs(A) +_pairs(iter) = zip(OneTo(typemax(Int)), iter) # safe for objects that don't implement length """ find(A) @@ -1789,22 +1801,10 @@ julia> find(falses(3)) ``` """ function find(A) - nnzA = count(t -> t != 0, A) - I = Vector{Int}(uninitialized, nnzA) - cnt = 1 - inds = _index_remapper(A) - warned = false - for (i,a) in enumerate(A) - if !warned && !(a isa Bool) - depwarn("In the future `find(A)` will only work on boolean collections. Use `find(x->x!=0, A)` instead.", :find) - warned = true - end - if a != 0 - I[cnt] = inds[i] - cnt += 1 - end + if !(eltype(A) === Bool) && !all(x -> x isa Bool, A) + depwarn("In the future `find(A)` will only work on boolean collections. Use `find(x->x!=0, A)` instead.", :find) end - return I + collect(first(p) for p in _pairs(A) if last(p) != 0) end find(x::Bool) = x ? [1] : Vector{Int}() diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index d9af41a9ab18b..0e4ba14fb1b50 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -1264,7 +1264,7 @@ function find(p::Function, S::SparseMatrixCSC) end sz = size(S) I, J = _findn(p, S) - return sub2ind(sz, I, J) + return CartesianIndex.(I, J) end findn(S::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = _findn(x->x!=0, S) diff --git a/test/arrayops.jl b/test/arrayops.jl index 5a8e2c2107cc0..f18fc347d3136 100644 --- a/test/arrayops.jl +++ b/test/arrayops.jl @@ -464,6 +464,16 @@ end @test findprev(isodd, [2,4,5,3,9,2,0], 7) == 5 @test findprev(isodd, [2,4,5,3,9,2,0], 2) == 0 end +@testset "find with Matrix" begin + A = [1 2 0; 3 4 0] + @test find(isodd, A) == [CartesianIndex(1, 1), CartesianIndex(2, 1)] + @test find(!iszero, A) == [CartesianIndex(1, 1), CartesianIndex(2, 1), + CartesianIndex(1, 2), CartesianIndex(2, 2)] +end +@testset "find with Dict" begin + d = Dict(:A => 10, :B => -1, :C => 0) + @test sort(find(x -> x >= 0, d)) == [:A, :C] +end @testset "find with general iterables" begin s = "julia" @test find(c -> c == 'l', s) == [3]