Skip to content

Commit

Permalink
UnionArray is minimally tested; I declare the PR done.
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Aug 26, 2023
1 parent 1c8fcb5 commit 2c68bed
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 27 deletions.
86 changes: 59 additions & 27 deletions src/AwkwardArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,17 @@ RecordArray(
behavior = behavior,
)

RecordArray{CONTENTS}(;
parameters::Parameters = Parameters(),
behavior::Symbol = :default,
) where {CONTENTS<:NamedTuple} = RecordArray(
NamedTuple{CONTENTS.parameters[1]}(
Base.Tuple(x() for x in CONTENTS.parameters[2].parameters),
),
parameters = parameters,
behavior = behavior,
)

struct Record{ARRAY<:RecordArray}
array::ARRAY
at::Int64
Expand Down Expand Up @@ -752,6 +763,15 @@ TupleArray(
behavior = behavior,
)

TupleArray{CONTENTS}(;
parameters::Parameters = Parameters(),
behavior::Symbol = :default,
) where {CONTENTS<:Base.Tuple} = TupleArray(
Base.Tuple(x() for x in CONTENTS.parameters),
parameters = parameters,
behavior = behavior,
)

struct Tuple{ARRAY<:TupleArray}
array::ARRAY
at::Int64
Expand Down Expand Up @@ -1447,7 +1467,19 @@ UnionArray{TAGS,INDEX,CONTENTS}(
) where {TAGS<:Index8,INDEX<:IndexBig,CONTENTS<:Base.Tuple} =
UnionArray(TAGS([]), INDEX([]), contents, parameters = parameters, behavior = behavior)

struct Specialization{TAG<:Int64,ARRAY<:UnionArray,TAGGED<:Content}
UnionArray{TAGS,INDEX,CONTENTS}(;
parameters::Parameters = Parameters(),
behavior::Symbol = :default,
) where {TAGS<:Index8,INDEX<:IndexBig,CONTENTS<:Base.Tuple} = UnionArray(
TAGS([]),
INDEX([]),
Base.Tuple(x() for x in CONTENTS.parameters),
parameters = parameters,
behavior = behavior,
)

struct Specialization{ARRAY<:UnionArray,TAGGED<:Content}
tag::Int64
array::ARRAY
tagged::TAGGED
end
Expand Down Expand Up @@ -1487,7 +1519,7 @@ function copy(
end

function is_valid(layout::UnionArray)
if length(tags) > length(index)
if length(layout.tags) > length(layout.index)
return false
end
adjustment = firstindex(layout.tags) - firstindex(layout.index)
Expand Down Expand Up @@ -1535,68 +1567,68 @@ Base.getindex(layout::UnionArray, f::Symbol) =
copy(layout, contents = Base.Tuple(x[f] for x in layout.contents))

specialization(layout::UnionArray, tag::Int64) =
Specialization{tag}(layout, layout.contents[tag])
Specialization(tag, layout, layout.contents[tag])

function push!(
special::Specialization{TAG,ARRAY,TAGGED},
special::Specialization{ARRAY,TAGGED},
x::ITEM,
) where {TAG<:Int64,ITEM,ARRAY<:UnionArray,TAGGED<:PrimitiveArray{ITEM}}
) where {ITEM,ARRAY<:UnionArray,TAGGED<:PrimitiveArray{ITEM}}
tmp = length(special.tagged)
push!(special.tagged, x)
Base.push!(special.array.tags, TAG)
Base.push!(special.array.tags, special.tag - firstindex(special.array.contents))
Base.push!(special.array.index, tmp)
layout
special
end

function push!(
special::Specialization{TAG,ARRAY,TAGGED},
special::Specialization{ARRAY,TAGGED},
x::ITEM,
) where {TAG<:Int64,ITEM,ARRAY<:UnionArray,TAGGED<:OptionType}
) where {ITEM,ARRAY<:UnionArray,TAGGED<:OptionType}
tmp = length(special.tagged)
push!(special.tagged, x)
Base.push!(special.array.tags, TAG)
Base.push!(special.array.tags, special.tag - firstindex(special.array.contents))
Base.push!(special.array.index, tmp)
layout
special
end

function end_list!(
special::Specialization{TAG,ARRAY,TAGGED},
) where {TAG<:Int64,ARRAY<:UnionArray,TAGGED<:Content}
special::Specialization{ARRAY,TAGGED},
) where {ARRAY<:UnionArray,TAGGED<:Content}
tmp = length(special.tagged)
end_list!(special.tagged)
Base.push!(special.array.tags, TAG)
Base.push!(special.array.tags, special.tag - firstindex(special.array.contents))
Base.push!(special.array.index, tmp)
layout
special
end

function end_record!(
special::Specialization{TAG,ARRAY,TAGGED},
) where {TAG<:Int64,ARRAY<:UnionArray,TAGGED<:Content}
special::Specialization{ARRAY,TAGGED},
) where {ARRAY<:UnionArray,TAGGED<:Content}
tmp = length(special.tagged)
end_record!(special.tagged)
Base.push!(special.array.tags, TAG)
Base.push!(special.array.tags, special.tag - firstindex(special.array.contents))
Base.push!(special.array.index, tmp)
layout
special
end

function end_tuple!(
special::Specialization{TAG,ARRAY,TAGGED},
) where {TAG<:Int64,ARRAY<:UnionArray,TAGGED<:Content}
special::Specialization{ARRAY,TAGGED},
) where {ARRAY<:UnionArray,TAGGED<:Content}
tmp = length(special.tagged)
end_tuple!(special.tagged)
Base.push!(special.array.tags, TAG)
Base.push!(special.array.tags, special.tag - firstindex(special.array.contents))
Base.push!(special.array.index, tmp)
layout
special
end

function push_null!(
special::Specialization{TAG,ARRAY,TAGGED},
) where {TAG<:Int64,ARRAY<:UnionArray,TAGGED<:OptionType}
special::Specialization{ARRAY,TAGGED},
) where {ARRAY<:UnionArray,TAGGED<:OptionType}
tmp = length(special.tagged)
push_null!(special.tagged)
Base.push!(special.array.tags, TAG)
Base.push!(special.array.tags, special.tag - firstindex(special.array.contents))
Base.push!(special.array.index, tmp)
layout
special
end

end # module AwkwardArray
177 changes: 177 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,57 @@ using Test
@test layout_2 == AwkwardArray.copy(layout_3, length = 2)
end

begin
layout = AwkwardArray.RecordArray{
NamedTuple{
(:a, :b),
Tuple{
AwkwardArray.PrimitiveArray{Int64},
AwkwardArray.ListOffsetArray{
AwkwardArray.Index64,
AwkwardArray.PrimitiveArray{Float64},
},
},
},
}()
@test AwkwardArray.is_valid(layout)
@test length(layout) == 0

a_layout = layout.contents[:a]
b_layout = layout.contents[:b]
b_sublayout = b_layout.content

AwkwardArray.push!(a_layout, 1)
AwkwardArray.push!(b_sublayout, 1.1)
AwkwardArray.push!(b_sublayout, 2.2)
AwkwardArray.push!(b_sublayout, 3.3)
AwkwardArray.end_list!(b_layout)
AwkwardArray.end_record!(layout)
@test length(layout) == 1

AwkwardArray.push!(a_layout, 2)
AwkwardArray.end_list!(b_layout)
AwkwardArray.end_record!(layout)
@test length(layout) == 2

AwkwardArray.push!(a_layout, 3)
AwkwardArray.push!(b_sublayout, 4.4)
AwkwardArray.push!(b_sublayout, 5.5)
AwkwardArray.end_list!(b_layout)
AwkwardArray.end_record!(layout)
@test length(layout) == 3

@test layout == AwkwardArray.RecordArray(
NamedTuple{(:a, :b)}((
AwkwardArray.PrimitiveArray([1, 2, 3]),
AwkwardArray.ListOffsetArray(
[0, 3, 3, 5],
AwkwardArray.PrimitiveArray([1.1, 2.2, 3.3, 4.4, 5.5]),
),
)),
)
end

### TupleArray ##########################################################

begin
Expand Down Expand Up @@ -914,6 +965,52 @@ using Test
@test layout_2 == AwkwardArray.copy(layout_3, length = 2)
end

begin
layout = AwkwardArray.TupleArray{
Tuple{
AwkwardArray.PrimitiveArray{Int64},
AwkwardArray.ListOffsetArray{
AwkwardArray.Index64,
AwkwardArray.PrimitiveArray{Float64},
},
},
}()
@test AwkwardArray.is_valid(layout)
@test length(layout) == 0

a_layout = layout.contents[1]
b_layout = layout.contents[2]
b_sublayout = b_layout.content

AwkwardArray.push!(a_layout, 1)
AwkwardArray.push!(b_sublayout, 1.1)
AwkwardArray.push!(b_sublayout, 2.2)
AwkwardArray.push!(b_sublayout, 3.3)
AwkwardArray.end_list!(b_layout)
AwkwardArray.end_tuple!(layout)
@test length(layout) == 1

AwkwardArray.push!(a_layout, 2)
AwkwardArray.end_list!(b_layout)
AwkwardArray.end_tuple!(layout)
@test length(layout) == 2

AwkwardArray.push!(a_layout, 3)
AwkwardArray.push!(b_sublayout, 4.4)
AwkwardArray.push!(b_sublayout, 5.5)
AwkwardArray.end_list!(b_layout)
AwkwardArray.end_tuple!(layout)
@test length(layout) == 3

@test layout == AwkwardArray.TupleArray((
AwkwardArray.PrimitiveArray([1, 2, 3]),
AwkwardArray.ListOffsetArray(
[0, 3, 3, 5],
AwkwardArray.PrimitiveArray([1.1, 2.2, 3.3, 4.4, 5.5]),
),
),)
end

### IndexedArray #########################################################

begin
Expand Down Expand Up @@ -1321,4 +1418,84 @@ using Test

### UnionArray ###########################################################

begin
layout = AwkwardArray.UnionArray(
Vector{Int8}([0, 0, 0, 1]),
[0, 1, 2, 0],
(
AwkwardArray.PrimitiveArray([1.1, 2.2, 3.3]),
AwkwardArray.ListOffsetArray(
[0, 2],
AwkwardArray.PrimitiveArray([4.4, 5.5]),
),
),
)
@test AwkwardArray.is_valid(layout)
@test length(layout) == 4
@test layout[1] == 1.1
@test layout[2] == 2.2
@test layout[3] == 3.3
@test layout[4] == AwkwardArray.PrimitiveArray([4.4, 5.5])

tmp = 0.0
for x in layout
if isa(x, AwkwardArray.PrimitiveArray)
for y in x
@test y < 6
tmp += y
end
else
@test x < 6
tmp += x
end
end
@test tmp == 16.5

@test layout == layout
end

begin
layout = AwkwardArray.UnionArray{
AwkwardArray.Index8,
AwkwardArray.Index64,
Tuple{
AwkwardArray.PrimitiveArray{Float64},
AwkwardArray.ListOffsetArray{
AwkwardArray.Index64,
AwkwardArray.PrimitiveArray{Float64},
},
},
}()
@test AwkwardArray.is_valid(layout)
@test length(layout) == 0

special1 = AwkwardArray.specialization(layout, 1)
special2 = AwkwardArray.specialization(layout, 2)
subspecial2 = special2.tagged.content

AwkwardArray.push!(special1, 1.1)
@test length(layout) == 1
@test layout[1] == 1.1

AwkwardArray.push!(subspecial2, 2.2)
AwkwardArray.push!(subspecial2, 3.3)
AwkwardArray.end_list!(special2)
@test length(layout) == 2
@test layout[2][1] == 2.2
@test layout[2][2] == 3.3

@test layout == AwkwardArray.UnionArray(
Vector{Int8}([0, 1]),
[0, 0],
(
AwkwardArray.PrimitiveArray([1.1]),
AwkwardArray.ListOffsetArray(
[0, 2],
AwkwardArray.PrimitiveArray([2.2, 3.3]),
),
),
)

end

end # @testset "AwkwardArray.jl"

0 comments on commit 2c68bed

Please sign in to comment.