Skip to content

Commit

Permalink
Add 2-arg versions of findmax/min, argmax/min
Browse files Browse the repository at this point in the history
Fixes JuliaLang#27613. Related: JuliaLang#27639, JuliaLang#27612, JuliaLang#34674.

Thanks to @tkf, @StefanKarpinski and @drewrobson for their assistance
with this PR.
  • Loading branch information
cmcaine committed Dec 17, 2020
1 parent b2f1111 commit 8767109
Show file tree
Hide file tree
Showing 4 changed files with 229 additions and 134 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ New library functions
---------------------

* New function `isgreater(a, b)` defines a descending total order where unorderable values and missing are ordered smaller than any regular value.
* Two argument methods `findmax(f, domain)`, `argmax(f, domain)` and the corresponding `min` versions ([#27613]).

New library features
--------------------
Expand Down
134 changes: 0 additions & 134 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2201,140 +2201,6 @@ findall(x::Bool) = x ? [1] : Vector{Int}()
findall(testf::Function, x::Number) = testf(x) ? [1] : Vector{Int}()
findall(p::Fix2{typeof(in)}, x::Number) = x in p.x ? [1] : Vector{Int}()

"""
findmax(itr) -> (x, index)
Return the maximum element of the collection `itr` and its index or key.
If there are multiple maximal elements, then the first one will be returned.
If any data element is `NaN`, this element is returned.
The result is in line with `max`.
The collection must not be empty.
# Examples
```jldoctest
julia> findmax([8,0.1,-9,pi])
(8.0, 1)
julia> findmax([1,7,7,6])
(7, 2)
julia> findmax([1,7,7,NaN])
(NaN, 4)
```
"""
findmax(a) = _findmax(a, :)

function _findmax(a, ::Colon)
p = pairs(a)
y = iterate(p)
if y === nothing
throw(ArgumentError("collection must be non-empty"))
end
(mi, m), s = y
i = mi
while true
y = iterate(p, s)
y === nothing && break
m != m && break
(i, ai), s = y
if ai != ai || isless(m, ai)
m = ai
mi = i
end
end
return (m, mi)
end

"""
findmin(itr) -> (x, index)
Return the minimum element of the collection `itr` and its index or key.
If there are multiple minimal elements, then the first one will be returned.
If any data element is `NaN`, this element is returned.
The result is in line with `min`.
The collection must not be empty.
# Examples
```jldoctest
julia> findmin([8,0.1,-9,pi])
(-9.0, 3)
julia> findmin([7,1,1,6])
(1, 2)
julia> findmin([7,1,1,NaN])
(NaN, 4)
```
"""
findmin(a) = _findmin(a, :)

function _findmin(a, ::Colon)
p = pairs(a)
y = iterate(p)
if y === nothing
throw(ArgumentError("collection must be non-empty"))
end
(mi, m), s = y
i = mi
while true
y = iterate(p, s)
y === nothing && break
m != m && break
(i, ai), s = y
if ai != ai || isless(ai, m)
m = ai
mi = i
end
end
return (m, mi)
end

"""
argmax(itr)
Return the index or key of the maximum element in a collection.
If there are multiple maximal elements, then the first one will be returned.
The collection must not be empty.
# Examples
```jldoctest
julia> argmax([8,0.1,-9,pi])
1
julia> argmax([1,7,7,6])
2
julia> argmax([1,7,7,NaN])
4
```
"""
argmax(a) = findmax(a)[2]

"""
argmin(itr)
Return the index or key of the minimum element in a collection.
If there are multiple minimal elements, then the first one will be returned.
The collection must not be empty.
# Examples
```jldoctest
julia> argmin([8,0.1,-9,pi])
3
julia> argmin([7,1,1,6])
2
julia> argmin([7,1,1,NaN])
4
```
"""
argmin(a) = findmin(a)[2]

# similar to Matlab's ismember
"""
indexin(a, b)
Expand Down
196 changes: 196 additions & 0 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,202 @@ Inf
"""
minimum(a; kw...) = mapreduce(identity, min, a; kw...)

## findmax, findmin, argmax & argmin

"""
findmax(f, domain) -> (f(x), x)
findmax(f)
Returns a pair of a value in the codomain (outputs of `f`) and the corresponding
value in the `domain` (inputs to `f`) such that `f(x)` is maximised. If there
are multiple maximal points, then the first one will be returned.
When `domain` is provided it may be any iterable and must not be empty.
When `domain` is omitted, `f` must have an implicit domain. In particular, if
`f` is an indexable collection, it is interpreted as a function mapping keys
(domain) to values (codomain), i.e. `findmax(itr)` returns the maximal element
of the collection `itr` and its index.
Values are compared with `isless`.
# Examples
```jldoctest
julia> findmax(identity, 5:9)
(9, 9)
julia> findmax(-, 1:10)
(-1, 1)
julia> findmax(cos, 0:π/2:2π)
(1.0, 0.0)
julia> findmax([8,0.1,-9,pi])
(8.0, 1)
julia> findmax([1,7,7,6])
(7, 2)
julia> findmax([1,7,7,NaN])
(NaN, 4)
```
"""
findmax(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmax, domain)
_rf_findmax((fm, m), (fx, x)) = isless(fm, fx) ? (fx, x) : (fm, m)

"""
findmin(f, domain) -> (f(x), x)
findmin(f)
Returns a pair of a value in the codomain (outputs of `f`) and the corresponding
value in the `domain` (inputs to `f`) such that `f(x)` is minimised. If there
are multiple minimal points, then the first one will be returned.
When `domain` is provided it may be any iterable and must not be empty.
When `domain` is omitted, `f` must have an implicit domain. In particular, if
`f` is an indexable collection, it is interpreted as a function mapping keys
(domain) to values (codomain), i.e. `findmin(itr)` returns the minimal element
of the collection `itr` and its index.
Values are compared with `isgreater`.
# Examples
```jldoctest
julia> findmin(identity, 5:9)
(5, 5)
julia> findmin(-, 1:10)
(-10, 10)
julia> findmin(cos, 0:π/2:2π)
(-1.0, 3.141592653589793)
julia> findmin([8,0.1,-9,pi])
(-9, 3)
julia> findmin([1,7,7,6])
(1, 1)
julia> findmin([1,7,7,NaN])
(NaN, 4)
```
"""
findmin(f, domain) = mapfoldl(x -> (f(x), x), _rf_findmin, domain)
_rf_findmin((fm, m), (fx, x)) = isgreater(fm, fx) ? (fx, x) : (fm, m)

findmax(a) = _findmax(a, :)

function _findmax(a, ::Colon)
p = pairs(a)
y = iterate(p)
if y === nothing
throw(ArgumentError("collection must be non-empty"))
end
(mi, m), s = y
i = mi
while true
y = iterate(p, s)
y === nothing && break
m != m && break
(i, ai), s = y
if ai != ai || isless(m, ai)
m = ai
mi = i
end
end
return (m, mi)
end

findmin(a) = _findmin(a, :)

function _findmin(a, ::Colon)
p = pairs(a)
y = iterate(p)
if y === nothing
throw(ArgumentError("collection must be non-empty"))
end
(mi, m), s = y
i = mi
while true
y = iterate(p, s)
y === nothing && break
m != m && break
(i, ai), s = y
if ai != ai || isless(ai, m)
m = ai
mi = i
end
end
return (m, mi)
end

"""
argmax(f, domain)
argmax(f)
Return a value `x` in the domain of `f` for which `f(x)` is maximised.
If there are multiple maximal values for `f(x)` then the first one will be found.
When `domain` is provided it may be any iterable and must not be empty.
When `domain` is omitted, `f` must have an implicit domain. In particular, if
`f` is an indexable collection, it is interpreted as a function mapping keys
(domain) to values (codomain), i.e. `argmax(itr)` returns the index of the
maximal element in `itr`.
Values are compared with `isless`.
# Examples
```jldoctest
julia> argmax([8,0.1,-9,pi])
1
julia> argmax([1,7,7,6])
2
julia> argmax([1,7,7,NaN])
4
```
"""
argmax(f, domain) = findmax(f, domain)[2]
argmax(f) = findmax(f)[2]

"""
argmin(f, domain)
argmin(f)
Return a value `x` in the domain of `f` for which `f(x)` is minimised.
If there are multiple minimal values for `f(x)` then the first one will be found.
When `domain` is provided it may be any iterable and must not be empty.
When `domain` is omitted, `f` must have an implicit domain. In particular, if
`f` is an indexable collection, it is interpreted as a function mapping keys
(domain) to values (codomain), i.e. `argmin(itr)` returns the index of the
minimal element in `itr`.
Values are compared with `isgreater`.
# Examples
```jldoctest
julia> argmin([8,0.1,-9,pi])
3
julia> argmin([7,1,1,6])
2
julia> argmin([7,1,1,NaN])
4
```
"""
argmin(f, domain) = findmin(f, domain)[2]
argmin(f) = findmin(f)[2]

## all & any

"""
Expand Down
32 changes: 32 additions & 0 deletions test/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,38 @@ A = circshift(reshape(1:24,2,3,4), (0,1,1))
end
end

# findmin, findmax, argmin, argmax

@testset "findmin(f, domain)" begin
@test findmin(-, 1:10) == (-10, 10)
@test findmin(identity, [1, 2, 3, missing]) === (missing, missing)
@test findmin(identity, [1, NaN, 3, missing]) === (missing, missing)
@test findmin(identity, [1, missing, NaN, 3]) === (missing, missing)
@test findmin(identity, [1, NaN, 3]) === (NaN, NaN)
@test findmin(identity, [1, 3, NaN]) === (NaN, NaN)
@test all(findmin(cos, 0:π/2:2π) .≈ (-1.0, π))
end

@testset "findmax(f, domain)" begin
@test findmax(-, 1:10) == (-1, 1)
@test findmax(identity, [1, 2, 3, missing]) === (missing, missing)
@test findmax(identity, [1, NaN, 3, missing]) === (missing, missing)
@test findmax(identity, [1, missing, NaN, 3]) === (missing, missing)
@test findmax(identity, [1, NaN, 3]) === (NaN, NaN)
@test findmax(identity, [1, 3, NaN]) === (NaN, NaN)
@test findmax(cos, 0:π/2:2π) == (1.0, 0.0)
end

@testset "argmin(f, domain)" begin
@test argmin(-, 1:10) == 10
@test argmin(sum, Iterators.product(1:5, 1:5)) == (1, 1)
end

@testset "argmax(f, domain)" begin
@test argmax(-, 1:10) == 1
@test argmax(sum, Iterators.product(1:5, 1:5)) == (5, 5)
end

# any & all

@test @inferred any([]) == false
Expand Down

0 comments on commit 8767109

Please sign in to comment.