Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make searchsorted*/findnext/findprev return values of keytype #32978

Merged
merged 13 commits into from
Apr 28, 2020
8 changes: 4 additions & 4 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1617,7 +1617,7 @@ CartesianIndex(2, 1)
"""
function findnext(A, start)
l = last(keys(A))
i = start
i = oftype(l, start)
i > l && return nothing
while true
A[i] && return i
Expand Down Expand Up @@ -1699,7 +1699,7 @@ CartesianIndex(1, 1)
"""
function findnext(testf::Function, A, start)
l = last(keys(A))
i = start
i = oftype(l, start)
i > l && return nothing
while true
testf(A[i]) && return i
Expand Down Expand Up @@ -1796,8 +1796,8 @@ CartesianIndex(2, 1)
```
"""
function findprev(A, start)
i = start
f = first(keys(A))
i = oftype(f, start)
i < f && return nothing
while true
A[i] && return i
Expand Down Expand Up @@ -1887,8 +1887,8 @@ CartesianIndex(2, 1)
```
"""
function findprev(testf::Function, A, start)
i = start
f = first(keys(A))
i = oftype(f, start)
i < f && return nothing
while true
testf(A[i]) && return i
Expand Down
14 changes: 8 additions & 6 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,7 @@ end
function findnext(B::BitArray, start::Integer)
start > 0 || throw(BoundsError(B, start))
start > length(B) && return nothing
unsafe_bitfindnext(B.chunks, start)
unsafe_bitfindnext(B.chunks, Int(start))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsafe_bitfindnext (and prev, too) accepts a start::Integer. I was about to make a comment that we should just move this there and/or restrict its signature, but then I realized it's also used by BitSet. BitSet demands an Int64 result, whereas its use for BitArray demands an Int result. So this seems like the right way to go about it. 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it changed, but my impression is that BitSet calls unsafe_bitfindnext by feeding it an Int, not an Int64, so I would also favor restricting unsafe_bitfindnext to Int input, as this is an internal function. But the PR is also fine as is!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now restricted unsafe_bitfindnext to Int.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sorry for the nitpick, but please change also unsafe_bitfindprev accordingly too :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was fast!

end

#findfirst(B::BitArray) = findnext(B, 1) ## defined in array.jl
Expand All @@ -1399,8 +1399,9 @@ function findnextnot(B::BitArray, start::Integer)
l = length(Bc)
l == 0 && return nothing

chunk_start = _div64(start-1)+1
within_chunk_start = _mod64(start-1)
st = Int(start)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an aside, note that for this kind of thing, there is no problem writing start = Int(start), which saves you from having to find another name :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, then I will change it.

chunk_start = _div64(st-1)+1
within_chunk_start = _mod64(st-1)
mask = ~(_msk64 << within_chunk_start)

@inbounds if chunk_start < l
Expand Down Expand Up @@ -1468,7 +1469,7 @@ end
function findprev(B::BitArray, start::Integer)
start > 0 || return nothing
start > length(B) && throw(BoundsError(B, start))
unsafe_bitfindprev(B.chunks, start)
unsafe_bitfindprev(B.chunks, Int(start))
end

function findprevnot(B::BitArray, start::Integer)
Expand All @@ -1477,8 +1478,9 @@ function findprevnot(B::BitArray, start::Integer)

Bc = B.chunks

chunk_start = _div64(start-1)+1
mask = ~_msk_end(start)
st = Int(start)
chunk_start = _div64(st-1)+1
mask = ~_msk_end(st)

@inbounds begin
if Bc[chunk_start] | mask != _msk64
Expand Down
18 changes: 9 additions & 9 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ partialsort(v::AbstractVector, k::Union{Int,OrdinalRange}; kws...) =

# index of the first value of vector a that is greater than or equal to x;
# returns length(v)+1 if x is greater than all values in v.
function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering) where T<:Integer
function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
u = T(1)
lo = lo - u
hi = hi + u
Expand All @@ -187,7 +187,7 @@ end

# index of the last value of vector a that is less than or equal to x;
# returns 0 if x is less than all values of v.
function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering) where T<:Integer
function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
u = T(1)
lo = lo - u
hi = hi + u
Expand Down Expand Up @@ -221,10 +221,10 @@ function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering) where T
return a : b
end
end
return (lo + 1) : (hi - 1)
return convert(keytype(v), lo + 1) : convert(keytype(v), hi - 1)
sostock marked this conversation as resolved.
Show resolved Hide resolved
end

function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if step(a) == 0
lt(o, x, first(a)) ? 0 : length(a)
Expand All @@ -234,7 +234,7 @@ function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
end
end

function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if step(a) == 0
lt(o, first(a), x) ? length(a) + 1 : 1
Expand All @@ -244,7 +244,7 @@ function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
end
end

function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)
function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if step(a) == 0
lt(o, x, first(a)) ? 0 : length(a)
Expand All @@ -253,7 +253,7 @@ function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrderin
end
end

function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)
function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if step(a) == 0
lt(o, first(a), x) ? length(a)+1 : 1
Expand All @@ -262,7 +262,7 @@ function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrderi
end
end

function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)
function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if lt(o, first(a), x)
if step(a) == 0
Expand All @@ -275,7 +275,7 @@ function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOr
end
end

function searchsortedlast(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)
function searchsortedlast(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if lt(o, x, first(a))
0
Expand Down
6 changes: 3 additions & 3 deletions base/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ function findnext(testf::Function, s::AbstractString, i::Integer)
@inbounds i == z || isvalid(s, i) || string_index_err(s, i)
for (j, d) in pairs(SubString(s, i))
if testf(d)
return i + j - 1
return Int(i + j - 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a conversion is needed somewhere in this function, my personal preference would favor converting at the begining (i = Int(i), which might help having to compile less function specializations, like isvalid or SubString, but this also might not matter). But keep it as you prefer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I have changed it.

end
end
return nothing
Expand Down Expand Up @@ -272,7 +272,7 @@ julia> findnext("Lang", "JuliaLang", 2)
6:9
```
"""
findnext(t::AbstractString, s::AbstractString, i::Integer) = _search(s, t, i)
findnext(t::AbstractString, s::AbstractString, i::Integer) = _search(s, t, Int(i))

"""
findnext(ch::AbstractChar, string::AbstractString, start::Integer)
Expand Down Expand Up @@ -484,7 +484,7 @@ julia> findprev("Julia", "JuliaLang", 6)
1:5
```
"""
findprev(t::AbstractString, s::AbstractString, i::Integer) = _rsearch(s, t, i)
findprev(t::AbstractString, s::AbstractString, i::Integer) = _rsearch(s, t, Int(i))

"""
findprev(ch::AbstractChar, string::AbstractString, start::Integer)
Expand Down
26 changes: 26 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2333,6 +2333,32 @@ end
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
end

# issue 32568
@test findnext(!iszero, x_sp, big(4)) isa keytype(x_sp)
@test findnext(!iszero, x_sp, big(5)) isa keytype(x_sp)
@test findnext(!iszero, x_sp, UInt(4)) isa keytype(x_sp)
@test findnext(!iszero, x_sp, UInt(5)) isa keytype(x_sp)
@test findprev(!iszero, x_sp, big(5)) isa keytype(x_sp)
@test findprev(!iszero, x_sp, big(6)) isa keytype(x_sp)
@test findprev(!iszero, x_sp, UInt(5)) isa keytype(x_sp)
@test findprev(!iszero, x_sp, UInt(6)) isa keytype(x_sp)
@test findnext(iseven, x_sp, big(4)) isa keytype(x_sp)
@test findnext(iseven, x_sp, big(5)) isa keytype(x_sp)
@test findnext(iseven, x_sp, UInt(4)) isa keytype(x_sp)
@test findnext(iseven, x_sp, UInt(5)) isa keytype(x_sp)
@test findprev(iseven, x_sp, big(4)) isa keytype(x_sp)
@test findprev(iseven, x_sp, big(5)) isa keytype(x_sp)
@test findprev(iseven, x_sp, UInt(4)) isa keytype(x_sp)
@test findprev(iseven, x_sp, UInt(5)) isa keytype(x_sp)
@test findnext(!iszero, z_sp, big(4)) isa keytype(z_sp)
@test findnext(!iszero, z_sp, big(5)) isa keytype(z_sp)
@test findnext(!iszero, z_sp, UInt(4)) isa keytype(z_sp)
@test findnext(!iszero, z_sp, UInt(5)) isa keytype(z_sp)
@test findprev(!iszero, z_sp, big(4)) isa keytype(z_sp)
@test findprev(!iszero, z_sp, big(5)) isa keytype(z_sp)
@test findprev(!iszero, z_sp, UInt(4)) isa keytype(z_sp)
@test findprev(!iszero, z_sp, UInt(5)) isa keytype(z_sp)
end

# #20711
Expand Down
20 changes: 20 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,6 +561,26 @@ end
@test findlast(isequal(0x00), [0x01, 0x00]) == 2
@test findnext(isequal(0x00), [0x00, 0x01, 0x00], 2) == 3
@test findprev(isequal(0x00), [0x00, 0x01, 0x00], 2) == 1

@testset "issue 32568" begin
@test findnext(!iszero,a,big(1)) isa keytype(a)
@test findnext(!iszero,a,big(2)) isa keytype(a)
@test findnext(!iszero,a,UInt(1)) isa keytype(a)
@test findnext(!iszero,a,UInt(2)) isa keytype(a)
@test findprev(!iszero,a,big(4)) isa keytype(a)
@test findprev(!iszero,a,big(5)) isa keytype(a)
@test findprev(!iszero,a,UInt(4)) isa keytype(a)
@test findprev(!iszero,a,UInt(5)) isa keytype(a)
b = [true,false,true]
@test findnext(b,big(2)) isa keytype(b)
@test findnext(b,big(3)) isa keytype(b)
@test findnext(b,UInt(2)) isa keytype(b)
@test findnext(b,UInt(3)) isa keytype(b)
@test findprev(b,big(1)) isa keytype(b)
@test findprev(b,big(2)) isa keytype(b)
@test findprev(b,UInt(1)) isa keytype(b)
@test findprev(b,UInt(2)) isa keytype(b)
end
end
@testset "find with Matrix" begin
A = [1 2 0; 3 4 0]
Expand Down
35 changes: 35 additions & 0 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,41 @@ timesofar("find")
@test_throws BoundsError findprev(x->true, b1, 11)
@test_throws BoundsError findnext(x->true, b1, -1)

@testset "issue 32568" begin
@test findnext(evens, big(1)) isa keytype(evens)
@test findnext(evens, big(2)) isa keytype(evens)
@test findnext(evens, UInt(1)) isa keytype(evens)
@test findnext(evens, UInt(2)) isa keytype(evens)
@test findprev(evens, big(3)) isa keytype(evens)
@test findprev(evens, big(4)) isa keytype(evens)
@test findprev(evens, UInt(3)) isa keytype(evens)
@test findprev(evens, UInt(4)) isa keytype(evens)
@test findnext(iseven, evens, big(1)) isa keytype(evens)
@test findnext(iseven, evens, big(2)) isa keytype(evens)
@test findnext(iseven, evens, UInt(1)) isa keytype(evens)
@test findnext(iseven, evens, UInt(2)) isa keytype(evens)
@test findprev(iseven, evens, big(3)) isa keytype(evens)
@test findprev(iseven, evens, big(4)) isa keytype(evens)
@test findprev(iseven, evens, UInt(3)) isa keytype(evens)
@test findprev(iseven, evens, UInt(4)) isa keytype(evens)
@test findnext(isequal(true), evens, big(1)) isa keytype(evens)
@test findnext(isequal(true), evens, big(2)) isa keytype(evens)
@test findnext(isequal(true), evens, UInt(1)) isa keytype(evens)
@test findnext(isequal(true), evens, UInt(2)) isa keytype(evens)
@test findprev(isequal(true), evens, big(3)) isa keytype(evens)
@test findprev(isequal(true), evens, big(4)) isa keytype(evens)
@test findprev(isequal(true), evens, UInt(3)) isa keytype(evens)
@test findprev(isequal(true), evens, UInt(4)) isa keytype(evens)
@test findnext(isequal(false), evens, big(1)) isa keytype(evens)
@test findnext(isequal(false), evens, big(2)) isa keytype(evens)
@test findnext(isequal(false), evens, UInt(1)) isa keytype(evens)
@test findnext(isequal(false), evens, UInt(2)) isa keytype(evens)
@test findprev(isequal(false), evens, big(3)) isa keytype(evens)
@test findprev(isequal(false), evens, big(4)) isa keytype(evens)
@test findprev(isequal(false), evens, UInt(3)) isa keytype(evens)
@test findprev(isequal(false), evens, UInt(4)) isa keytype(evens)
end

for l = [1, 63, 64, 65, 127, 128, 129]
f = falses(l)
t = trues(l)
Expand Down
13 changes: 13 additions & 0 deletions test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,19 @@ end
@test searchsortedlast(500:1.0:600, -1.0e20) == 0
@test searchsortedlast(500:1.0:600, 1.0e20) == 101
end

@testset "issue 32568" begin
for R in numTypes, T in numTypes
for arr in [R[1:5;], R(1):R(5), R(1):2:R(5)]
@test eltype(searchsorted(arr, T(2))) == keytype(arr)
@test eltype(searchsorted(arr, T(2), big(1), big(4), Forward)) == keytype(arr)
@test searchsortedfirst(arr, T(2)) isa keytype(arr)
@test searchsortedfirst(arr, T(2), big(1), big(4), Forward) isa keytype(arr)
@test searchsortedlast(arr, T(2)) isa keytype(arr)
@test searchsortedlast(arr, T(2), big(1), big(4), Forward) isa keytype(arr)
end
end
end
end
# exercise the codepath in searchsorted* methods for ranges that check for zero step range
struct ConstantRange{T} <: AbstractRange{T}
Expand Down
38 changes: 38 additions & 0 deletions test/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,3 +382,41 @@ s_18109 = "fooα🐨βcd3"
@test findall("aa", "aaaaaa") == [1:2, 3:4, 5:6]
@test findall("aa", "aaaaaa", overlap=true) == [1:2, 2:3, 3:4, 4:5, 5:6]
end

# issue 32568
@test eltype(findnext(r"l", astr, big(4))) == Int
@test eltype(findnext(r"l", astr, big(5))) == Int
@test eltype(findnext(r"l", astr, UInt(4))) == Int
@test eltype(findnext(r"l", astr, UInt(5))) == Int
@test findnext(isequal('l'), astr, big(4)) isa Int
@test findnext(isequal('l'), astr, big(5)) isa Int
@test findnext(isequal('l'), astr, UInt(4)) isa Int
@test findnext(isequal('l'), astr, UInt(5)) isa Int
@test findprev(isequal('l'), astr, big(5)) isa Int
@test findprev(isequal('l'), astr, big(4)) isa Int
@test findprev(isequal('l'), astr, UInt(4)) isa Int
@test findprev(isequal('l'), astr, UInt(5)) isa Int
@test findnext('l', astr, big(4)) isa Int
@test findnext('l', astr, big(5)) isa Int
@test findnext('l', astr, UInt(4)) isa Int
@test findnext('l', astr, UInt(5)) isa Int
@test findprev('l', astr, big(4)) isa Int
@test findprev('l', astr, big(5)) isa Int
@test findprev('l', astr, UInt(4)) isa Int
@test findprev('l', astr, UInt(5)) isa Int
@test findnext(isletter, astr, big(7)) isa Int
@test findnext(isletter, astr, big(8)) isa Int
@test findnext(isletter, astr, UInt(7)) isa Int
@test findnext(isletter, astr, UInt(8)) isa Int
@test findprev(isletter, astr, big(7)) isa Int
@test findprev(isletter, astr, big(8)) isa Int
@test findprev(isletter, astr, UInt(7)) isa Int
@test findprev(isletter, astr, UInt(8)) isa Int
@test eltype(findnext(",b", "foo,bar,baz", big(7))) == Int
@test eltype(findnext(",b", "foo,bar,baz", big(8))) == Int
@test eltype(findnext(",b", "foo,bar,baz", UInt(7))) == Int
@test eltype(findnext(",b", "foo,bar,baz", UInt(8))) == Int
@test eltype(findprev(",b", "foo,bar,baz", big(5))) == Int
@test eltype(findprev(",b", "foo,bar,baz", big(6))) == Int
@test eltype(findprev(",b", "foo,bar,baz", UInt(5))) == Int
@test eltype(findprev(",b", "foo,bar,baz", UInt(6))) == Int
7 changes: 7 additions & 0 deletions test/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,13 @@ end
@test findprev(isequal(1), (1, 1), 1) == 1
@test findnext(isequal(1), (2, 3), 1) === nothing
@test findprev(isequal(1), (2, 3), 2) === nothing

@testset "issue 32568" begin
@test findnext(isequal(1), (1, 2), big(1)) isa Int
@test findprev(isequal(1), (1, 2), big(2)) isa Int
@test findnext(isequal(1), (1, 1), UInt(2)) isa Int
@test findprev(isequal(1), (1, 1), UInt(1)) isa Int
end
end

@testset "properties" begin
Expand Down