Skip to content

Commit

Permalink
Merge branch 'master' into od/ambiguity
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Jan 5, 2023
2 parents 50d134a + 9f61e95 commit b6d001b
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 0 deletions.
86 changes: 86 additions & 0 deletions src/Containers/DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,17 @@ function Base.IndexStyle(::Type{DenseAxisArray{T,N,Ax}}) where {T,N,Ax}
return IndexAnyCartesian()
end

function Base.setindex!(
A::DenseAxisArray{T,N},
value::DenseAxisArray{T,N},
args...,
) where {T,N}
for key in Base.product(args...)
A[key...] = value[key...]
end
return A
end

########
# Keys #
########
Expand All @@ -376,6 +387,10 @@ end
Base.getindex(k::DenseAxisArrayKey, args...) = getindex(k.I, args...)
Base.getindex(a::DenseAxisArray, k::DenseAxisArrayKey) = a[k.I...]

function Base.setindex!(A::DenseAxisArray, value, key::DenseAxisArrayKey)
return setindex!(A, value, key.I...)
end

struct DenseAxisArrayKeys{T<:Tuple,S<:DenseAxisArrayKey,N} <: AbstractArray{S,N}
product_iter::Base.Iterators.ProductIterator{T}
function DenseAxisArrayKeys(a::DenseAxisArray{TT,N,Ax}) where {TT,N,Ax}
Expand Down Expand Up @@ -572,3 +587,74 @@ end
# but some users may depend on it's functionality so we have a work-around
# instead of just breaking code.
Base.repeat(x::DenseAxisArray; kwargs...) = repeat(x.data; kwargs...)

###
### view
###

_get_subaxis(::Colon, b) = b

function _get_subaxis(a::AbstractVector, b)
for ai in a
if !(ai in b)
throw(KeyError(ai))
end
end
return a
end

function _get_subaxis(a::T, b::AbstractVector{T}) where {T}
if !(a in b)
throw(KeyError(a))
end
return a
end

struct DenseAxisArrayView{T,N,D,A} <: AbstractArray{T,N}
data::D
axes::A
function DenseAxisArrayView(
x::Containers.DenseAxisArray{T,N},
args...,
) where {T,N}
axis = _get_subaxis.(args, axes(x))
return new{T,N,typeof(x),typeof(axis)}(x, axis)
end
end

function Base.view(A::Containers.DenseAxisArray, args...)
return DenseAxisArrayView(A, args...)
end

Base.size(x::DenseAxisArrayView) = length.(x.axes)

Base.axes(x::DenseAxisArrayView) = x.axes

function Base.getindex(x::DenseAxisArrayView, args...)
y = _get_subaxis.(args, x.axes)
return getindex(x.data, y...)
end

Base.getindex(a::DenseAxisArrayView, k::DenseAxisArrayKey) = a[k.I...]

function Base.setindex!(x::DenseAxisArrayView, args...)
return setindex!(x.data, args...)
end

function Base.eachindex(A::DenseAxisArrayView)
# Return a generator so that we lazily evaluate the product instead of
# collecting into a vector.
#
# In future, we might want to return the appropriate matrix of
# `CartesianIndex` to avoid having to do the lookups with
# `DenseAxisArrayKey`.
return (DenseAxisArrayKey(k) for k in Base.product(A.axes...))
end

Base.show(io::IO, x::DenseAxisArrayView) = print(io, x.data)

Base.print_array(io::IO, x::DenseAxisArrayView) = show(io, x)

function Base.summary(io::IO, x::DenseAxisArrayView)
return print(io, "view(::DenseAxisArray, ", join(x.axes, ", "), "), over")
end
123 changes: 123 additions & 0 deletions test/Containers/test_DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -490,4 +490,127 @@ function test_ambiguity_isassigned()
return
end

function test_containers_denseaxisarray_setindex_vector()
A = Containers.DenseAxisArray(zeros(3), 1:3)
A[2:3] .= 1.0
@test A.data == [0.0, 1.0, 1.0]
A = Containers.DenseAxisArray(zeros(3), 1:3)
A[[2, 3]] .= 1.0
@test A.data == [0.0, 1.0, 1.0]
A = Containers.DenseAxisArray(zeros(3), 1:3)
A[[1, 3]] .= 1.0
@test A.data == [1.0, 0.0, 1.0]
A = Containers.DenseAxisArray(zeros(3), 1:3)
A[[2]] .= 1.0
@test A.data == [0.0, 1.0, 0.0]
A[2:3] = Containers.DenseAxisArray([2.0, 3.0], 2:3)
@test A.data == [0.0, 2.0, 3.0]
A = Containers.DenseAxisArray(zeros(3), 1:3)
A[:] .= 1.0
@test A.data == [1.0, 1.0, 1.0]
return
end

function test_containers_denseaxisarray_setindex_matrix()
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
A[:, [:a, :b]] .= 1.0
@test A.data == [1.0 1.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0]
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
A[2:3, [:a, :b]] .= 1.0
@test A.data == [0.0 0.0 0.0; 1.0 1.0 0.0; 1.0 1.0 0.0]
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
A[3:3, [:a, :b]] .= 1.0
@test A.data == [0.0 0.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0]
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
A[[1, 3], [:a, :b]] .= 1.0
@test A.data == [1.0 1.0 0.0; 0.0 0.0 0.0; 1.0 1.0 0.0]
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
A[[1, 3], [:a, :c]] .= 1.0
@test A.data == [1.0 0.0 1.0; 0.0 0.0 0.0; 1.0 0.0 1.0]
return
end

function test_containers_denseaxisarray_view()
A = Containers.DenseAxisArray(zeros(3, 3), 1:3, [:a, :b, :c])
B = view(A, :, [:a, :b])
@test_throws KeyError view(A, :, [:d])
@test size(B) == (3, 2)
@test B[1, :a] == A[1, :a]
@test B[3, :a] == A[3, :a]
@test_throws KeyError B[3, :c]
@test sprint(show, B) == sprint(show, B.data)
@test sprint(Base.print_array, B) == sprint(show, B.data)
@test sprint(Base.summary, B) ==
"view(::DenseAxisArray, 1:3, [:a, :b]), over"
return
end

function test_containers_denseaxisarray_jump_3151()
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
E = Containers.DenseAxisArray(ones(3), [:a, :b, :c])
I = [:a, :b]
D[I] = E[I]
@test D.data == [1.0, 1.0, 0.0]
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
I = [:b, :c]
D[I] = E[I]
@test D.data == [0.0, 1.0, 1.0]
D = Containers.DenseAxisArray(zeros(3), [:a, :b, :c])
I = [:a, :c]
D[I] = E[I]
@test D.data == [1.0, 0.0, 1.0]
return
end

function test_containers_denseaxisarray_view_operations()
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
d = view(c, 2:3, :)
@test sum(c) == 60
@test sum(d) == 30
d .= 1
@test sum(d) == 4
@test sum(c) == 34
return
end

function test_containers_denseaxisarray_view_addition()
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
d = view(c, 2:3, :)
@test_throws MethodError d + d
return
end

function test_containers_denseaxisarray_view_colon()
c = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
d = view(c, 2:3, :)
@test d[:, 2] == Containers.@container([i = 2:3], i + 2 * 2)
return
end

function test_containers_denseaxisarray_setindex_invalid()
c = Containers.@container([i = 1:4, j = 2:3], 0)
d = Containers.@container([i = 1:4, j = 2:3], i + 2 * j)
setindex!(c, d, 1:4, 2:3)
@test c == d
c .= 0
setindex!(c, d, 1:4, 2:2)
@test c == Containers.@container([i = 1:4, j = 2:3], (4 + i) * (j == 2))
d = Containers.@container([i = 5:6, j = 2:3], i + 2 * j)
@test_throws KeyError setindex!(c, d, 1:4, 2:3)
return
end

function test_containers_denseaxisarray_setindex_keys()
c = Containers.@container([i = 1:4, j = 2:3], 0)
for (i, k) in enumerate(keys(c))
c[k] = c[k] + i
end
@test c == Containers.@container([i = 1:4, j = 2:3], 4 * (j - 2) + i)
for (i, k) in enumerate(keys(c))
c[k] = c[k] + i
end
@test c == Containers.@container([i = 1:4, j = 2:3], 2 * (4 * (j - 2) + i))
return
end

end # module

0 comments on commit b6d001b

Please sign in to comment.