Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some work on supporting non-standard indexing #44

Merged
merged 8 commits into from
Feb 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "HybridArrays"
uuid = "1baab800-613f-4b0a-84e4-9cd3431bfbb9"
authors = ["Mateusz Baran <mateuszbaran89@gmail.com>"]
version = "0.4.6"
version = "0.4.7"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
26 changes: 21 additions & 5 deletions src/HybridArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import Base: convert,
import Base.Array

using StaticArrays
using StaticArrays: Dynamic
using StaticArrays: Dynamic, StaticIndexing
import StaticArrays: _setindex!_scalar, Size

using LinearAlgebra
Expand Down Expand Up @@ -55,11 +55,10 @@ end

destatizing = (inds[i] <: AbstractArray && !(
inds[i] <: StaticArray ||
inds[i] <: Base.Slice{<:StaticArray} ||
inds[i] <: SOneTo ||
inds[i] <: Base.Slice{<:SOneTo}))
inds[i] <: Base.Slice ||
inds[i] <: SOneTo))

nonstatizing = inds[i] == Colon || destatizing
nonstatizing = inds[i] == Colon || inds[i] <: Base.Slice || destatizing

if destatizing || (isa(param, Dynamic) && nonstatizing)
all_fixed = false
Expand All @@ -75,6 +74,23 @@ end
end
end

function has_dynamic(::Type{Size}) where Size<:Tuple
for param in Size.parameters
if isa(param, Dynamic)
return true
end
end
return false
end

@generated function all_dynamic_fixed_val(::Type{Size}, inds::Union{Colon, Base.Slice}) where Size<:Tuple
if has_dynamic(Size)
return Val(:dynamic_fixed_false)
else
return Val(:dynamic_fixed_true)
end
end

@generated function tuple_nodynamic_prod(::Type{Size}) where Size<:Tuple
i = 1
for s ∈ Size.parameters
Expand Down
63 changes: 44 additions & 19 deletions src/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
@inline function getindex(sa::HybridArray{S}, ::Colon) where S
return HybridArray{S}(getindex(sa.data, :))
end

Base.@propagate_inbounds function getindex(sa::HybridArray{S}, inds::Int...) where S
return getindex(sa.data, inds...)
end
Expand All @@ -10,39 +6,52 @@ Base.@propagate_inbounds function getindex(sa::HybridArray{S}, inds::Union{Int,
_getindex(all_dynamic_fixed_val(S, inds...), sa, inds...)
end

Base.@propagate_inbounds function _getindex(::Val{:dynamic_fixed_true}, sa::HybridArray, inds::Union{Int, StaticArray{<:Tuple, Int}, Colon}...)
# This plugs into a deeper level of indexing in base to catch custom
# indexing schemes based on `to_indices`.
# A minor version Julia release could potentially break this (though it seems unlikely).
Base.@propagate_inbounds function Base._getindex(l::IndexLinear, sa::HybridArray{S}, inds::Int...) where S
return Base._getindex(l, sa.data, inds...)
end
Base.@propagate_inbounds function Base._getindex(::IndexLinear, sa::HybridArray{S}, inds::Union{Int, StaticVector, Colon, Base.Slice}...) where S
_getindex(all_dynamic_fixed_val(S, inds...), sa, inds...)
end

Base.@propagate_inbounds function _getindex(::Val{:dynamic_fixed_true}, sa::HybridArray, inds::Union{Int, StaticVector, Colon, Base.Slice}...)
return _getindex_all_static(sa, inds...)
end

function _get_indices(i::Tuple{}, j::Int)
return ()
end

function _get_indices(i::Tuple, j::Int, i1::Type{Int}, inds...)
function _get_indices(i::Tuple, j::Int, ::Type{Int}, inds...)
return (:(inds[$j]), _get_indices(i, j+1, inds...)...)
end

function _get_indices(i::Tuple, j::Int, i1::Type{T}, inds...) where T<:StaticArray{<:Tuple, Int}
function _get_indices(i::Tuple, j::Int, ::Type{T}, inds...) where T<:StaticVector
return (:(inds[$j][$(i[1])]), _get_indices(i[2:end], j+1, inds...)...)
end

function _get_indices(i::Tuple, j::Int, i1::Type{Colon}, inds...)
function _get_indices(i::Tuple, j::Int, ::Type{<:Union{Colon, Base.Slice}}, inds...)
return (i[1], _get_indices(i[2:end], j+1, inds...)...)
end

_totally_linear() = true
_totally_linear(inds...) = false
_totally_linear(inds::Type{Int}...) = true
_totally_linear(inds::Type{<:Base.Slice}...) = true
_totally_linear(inds::Type{Colon}...) = true
_totally_linear(i1::Type{Colon}, inds...) = _totally_linear(inds...)
_totally_linear(::Type{<:Base.Slice}, inds...) = _totally_linear(inds...)
_totally_linear(::Type{Colon}, inds...) = _totally_linear(inds...)

function new_out_size_nongen(::Type{Size}, inds...) where Size
os = []
@assert length(Size.parameters) == length(inds)
map(Size.parameters, inds) do s, i
if i == Int
elseif i <: StaticVector
push!(os, length(i))
elseif i == Colon
elseif i == Colon || i <: Base.Slice
push!(os, s)
else
error("Unknown index type: $i")
Expand All @@ -51,6 +60,14 @@ function new_out_size_nongen(::Type{Size}, inds...) where Size
return tuple(os...)
end

function new_out_size_nongen(::Type{Size}, i::Type{<:Union{Colon, Base.Slice}}) where Size
if has_dynamic(Size)
return (Dynamic(),)
else
return (tuple_nodynamic_prod(Size),)
end
end

"""
_get_linear_inds(S, inds...)

Expand Down Expand Up @@ -108,7 +125,7 @@ function _get_linear_inds(S, inds...)
end
end

@generated function _getindex_all_static(sa::HybridArray{S,T}, inds::Union{Int, StaticArray{<:Tuple, Int}, Colon}...) where {S,T}
@generated function _getindex_all_static(sa::HybridArray{S,T}, inds::Union{Int, StaticIndexing, Base.Slice, Colon, StaticArray}...) where {S,T}
newsize = new_out_size_nongen(S, inds...)
exprs = Vector{Expr}(undef, length(newsize))

Expand Down Expand Up @@ -138,17 +155,13 @@ end
end
end

function new_out_size(S::Type{Size}, inds::StaticArrays.StaticIndexing...) where Size
return new_out_size(S, map(StaticArrays.unwrap, inds)...)
end


# _get_static_vector_length is used in a generated function so using a generic function
# may not be a good idea
_get_static_vector_length(::Type{<:StaticVector{N}}) where {N} = N

@generated function new_out_size(::Type{Size}, inds...) where Size
os = []
@assert length(Size.parameters) === length(inds)
map(Size.parameters, inds) do s, i
if i == Int
elseif i <: StaticVector
Expand All @@ -166,9 +179,21 @@ _get_static_vector_length(::Type{<:StaticVector{N}}) where {N} = N
return Tuple{os...}
end

@inline function _getindex(::Val{:dynamic_fixed_false}, sa::HybridArray{S}, inds::Union{Int, StaticArray{<:Tuple, Int}, Colon}...) where S
newsize = new_out_size(S, inds...)
return HybridArray{newsize}(getindex(sa.data, inds...))
@generated function new_out_size(::Type{Size}, ::Union{Colon, Base.Slice}) where Size
if has_dynamic(Size)
return Tuple{Dynamic()}
else
return Tuple{tuple_nodynamic_prod(Size)}
end
end

maybe_unwrap(i) = i
maybe_unwrap(i::StaticIndexing) = i.ind

@inline function _getindex(::Val{:dynamic_fixed_false}, sa::HybridArray{S}, inds::Union{Int, StaticIndexing, StaticVector, Base.Slice, Colon}...) where S
uinds = map(maybe_unwrap, inds)
newsize = new_out_size(S, uinds...)
return HybridArray{newsize}(getindex(sa.data, uinds...))
end

# setindex stuff
Expand Down
80 changes: 41 additions & 39 deletions test/nonstandard_indices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,60 +6,62 @@ using HybridArrays
using Test


@testset "Compatibility with EllipsisNotation" begin
u_array = rand(2, 10)
u_hybrid = HybridArray{Tuple{2, HybridArrays.Dynamic()}}(copy(u_array))
@test u_hybrid ≈ u_array
@test u_hybrid[1, ..] ≈ u_array[1, ..]
@test u_hybrid[.., 1] ≈ u_array[.., 1]
@test u_hybrid[..] ≈ u_array[..]

@test_broken typeof(u_hybrid[1, ..]) == typeof(u_hybrid[1, :])
@test_broken typeof(u_hybrid[.., 1]) == typeof(u_hybrid[:, 1])
@test_broken typeof(u_hybrid[..]) == typeof(u_hybrid[:])

@inferred u_hybrid[1, ..]
@inferred u_hybrid[.., 1]
@inferred u_hybrid[..]

let new_values = rand(10)
u_array[1, ..] .= new_values
u_hybrid[1, ..] .= new_values
if VERSION >= v"1.5"
@testset "Compatibility with EllipsisNotation" begin
u_array = rand(2, 10)
u_hybrid = HybridArray{Tuple{2, HybridArrays.Dynamic()}}(copy(u_array))
@test u_hybrid ≈ u_array
@test u_hybrid[1, ..] ≈ u_array[1, ..]
@test u_hybrid[.., 1] ≈ u_array[.., 1]
@test u_hybrid[..] ≈ u_array[..]
end

let new_values = rand(2)
u_array[.., 1] .= new_values
u_hybrid[.., 1] .= new_values
@test u_hybrid ≈ u_array
@test u_hybrid[1, ..] ≈ u_array[1, ..]
@test u_hybrid[.., 1] ≈ u_array[.., 1]
@test u_hybrid[..] ≈ u_array[..]
end

let new_values = rand(2, 10)
u_array .= new_values
u_hybrid .= new_values
@test u_hybrid ≈ u_array
@test u_hybrid[1, ..] ≈ u_array[1, ..]
@test u_hybrid[.., 1] ≈ u_array[.., 1]
@test u_hybrid[..] ≈ u_array[..]
@test typeof(u_hybrid[1, ..]) == typeof(u_hybrid[1, :])
@test typeof(u_hybrid[.., 1]) == typeof(u_hybrid[:, 1])
@test typeof(u_hybrid[..]) == typeof(u_hybrid[:])
@test typeof(u_hybrid[..,..]) == typeof(u_hybrid[:, :])

@inferred u_hybrid[1, ..]
@inferred u_hybrid[.., 1]
@inferred u_hybrid[..]

let new_values = rand(10)
u_array[1, ..] .= new_values
u_hybrid[1, ..] .= new_values
@test u_hybrid ≈ u_array
@test u_hybrid[1, ..] ≈ u_array[1, ..]
@test u_hybrid[.., 1] ≈ u_array[.., 1]
@test u_hybrid[..] ≈ u_array[..]
end

let new_values = rand(2)
u_array[.., 1] .= new_values
u_hybrid[.., 1] .= new_values
@test u_hybrid ≈ u_array
@test u_hybrid[1, ..] ≈ u_array[1, ..]
@test u_hybrid[.., 1] ≈ u_array[.., 1]
@test u_hybrid[..] ≈ u_array[..]
end

let new_values = rand(2, 10)
u_array .= new_values
u_hybrid .= new_values
@test u_hybrid ≈ u_array
@test u_hybrid[1, ..] ≈ u_array[1, ..]
@test u_hybrid[.., 1] ≈ u_array[.., 1]
@test u_hybrid[..] ≈ u_array[..]
end
end
end


@testset "Compatibility with Cartesian indices" begin
u_array = rand(2, 3, 4)
u_hybrid = HybridArray{Tuple{2, 3, 4}}(copy(u_array))
@test u_hybrid ≈ u_array
@test u_hybrid[CartesianIndex(1, 2), :] ≈ u_array[CartesianIndex(1, 2), :]
@test u_hybrid[:, CartesianIndex(1, 2)] ≈ u_array[:, CartesianIndex(1, 2)]

@test_broken typeof(u_hybrid[CartesianIndex(1, 2), :]) == typeof(u_hybrid[1, 2, :])
@test_broken typeof(u_hybrid[:, CartesianIndex(1, 2)]) == typeof(u_hybrid[:, 1, 2])
@test typeof(u_hybrid[CartesianIndex(1, 2), :]) == typeof(u_hybrid[1, 2, :])
@test typeof(u_hybrid[:, CartesianIndex(1, 2)]) == typeof(u_hybrid[:, 1, 2])

@inferred u_hybrid[CartesianIndex(1, 2), :]
@inferred u_hybrid[:, CartesianIndex(1, 2)]
Expand Down