Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make searchsorted*/findnext/findprev return values of keytype #32978

Merged
merged 13 commits into from
Apr 28, 2020
8 changes: 4 additions & 4 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1653,7 +1653,7 @@ CartesianIndex(2, 1)
"""
function findnext(A, start)
l = last(keys(A))
i = start
i = oftype(l, start)
i > l && return nothing
while true
A[i] && return i
Expand Down Expand Up @@ -1735,7 +1735,7 @@ CartesianIndex(1, 1)
"""
function findnext(testf::Function, A, start)
l = last(keys(A))
i = start
i = oftype(l, start)
i > l && return nothing
while true
testf(A[i]) && return i
Expand Down Expand Up @@ -1839,8 +1839,8 @@ CartesianIndex(2, 1)
```
"""
function findprev(A, start)
i = start
f = first(keys(A))
i = oftype(f, start)
i < f && return nothing
while true
A[i] && return i
Expand Down Expand Up @@ -1930,8 +1930,8 @@ CartesianIndex(2, 1)
```
"""
function findprev(testf::Function, A, start)
i = start
f = first(keys(A))
i = oftype(f, start)
i < f && return nothing
while true
testf(A[i]) && return i
Expand Down
8 changes: 5 additions & 3 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ end

count(B::BitArray) = bitcount(B.chunks)

function unsafe_bitfindnext(Bc::Vector{UInt64}, start::Integer)
function unsafe_bitfindnext(Bc::Vector{UInt64}, start::Int)
chunk_start = _div64(start-1)+1
within_chunk_start = _mod64(start-1)
mask = _msk64 << within_chunk_start
Expand All @@ -1397,13 +1397,14 @@ end
function findnext(B::BitArray, start::Integer)
start > 0 || throw(BoundsError(B, start))
start > length(B) && return nothing
unsafe_bitfindnext(B.chunks, start)
unsafe_bitfindnext(B.chunks, Int(start))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unsafe_bitfindnext (and prev, too) accepts a start::Integer. I was about to make a comment that we should just move this there and/or restrict its signature, but then I realized it's also used by BitSet. BitSet demands an Int64 result, whereas its use for BitArray demands an Int result. So this seems like the right way to go about it. 👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it changed, but my impression is that BitSet calls unsafe_bitfindnext by feeding it an Int, not an Int64, so I would also favor restricting unsafe_bitfindnext to Int input, as this is an internal function. But the PR is also fine as is!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now restricted unsafe_bitfindnext to Int.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, sorry for the nitpick, but please change also unsafe_bitfindprev accordingly too :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was fast!

end

#findfirst(B::BitArray) = findnext(B, 1) ## defined in array.jl

# aux function: same as findnext(~B, start), but performed without temporaries
function findnextnot(B::BitArray, start::Integer)
start = Int(start)
start > 0 || throw(BoundsError(B, start))
start > length(B) && return nothing

Expand Down Expand Up @@ -1480,10 +1481,11 @@ end
function findprev(B::BitArray, start::Integer)
start > 0 || return nothing
start > length(B) && throw(BoundsError(B, start))
unsafe_bitfindprev(B.chunks, start)
unsafe_bitfindprev(B.chunks, Int(start))
end

function findprevnot(B::BitArray, start::Integer)
start = Int(start)
start > 0 || return nothing
start > length(B) && throw(BoundsError(B, start))

Expand Down
24 changes: 12 additions & 12 deletions base/sort.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using .Base: copymutable, LinearIndices, length, (:),
AbstractVector, @inbounds, AbstractRange, @eval, @inline, Vector, @noinline,
AbstractMatrix, AbstractUnitRange, isless, identity, eltype, >, <, <=, >=, |, +, -, *, !,
extrema, sub_with_overflow, add_with_overflow, oneunit, div, getindex, setindex!,
length, resize!, fill, Missing, require_one_based_indexing
length, resize!, fill, Missing, require_one_based_indexing, keytype

using .Base: >>>, !==

Expand Down Expand Up @@ -174,7 +174,7 @@ midpoint(lo::Integer, hi::Integer) = midpoint(promote(lo, hi)...)

# index of the first value of vector a that is greater than or equal to x;
# returns length(v)+1 if x is greater than all values in v.
function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering) where T<:Integer
function searchsortedfirst(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
u = T(1)
lo = lo - u
hi = hi + u
Expand All @@ -191,7 +191,7 @@ end

# index of the last value of vector a that is less than or equal to x;
# returns 0 if x is less than all values of v.
function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering) where T<:Integer
function searchsortedlast(v::AbstractVector, x, lo::T, hi::T, o::Ordering)::keytype(v) where T<:Integer
u = T(1)
lo = lo - u
hi = hi + u
Expand All @@ -209,7 +209,7 @@ end
# returns the range of indices of v equal to x
# if v does not contain x, returns a 0-length range
# indicating the insertion point of x
function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering) where T<:Integer
function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering)::UnitRange{keytype(v)} where T<:Integer
u = T(1)
lo = ilo - u
hi = ihi + u
Expand All @@ -228,7 +228,7 @@ function searchsorted(v::AbstractVector, x, ilo::T, ihi::T, o::Ordering) where T
return (lo + 1) : (hi - 1)
end

function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if step(a) == 0
lt(o, x, first(a)) ? 0 : length(a)
Expand All @@ -238,7 +238,7 @@ function searchsortedlast(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
end
end

function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if step(a) == 0
lt(o, first(a), x) ? length(a) + 1 : 1
Expand All @@ -248,7 +248,7 @@ function searchsortedfirst(a::AbstractRange{<:Real}, x::Real, o::DirectOrdering)
end
end

function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)
function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
h = step(a)
if h == 0
Expand All @@ -270,7 +270,7 @@ function searchsortedlast(a::AbstractRange{<:Integer}, x::Real, o::DirectOrderin
end
end

function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)
function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
h = step(a)
if h == 0
Expand All @@ -285,14 +285,14 @@ function searchsortedfirst(a::AbstractRange{<:Integer}, x::Real, o::DirectOrderi
lastindex(a) + 1
else
if o isa ForwardOrdering
-fld(floor(Integer, -x) + first(a), h) + 1
-fld(floor(Integer, -x) + Signed(first(a)), h) + 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the Signed here is for the case where x is not an Integer ? Just by curiosity, without this conversion here, the type conversion attached to the return value of the function (::keytype(a)) is not enough?

Copy link
Contributor Author

@sostock sostock Apr 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is for the case when x is Signed and first(a) is Unsigned. For example, if floor(Integer, -x) == Int64(-5) and first(a) == UInt64(1), adding them yields the huge number 0xfffffffffffffffc (the negative Int64 gets promoted to UInt64). The result of the complete line is then a UInt64 that is too large to convert to keytype(a).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah right I was off as Signed is not applied to x, and ok, the ::keytype(a) fails then, thanks.

else
-fld(ceil(Integer, -x) + first(a), h) + 1
-fld(ceil(Integer, -x) + Signed(first(a)), h) + 1
end
end
end

function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)
function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if lt(o, first(a), x)
if step(a) == 0
Expand All @@ -305,7 +305,7 @@ function searchsortedfirst(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOr
end
end

function searchsortedlast(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)
function searchsortedlast(a::AbstractRange{<:Integer}, x::Unsigned, o::DirectOrdering)::keytype(a)
require_one_based_indexing(a)
if lt(o, x, first(a))
0
Expand Down
5 changes: 3 additions & 2 deletions base/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ findfirst(ch::AbstractChar, string::AbstractString) = findfirst(==(ch), string)

# AbstractString implementation of the generic findnext interface
function findnext(testf::Function, s::AbstractString, i::Integer)
i = Int(i)
z = ncodeunits(s) + 1
1 ≤ i ≤ z || throw(BoundsError(s, i))
@inbounds i == z || isvalid(s, i) || string_index_err(s, i)
Expand Down Expand Up @@ -272,7 +273,7 @@ julia> findnext("Lang", "JuliaLang", 2)
6:9
```
"""
findnext(t::AbstractString, s::AbstractString, i::Integer) = _search(s, t, i)
findnext(t::AbstractString, s::AbstractString, i::Integer) = _search(s, t, Int(i))

"""
findnext(ch::AbstractChar, string::AbstractString, start::Integer)
Expand Down Expand Up @@ -484,7 +485,7 @@ julia> findprev("Julia", "JuliaLang", 6)
1:5
```
"""
findprev(t::AbstractString, s::AbstractString, i::Integer) = _rsearch(s, t, i)
findprev(t::AbstractString, s::AbstractString, i::Integer) = _rsearch(s, t, Int(i))

"""
findprev(ch::AbstractChar, string::AbstractString, start::Integer)
Expand Down
16 changes: 16 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,22 @@ end
@test findnext(!iszero, z,i) == findnext(!iszero, z_sp,i)
@test findprev(!iszero, z,i) == findprev(!iszero, z_sp,i)
end

# issue 32568
for T = (UInt, BigInt)
@test findnext(!iszero, x_sp, T(4)) isa keytype(x_sp)
@test findnext(!iszero, x_sp, T(5)) isa keytype(x_sp)
@test findprev(!iszero, x_sp, T(5)) isa keytype(x_sp)
@test findprev(!iszero, x_sp, T(6)) isa keytype(x_sp)
@test findnext(iseven, x_sp, T(4)) isa keytype(x_sp)
@test findnext(iseven, x_sp, T(5)) isa keytype(x_sp)
@test findprev(iseven, x_sp, T(4)) isa keytype(x_sp)
@test findprev(iseven, x_sp, T(5)) isa keytype(x_sp)
@test findnext(!iszero, z_sp, T(4)) isa keytype(z_sp)
@test findnext(!iszero, z_sp, T(5)) isa keytype(z_sp)
@test findprev(!iszero, z_sp, T(4)) isa keytype(z_sp)
@test findprev(!iszero, z_sp, T(5)) isa keytype(z_sp)
end
end

# #20711
Expand Down
12 changes: 12 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,18 @@ end
@test findlast(isequal(0x00), [0x01, 0x00]) == 2
@test findnext(isequal(0x00), [0x00, 0x01, 0x00], 2) == 3
@test findprev(isequal(0x00), [0x00, 0x01, 0x00], 2) == 1

@testset "issue 32568" for T = (UInt, BigInt)
@test findnext(!iszero, a, T(1)) isa keytype(a)
@test findnext(!iszero, a, T(2)) isa keytype(a)
@test findprev(!iszero, a, T(4)) isa keytype(a)
@test findprev(!iszero, a, T(5)) isa keytype(a)
b = [true, false, true]
@test findnext(b, T(2)) isa keytype(b)
@test findnext(b, T(3)) isa keytype(b)
@test findprev(b, T(1)) isa keytype(b)
@test findprev(b, T(2)) isa keytype(b)
end
end
@testset "find with Matrix" begin
A = [1 2 0; 3 4 0]
Expand Down
15 changes: 15 additions & 0 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1307,6 +1307,21 @@ timesofar("find")
@test_throws BoundsError findprev(x->true, b1, 11)
@test_throws BoundsError findnext(x->true, b1, -1)

@testset "issue 32568" for T = (UInt, BigInt)
for x = (1, 2)
@test findnext(evens, T(x)) isa keytype(evens)
@test findnext(iseven, evens, T(x)) isa keytype(evens)
@test findnext(isequal(true), evens, T(x)) isa keytype(evens)
@test findnext(isequal(false), evens, T(x)) isa keytype(evens)
end
for x = (3, 4)
@test findprev(evens, T(x)) isa keytype(evens)
@test findprev(iseven, evens, T(x)) isa keytype(evens)
@test findprev(isequal(true), evens, T(x)) isa keytype(evens)
@test findprev(isequal(false), evens, T(x)) isa keytype(evens)
end
end

for l = [1, 63, 64, 65, 127, 128, 129]
f = falses(l)
t = trues(l)
Expand Down
15 changes: 14 additions & 1 deletion test/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ end
@test searchsortedlast(500:1.0:600, -1.0e20) == 0
@test searchsortedlast(500:1.0:600, 1.0e20) == 101
end

@testset "issue 32568" begin
for R in numTypes, T in numTypes
for arr in [R[1:5;], R(1):R(5), R(1):2:R(5)]
@test eltype(searchsorted(arr, T(2))) == keytype(arr)
@test eltype(searchsorted(arr, T(2), big(1), big(4), Forward)) == keytype(arr)
@test searchsortedfirst(arr, T(2)) isa keytype(arr)
@test searchsortedfirst(arr, T(2), big(1), big(4), Forward) isa keytype(arr)
@test searchsortedlast(arr, T(2)) isa keytype(arr)
@test searchsortedlast(arr, T(2), big(1), big(4), Forward) isa keytype(arr)
end
end
end

@testset "issue #34157" begin
@test searchsorted(1:2.0, -Inf) === 1:0
@test searchsorted([1,2], -Inf) === 1:0
Expand Down Expand Up @@ -173,7 +187,6 @@ end
@test searchsortedlast(reverse(coll), -huge, rev=true) === lastindex(coll)
@test searchsorted(reverse(coll), huge, rev=true) === firstindex(coll):firstindex(coll) - 1
@test searchsorted(reverse(coll), -huge, rev=true) === lastindex(coll)+1:lastindex(coll)

end
end
end
Expand Down
19 changes: 19 additions & 0 deletions test/strings/search.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,22 @@ s_18109 = "fooα🐨βcd3"
@test findall("aa", "aaaaaa") == [1:2, 3:4, 5:6]
@test findall("aa", "aaaaaa", overlap=true) == [1:2, 2:3, 3:4, 4:5, 5:6]
end

# issue 32568
for T = (UInt, BigInt)
for x = (4, 5)
@test eltype(findnext(r"l", astr, T(x))) == Int
@test findnext(isequal('l'), astr, T(x)) isa Int
@test findprev(isequal('l'), astr, T(x)) isa Int
@test findnext('l', astr, T(x)) isa Int
@test findprev('l', astr, T(x)) isa Int
end
for x = (5, 6)
@test eltype(findprev(",b", "foo,bar,baz", T(x))) == Int
end
for x = (7, 8)
@test eltype(findnext(",b", "foo,bar,baz", T(x))) == Int
@test findnext(isletter, astr, T(x)) isa Int
@test findprev(isletter, astr, T(x)) isa Int
end
end
7 changes: 7 additions & 0 deletions test/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,13 @@ end
@test findprev(isequal(1), (1, 1), 1) == 1
@test findnext(isequal(1), (2, 3), 1) === nothing
@test findprev(isequal(1), (2, 3), 2) === nothing

@testset "issue 32568" begin
@test findnext(isequal(1), (1, 2), big(1)) isa Int
@test findprev(isequal(1), (1, 2), big(2)) isa Int
@test findnext(isequal(1), (1, 1), UInt(2)) isa Int
@test findprev(isequal(1), (1, 1), UInt(1)) isa Int
end
end

@testset "properties" begin
Expand Down