Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ak.concatenate (mergemany) should preserve regular-type. #1604

Merged
merged 2 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/awkward/_v2/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,9 +563,10 @@ def mergemany(self, others):
nextindex = ak._v2.index.Index64.empty(total_length, self._nplike)
parameters = self._parameters

parameters = self._parameters
for array in head:
parameters = ak._v2._util.merge_parameters(
self._parameters, array._parameters, True
parameters, array._parameters, True
)

if isinstance(
Expand Down
3 changes: 2 additions & 1 deletion src/awkward/_v2/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,10 @@ def mergemany(self, others):
nextindex = ak._v2.index.Index64.empty(total_length, self._nplike)
parameters = self._parameters

parameters = self._parameters
for array in head:
parameters = ak._v2._util.merge_parameters(
self._parameters, array._parameters, True
parameters, array._parameters, True
)

if isinstance(
Expand Down
3 changes: 2 additions & 1 deletion src/awkward/_v2/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,9 +948,10 @@ def mergemany(self, others):

contents = []

parameters = self._parameters
for array in head:
parameters = ak._v2._util.merge_parameters(
self._parameters, array._parameters, True
parameters, array._parameters, True
)

if isinstance(
Expand Down
3 changes: 2 additions & 1 deletion src/awkward/_v2/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,9 +482,10 @@ def mergemany(self, others):

contiguous_arrays = []

parameters = self._parameters
for array in head:
parameters = ak._v2._util.merge_parameters(
self._parameters, array._parameters, True
parameters, array._parameters, True
)
if isinstance(array, ak._v2.contents.emptyarray.EmptyArray):
pass
Expand Down
8 changes: 7 additions & 1 deletion src/awkward/_v2/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,9 +570,10 @@ def mergemany(self, others):
for_each_field.append([field])

if self.is_tuple:
parameters = self._parameters
for array in headless:
parameters = ak._v2._util.merge_parameters(
self._parameters, array._parameters, True
parameters, array._parameters, True
)

if isinstance(array, ak._v2.contents.recordarray.RecordArray):
Expand Down Expand Up @@ -607,7 +608,12 @@ def mergemany(self, others):
these_fields = self._fields.copy()
these_fields.sort()

parameters = self._parameters
for array in headless:
parameters = ak._v2._util.merge_parameters(
parameters, array._parameters, True
)

if isinstance(array, ak._v2.contents.recordarray.RecordArray):
if not array.is_tuple:
those_fields = array._fields.copy()
Expand Down
26 changes: 25 additions & 1 deletion src/awkward/_v2/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,31 @@ def mergeable(self, other, mergebool):
def mergemany(self, others):
if len(others) == 0:
return self
return self.toListOffsetArray64(True).mergemany(others)

if any(x.is_OptionType for x in others):
return ak._v2.contents.UnmaskedArray(self).mergemany(others)

elif all(x.is_RegularType and x.size == self.size for x in others):
parameters = self._parameters
tail_contents = []
zeros_length = 0
for x in others:
parameters = ak._v2._util.merge_parameters(
parameters, x._parameters, True
)
tail_contents.append(x._content[: x._length * x._size])
zeros_length += x._length

return RegularArray(
self._content.mergemany(tail_contents),
self._size,
zeros_length,
None,
parameters,
)

else:
return self.toListOffsetArray64(True).mergemany(others)

def fill_none(self, value):
return RegularArray(
Expand Down
3 changes: 2 additions & 1 deletion src/awkward/_v2/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,9 +883,10 @@ def mergemany(self, others):
length_so_far = 0
parameters = self._parameters

parameters = self._parameters
for array in head:
parameters = ak._v2._util.merge_parameters(
self._parameters, array._parameters, True
parameters, array._parameters, True
)
if isinstance(array, ak._v2.contents.unionarray.UnionArray):
union_tags = ak._v2.index.Index(array.tags)
Expand Down
78 changes: 78 additions & 0 deletions tests/v2/test_1586-concatenate-should-preserve-regulararray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

import pytest # noqa: F401
import numpy as np # noqa: F401
import awkward as ak # noqa: F401

from awkward._v2.types import ArrayType, RegularType, OptionType, NumpyType


def test_simple():
a = ak._v2.from_numpy(np.array([[1, 2], [3, 4], [5, 6]]), regulararray=True)
b = ak._v2.from_numpy(np.array([[7, 8], [9, 10]]), regulararray=True)
c = a.layout.merge(b.layout)
assert isinstance(c, ak._v2.contents.RegularArray)
assert c.size == 2
assert ak._v2.operations.to_list(c) == [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]


def test_regular_regular():
a1 = ak._v2.from_json("[[0.0, 1.1], [2.2, 3.3]]")
a2 = ak._v2.from_json("[[4.4, 5.5], [6.6, 7.7], [8.8, 9.9]]")
a1 = ak._v2.to_regular(a1, axis=1)
a2 = ak._v2.to_regular(a2, axis=1)
c = ak._v2.concatenate([a1, a2])
assert c.tolist() == [[0.0, 1.1], [2.2, 3.3], [4.4, 5.5], [6.6, 7.7], [8.8, 9.9]]
assert c.type == ArrayType(RegularType(NumpyType("float64"), 2), 5)


def test_regular_option():
a1 = ak._v2.from_json("[[0.0, 1.1], [2.2, 3.3]]")
a2 = ak._v2.from_json("[[4.4, 5.5], [6.6, 7.7], null, [8.8, 9.9]]")
a1 = ak._v2.to_regular(a1, axis=1)
a2 = ak._v2.to_regular(a2, axis=1)
c = ak._v2.concatenate([a1, a2])
assert c.tolist() == [
[0.0, 1.1],
[2.2, 3.3],
[4.4, 5.5],
[6.6, 7.7],
None,
[8.8, 9.9],
]
assert c.type == ArrayType(OptionType(RegularType(NumpyType("float64"), 2)), 6)


def test_option_regular():
a1 = ak._v2.from_json("[[0.0, 1.1], null, [2.2, 3.3]]")
a2 = ak._v2.from_json("[[4.4, 5.5], [6.6, 7.7], [8.8, 9.9]]")
a1 = ak._v2.to_regular(a1, axis=1)
a2 = ak._v2.to_regular(a2, axis=1)
c = ak._v2.concatenate([a1, a2])
assert c.tolist() == [
[0.0, 1.1],
None,
[2.2, 3.3],
[4.4, 5.5],
[6.6, 7.7],
[8.8, 9.9],
]
assert c.type == ArrayType(OptionType(RegularType(NumpyType("float64"), 2)), 6)


def test_option_option():
a1 = ak._v2.from_json("[[0.0, 1.1], null, [2.2, 3.3]]")
a2 = ak._v2.from_json("[[4.4, 5.5], [6.6, 7.7], null, [8.8, 9.9]]")
a1 = ak._v2.to_regular(a1, axis=1)
a2 = ak._v2.to_regular(a2, axis=1)
c = ak._v2.concatenate([a1, a2])
assert c.tolist() == [
[0.0, 1.1],
None,
[2.2, 3.3],
[4.4, 5.5],
[6.6, 7.7],
None,
[8.8, 9.9],
]
assert c.type == ArrayType(OptionType(RegularType(NumpyType("float64"), 2)), 7)