Skip to content

Commit

Permalink
Add CartesianIndices and LinearIndices
Browse files Browse the repository at this point in the history
Needed in particular to get linear indices once find() returns cartesian indices.
  • Loading branch information
nalimilan committed Jan 6, 2018
1 parent 1f40dd0 commit 8768c35
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 0 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,10 @@ Currently, the `@compat` macro supports the following syntaxes:
* `replace` accepts a pair of pattern and replacement, with the number of replacements as
a keyword argument ([#25165]).

* `CartesianIndices` and `LinearIndices` types represent cartesian and linear indices of
an array (respectively), and indexing such objects allows translating from one kind of index
to the other ([#25113]).

## Renaming


Expand Down Expand Up @@ -431,3 +435,4 @@ includes this fix. Find the minimum version from there.
[#25162]: https://github.com/JuliaLang/julia/issues/25162
[#25165]: https://github.com/JuliaLang/julia/issues/25165
[#25168]: https://github.com/JuliaLang/julia/issues/25168
[#25113]: https://github.com/JuliaLang/julia/issues/25113
126 changes: 126 additions & 0 deletions src/Compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1070,6 +1070,132 @@ end
replace(s, first(pat_rep), last(pat_rep), count)
end

@static if VERSION < v"0.7.0-DEV.3025"
import Base: convert, ndims, getindex, size, length, eltype,
start, next, done, first, last
export CartesianIndices, LinearIndices

struct CartesianIndices{N,R<:NTuple{N,AbstractUnitRange{Int}}} <: AbstractArray{CartesianIndex{N},N}
indices::R
end

CartesianIndices(::Tuple{}) = CartesianIndices{0,typeof(())}(())
CartesianIndices(inds::NTuple{N,AbstractUnitRange{Int}}) where {N} =
CartesianIndices{N,typeof(inds)}(inds)
CartesianIndices(inds::Vararg{AbstractUnitRange{Int},N}) where {N} =
CartesianIndices(inds)
CartesianIndices(inds::NTuple{N,AbstractUnitRange{<:Integer}}) where {N} =
CartesianIndices(map(r->convert(AbstractUnitRange{Int}, r), inds))
CartesianIndices(inds::Vararg{AbstractUnitRange{<:Integer},N}) where {N} =
CartesianIndices(inds)

CartesianIndices(index::CartesianIndex) = CartesianIndices(index.I)
CartesianIndices(sz::NTuple{N,<:Integer}) where {N} = CartesianIndices(map(Base.OneTo, sz))
CartesianIndices(inds::NTuple{N,Union{<:Integer,AbstractUnitRange{<:Integer}}}) where {N} =
CartesianIndices(map(i->first(i):last(i), inds))

CartesianIndices(A::AbstractArray) = CartesianIndices(axes(A))

convert(::Type{Tuple{}}, R::CartesianIndices{0}) = ()
convert(::Type{NTuple{N,AbstractUnitRange{Int}}}, R::CartesianIndices{N}) where {N} =
R.indices

convert(::Type{NTuple{N,AbstractUnitRange}}, R::CartesianIndices{N}) where {N} =
convert(NTuple{N,AbstractUnitRange{Int}}, R)
convert(::Type{NTuple{N,UnitRange{Int}}}, R::CartesianIndices{N}) where {N} =
UnitRange{Int}.(convert(NTuple{N,AbstractUnitRange}, R))
convert(::Type{NTuple{N,UnitRange}}, R::CartesianIndices{N}) where {N} =
UnitRange.(convert(NTuple{N,AbstractUnitRange}, R))
convert(::Type{Tuple{Vararg{AbstractUnitRange{Int}}}}, R::CartesianIndices{N}) where {N} =
convert(NTuple{N,AbstractUnitRange{Int}}, R)
convert(::Type{Tuple{Vararg{AbstractUnitRange}}}, R::CartesianIndices) =
convert(Tuple{Vararg{AbstractUnitRange{Int}}}, R)
convert(::Type{Tuple{Vararg{UnitRange{Int}}}}, R::CartesianIndices{N}) where {N} =
convert(NTuple{N,UnitRange{Int}}, R)
convert(::Type{Tuple{Vararg{UnitRange}}}, R::CartesianIndices) =
convert(Tuple{Vararg{UnitRange{Int}}}, R)

# AbstractArray implementation
Base.IndexStyle(::Type{CartesianIndices{N,R}}) where {N,R} = IndexCartesian()
@inline Base.getindex(iter::CartesianIndices{N,R}, I::Vararg{Int, N}) where {N,R} = CartesianIndex(first.(iter.indices) .- 1 .+ I)

ndims(R::CartesianIndices) = ndims(typeof(R))
ndims(::Type{CartesianIndices{N}}) where {N} = N
ndims(::Type{CartesianIndices{N,TT}}) where {N,TT} = N

eltype(R::CartesianIndices) = eltype(typeof(R))
eltype(::Type{CartesianIndices{N}}) where {N} = CartesianIndex{N}
eltype(::Type{CartesianIndices{N,TT}}) where {N,TT} = CartesianIndex{N}
Base.iteratorsize(::Type{<:CartesianIndices}) = Base.HasShape()

@inline function start(iter::CartesianIndices)
iterfirst, iterlast = first(iter), last(iter)
if any(map(>, iterfirst.I, iterlast.I))
return iterlast+1
end
iterfirst
end
@inline function next(iter::CartesianIndices, state)
state, CartesianIndex(inc(state.I, first(iter).I, last(iter).I))
end
# increment & carry
@inline inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
@inline inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]+1,)
@inline function inc(state, start, stop)
if state[1] < stop[1]
return (state[1]+1,Base.tail(state)...)
end
newtail = inc(Base.tail(state), Base.tail(start), Base.tail(stop))
(start[1], newtail...)
end
@inline done(iter::CartesianIndices, state) = state.I[end] > last(iter.indices[end])

# 0-d cartesian ranges are special-cased to iterate once and only once
start(iter::CartesianIndices{0}) = false
next(iter::CartesianIndices{0}, state) = CartesianIndex(), true
done(iter::CartesianIndices{0}, state) = state

size(iter::CartesianIndices) = map(dimlength, first(iter).I, last(iter).I)
dimlength(start, stop) = stop-start+1

length(iter::CartesianIndices) = prod(size(iter))

first(iter::CartesianIndices) = CartesianIndex(map(first, iter.indices))
last(iter::CartesianIndices) = CartesianIndex(map(last, iter.indices))

@inline function in(i::CartesianIndex{N}, r::CartesianIndices{N}) where {N}
_in(true, i.I, first(r).I, last(r).I)
end
_in(b, ::Tuple{}, ::Tuple{}, ::Tuple{}) = b
@inline _in(b, i, start, stop) = _in(b & (start[1] <= i[1] <= stop[1]), tail(i), tail(start), tail(stop))

struct LinearIndices{N,R<:NTuple{N,AbstractUnitRange{Int}}} <: AbstractArray{Int,N}
indices::R
end

LinearIndices(inds::CartesianIndices{N,R}) where {N,R} = LinearIndices{N,R}(inds.indices)
LinearIndices(::Tuple{}) = LinearIndices(CartesianIndices(()))
LinearIndices(inds::NTuple{N,AbstractUnitRange{Int}}) where {N} = LinearIndices(CartesianIndices(inds))
LinearIndices(inds::Vararg{AbstractUnitRange{Int},N}) where {N} = LinearIndices(CartesianIndices(inds))
LinearIndices(inds::NTuple{N,AbstractUnitRange{<:Integer}}) where {N} = LinearIndices(CartesianIndices(inds))
LinearIndices(inds::Vararg{AbstractUnitRange{<:Integer},N}) where {N} = LinearIndices(CartesianIndices(inds))
LinearIndices(index::CartesianIndex) = LinearIndices(CartesianIndices(index))
LinearIndices(sz::NTuple{N,<:Integer}) where {N} = LinearIndices(CartesianIndices(sz))
LinearIndices(inds::NTuple{N,Union{<:Integer,AbstractUnitRange{<:Integer}}}) where {N} = LinearIndices(CartesianIndices(inds))
LinearIndices(A::AbstractArray) = LinearIndices(CartesianIndices(A))

# AbstractArray implementation
Base.IndexStyle(::Type{LinearIndices{N,R}}) where {N,R} = IndexCartesian()
Compat.axes(iter::LinearIndices{N,R}) where {N,R} = iter.indices
@inline function Base.getindex(iter::LinearIndices{N,R}, I::Vararg{Int, N}) where {N,R}
dims = length.(iter.indices)
#without the inbounds, this is slower than Base._sub2ind(iter.indices, I...)
@inbounds result = reshape(1:prod(dims), dims)[(I .- first.(iter.indices) .+ 1)...]
return result
end
end


include("deprecated.jl")

end # module Compat
17 changes: 17 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,23 @@ end
@test replace("abcb", "b"=>"c") == "accc"
@test replace("abcb", "b"=>"c", count=1) == "accb"

# 0.7.0-DEV.3025
let c = CartesianIndices(1:3, 1:2), l = LinearIndices(1:3, 1:2)
@test LinearIndices(c) == collect(l)
@test CartesianIndices(l) == collect(c)
@test first(c) == CartesianIndex(1, 1)
@test CartesianIndex(1, 1) in c
@test first(l) == 1
@test size(c) == (3, 2)
@test c == collect(c) == [CartesianIndex(1, 1) CartesianIndex(1, 2)
CartesianIndex(2, 1) CartesianIndex(2, 2)
CartesianIndex(3, 1) CartesianIndex(3, 2)]
@test l == collect(l) == reshape(1:6, 3, 2)
@test c[1:6] == vec(c)
@test l == l[c] == map(i -> l[i], c)
@test l[vec(c)] == collect(1:6)
end

if VERSION < v"0.6.0"
include("deprecated.jl")
end
Expand Down

0 comments on commit 8768c35

Please sign in to comment.