Skip to content

Commit

Permalink
fix: prevent reducers like ak.sum on records (v2) (scikit-hep#1607)
Browse files Browse the repository at this point in the history
* policy: Prevent reducers like ak.sum on records (v2)

* Pass behavior on the rest of the reducers

* Made a util function for checking if a reducer function is overloaded for records

* Switch using .split with .name

* Add 'highlevel_function()
utility
  • Loading branch information
ioanaif authored and Saransh-cpp committed Aug 29, 2022
1 parent 475ac10 commit b6b6f05
Show file tree
Hide file tree
Showing 27 changed files with 258 additions and 69 deletions.
4 changes: 4 additions & 0 deletions src/awkward/_v2/_reducers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
class Reducer:
needs_position = False

@classmethod
def highlevel_function(cls):
return getattr(ak._v2.operations, cls.name)

@classmethod
def return_dtype(cls, given_dtype):
if given_dtype in (np.bool_, np.int8, np.int16, np.int32):
Expand Down
7 changes: 7 additions & 0 deletions src/awkward/_v2/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,13 @@ def recordclass(layout, behavior):
return ak._v2.highlevel.Record


def reducer_recordclass(reducer, layout, behavior):
behavior = Behavior(ak._v2.behavior, behavior)
rec = layout.parameter("__record__")
if isstr(rec):
return behavior[reducer.highlevel_function(), rec]


def typestrs(behavior):
behavior = Behavior(ak._v2.behavior, behavior)
out = {}
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
return self.toByteMaskedArray()._reduce_next(
reducer,
Expand All @@ -569,6 +570,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

def _validity_error(self, path):
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,6 +795,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
mask_length = self._mask.length

Expand Down Expand Up @@ -899,6 +900,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

if not branch and negaxis == depth:
Expand Down
53 changes: 32 additions & 21 deletions src/awkward/_v2/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ def dummy(self):
def local_index(self, axis):
return self._local_index(axis, 0)

def _reduce(self, reducer, axis=-1, mask=True, keepdims=False):
def _reduce(self, reducer, axis=-1, mask=True, keepdims=False, behavior=None):
if axis is None:
raise ak._v2._util.error(NotImplementedError)

Expand Down Expand Up @@ -861,39 +861,50 @@ def _reduce(self, reducer, axis=-1, mask=True, keepdims=False):
1,
mask,
keepdims,
behavior,
)

return next[0]

def argmin(self, axis=-1, mask=True, keepdims=False):
return self._reduce(awkward._v2._reducers.ArgMin, axis, mask, keepdims)
def argmin(self, axis=-1, mask=True, keepdims=False, behavior=None):
return self._reduce(
awkward._v2._reducers.ArgMin, axis, mask, keepdims, behavior
)

def argmax(self, axis=-1, mask=True, keepdims=False):
return self._reduce(awkward._v2._reducers.ArgMax, axis, mask, keepdims)
def argmax(self, axis=-1, mask=True, keepdims=False, behavior=None):
return self._reduce(
awkward._v2._reducers.ArgMax, axis, mask, keepdims, behavior
)

def count(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.Count, axis, mask, keepdims)
def count(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(awkward._v2._reducers.Count, axis, mask, keepdims, behavior)

def count_nonzero(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.CountNonzero, axis, mask, keepdims)
def count_nonzero(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(
awkward._v2._reducers.CountNonzero, axis, mask, keepdims, behavior
)

def sum(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.Sum, axis, mask, keepdims)
def sum(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(awkward._v2._reducers.Sum, axis, mask, keepdims, behavior)

def prod(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.Prod, axis, mask, keepdims)
def prod(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(awkward._v2._reducers.Prod, axis, mask, keepdims, behavior)

def any(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.Any, axis, mask, keepdims)
def any(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(awkward._v2._reducers.Any, axis, mask, keepdims, behavior)

def all(self, axis=-1, mask=False, keepdims=False):
return self._reduce(awkward._v2._reducers.All, axis, mask, keepdims)
def all(self, axis=-1, mask=False, keepdims=False, behavior=None):
return self._reduce(awkward._v2._reducers.All, axis, mask, keepdims, behavior)

def min(self, axis=-1, mask=True, keepdims=False, initial=None):
return self._reduce(awkward._v2._reducers.Min(initial), axis, mask, keepdims)
def min(self, axis=-1, mask=True, keepdims=False, initial=None, behavior=None):
return self._reduce(
awkward._v2._reducers.Min(initial), axis, mask, keepdims, behavior
)

def max(self, axis=-1, mask=True, keepdims=False, initial=None):
return self._reduce(awkward._v2._reducers.Max(initial), axis, mask, keepdims)
def max(self, axis=-1, mask=True, keepdims=False, initial=None, behavior=None):
return self._reduce(
awkward._v2._reducers.Max(initial), axis, mask, keepdims, behavior
)

def argsort(self, axis=-1, ascending=True, stable=False, kind=None, order=None):
negaxis = -axis
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
as_numpy = self.toNumpyArray(reducer.preferred_dtype)
return as_numpy._reduce_next(
Expand All @@ -272,6 +273,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

def _validity_error(self, path):
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
branch, depth = self.branch_depth

Expand Down Expand Up @@ -1013,6 +1014,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

# If we are reducing the contents of this layout,
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
branch, depth = self.branch_depth

Expand All @@ -1390,6 +1391,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

# If we are reducing the contents of this layout,
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
return self.toListOffsetArray64(True)._reduce_next(
reducer,
Expand All @@ -1244,6 +1245,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

def _validity_error(self, path):
Expand Down
4 changes: 4 additions & 0 deletions src/awkward/_v2/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,6 +1467,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
if self._offsets.dtype != np.dtype(np.int64) or (
self._offsets.nplike.known_data and self._offsets[0] != 0
Expand All @@ -1481,6 +1482,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

branch, depth = self.branch_depth
Expand Down Expand Up @@ -1586,6 +1588,7 @@ def _reduce_next(
maxnextparents[0] + 1,
mask,
False,
behavior,
)

out = ak._v2.contents.ListArray(
Expand Down Expand Up @@ -1641,6 +1644,7 @@ def _reduce_next(
globalstarts_length,
mask,
keepdims,
behavior,
)

outoffsets = ak._v2.index.Index64.empty(outlength + 1, self._nplike)
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
if len(self._data.shape) != 1 or not self.is_contiguous:
return self.toRegularArray()._reduce_next(
Expand All @@ -1088,6 +1089,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

if isinstance(self.nplike, ak.nplike.Jax):
Expand Down
35 changes: 14 additions & 21 deletions src/awkward/_v2/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,30 +824,23 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
contents = []
for content in self._contents:
contents.append(
content[: self._length]._reduce_next(
reducer,
negaxis,
starts,
shifts,
parents,
outlength,
mask,
keepdims,
reducer_recordclass = ak._v2._util.reducer_recordclass(reducer, self, behavior)
if reducer_recordclass is None:
raise ak._v2._util.error(
TypeError(
"no ak.{} overloads for custom types: {}".format(
reducer.name, ", ".join(self._fields)
)
)
)
else:
raise ak._v2._util.error(
NotImplementedError(
"overloading reducers for RecordArrays has not been implemented yet"
)
)

return ak._v2.contents.RecordArray(
contents,
self._fields,
outlength,
None,
None,
self._nplike,
)

def _validity_error(self, path):
for i, cont in enumerate(self.contents):
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
out = self.toListOffsetArray64(True)._reduce_next(
reducer,
Expand All @@ -995,6 +996,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

if not self._content.dimension_optiontype:
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
simplified = self.simplify_uniontype(mergebool=True)
if isinstance(simplified, UnionArray):
Expand All @@ -1150,6 +1151,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

def _validity_error(self, path):
Expand Down
2 changes: 2 additions & 0 deletions src/awkward/_v2/contents/unmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
):
next = self._content
if isinstance(next, ak._v2.contents.RegularArray):
Expand All @@ -502,6 +503,7 @@ def _reduce_next(
outlength,
mask,
keepdims,
behavior,
)

def _validity_error(self, path):
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def reduce(xs):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.all(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.all(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_any.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def reduce(xs):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.any(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.any(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def _impl(array, axis, keepdims, mask_identity, flatten_records):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.argmax(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.argmax(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_argmin.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,9 @@ def _impl(array, axis, keepdims, mask_identity, flatten_records):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.argmin(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.argmin(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/_v2/operations/ak_count.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def reduce(xs):

else:
behavior = ak._v2._util.behavior_of(array)
out = layout.count(axis=axis, mask=mask_identity, keepdims=keepdims)
out = layout.count(
axis=axis, mask=mask_identity, keepdims=keepdims, behavior=behavior
)
if isinstance(out, (ak._v2.contents.Content, ak._v2.record.Record)):
return ak._v2._util.wrap(out, behavior)
else:
Expand Down
Loading

0 comments on commit b6b6f05

Please sign in to comment.