Skip to content

Commit

Permalink
Rewrite _build_and_check_signature
Browse files Browse the repository at this point in the history
  • Loading branch information
shoyer committed Aug 14, 2016
1 parent c27e25e commit d8786da
Showing 1 changed file with 51 additions and 81 deletions.
132 changes: 51 additions & 81 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import NamedTuple
import itertools
from collections import namedtuple


def result_name(objects):
Expand All @@ -13,19 +14,25 @@ def result_name(objects):
return name


def apply_dataarray(func, args, join='inner', kwargs=None):
def apply_dataarray(func, args, join='inner', gufunc_signature=None,
kwargs=None, combine_names=None):
if kwargs is None:
kwargs = {}

if combine_names is None:
combine_names = result_name

args = deep_align(*args, join=join, copy=False, raise_on_invalid=False)

coord_variables = [getattr(getattr(a, 'coords', {}), 'variables')
for a in args]
coords = merge_coords_without_align(coord_variables)
name = result_name(args)
name = combine_names(args)

data_vars = [getattr(a, 'variable') for a in args]
variable = func(*data_vars, **kwargs)
variables = func(*data_vars, **kwargs)

# TODO handle gufunc_signature

return DataArray(variable, coords, name=name, fastpath=True)

Expand Down Expand Up @@ -111,76 +118,35 @@ def _as_sequence(arg, cls):
return cls(arg)


_ElemwiseSignature = NamedTuple(
'_ElemwiseSignature', 'broadcast_dims, core_dims, output_dims, axis')
_ElemwiseSignature = namedtuple(
'_ElemwiseSignature', 'broadcast_dims, output_dims')

class GUFuncSignature(object):
def __init__(self, inputs, outputs):
self.inputs = inputs
self.outputs = outputs

def _build_and_check_signature(variables, core_dims=None, axis_dims=None,
drop_dims=None, new_dims=None):
# All input dimension arguments are checked to appear on at least one input:
# - core_dims are not broadcast over, and moved to the right with order
# preserved.
# - axis_dims is used to generate an integer or tuples of integers `axis`
# keyword argument, which corresponds to the position of the given
# dimension on the inputs. If `axis_dims` have overlap with `core_dims`,
# no non-axis dimensions may appear in `core_dims` before an axis
# dimension.
# - drop_dims are input dimensions that should be dropped from the output.
#
# All output dimensions arguments are checked not to appear on any inputs:
# - new_dims are new dimensions that should be added to the output array, in
# order to the right of dimensions that are not dropped.
@classmethod
def from_string(cls, string):
raise NotImplementedError

if core_dims is None and drop_dims is None and axis_dims is None:
# broadcast everything
dims = tuple(_calculate_unified_dim_sizes(variables))
return _ElemwiseSignature(dims, (), dims, None)

core_dims = () if core_dims is None else _as_sequence(core_dims, tuple)
drop_dims = set() if drop_dims is None else _as_sequence(drop_dims, set)
new_dims = () if new_dims is None else _as_sequence(new_dims, tuple)

axis_is_scalar = axis_dims is not None and is_scalar(axis_dims)
axis_dims = set() if axis_dims is None else _as_sequence(axis_dims, set)
def _build_and_check_signature(variables, gufunc_signature):
# core_dims are not broadcast over, and moved to the right with order
# preserved.

dim_sizes = _calculate_unified_dim_sizes(variables)

if gufunc_signature is None:
# broadcast everything, one output
dims = tuple(size_dims)
return _ElemwiseSignature(dims, [dims])

core_dims = set(itertools.chain.from_iterable(
itertools.chain(gufunc_signature.inputs, gufunc_signature.outputs)))
broadcast_dims = tuple(d for d in dim_sizes if d not in core_dims)
all_input_dims = set(dim_sizes)

invalid = set(core_dims) - all_input_dims
if invalid:
raise ValueError('some `core_dims` not found on any input variables: '
'%r' % list(invalid))

invalid = drop_dims - all_input_dims
if invalid:
raise ValueError('some `drop_dims` not found on any input variables: '
'%r' % list(invalid))

invalid = axis_dims - all_input_dims
if invalid:
raise ValueError('some `axis_dims` not found on any input variables: '
'%r' % list(invalid))
axis = tuple(broadcast_dims.index(d) for d in axis_dims)
n_remaining_axes = len(axis_dims) - len(axis)
if n_remaining_axes > 0:
valid_core_dims_for_axis = core_dims[:remaining_axes]
if not set(valid_core_dims_for_axis) <= axis_dims:
raise ValueError('axis dimensions %r have overlap with core '
'dimensions %r, but do not appear at the start'
% (axis_dims, core_dims))
axis += tuple(range(len(axis), n_remaining_axes + len(axis)))
if axis_is_scalar:
axis, = axis

invalid = set(new_dims) ^ all_input_dims
if invalid:
raise ValueError('some `new_dims` are found on input variables: '
'%r' % list(invalid))

output_dims = tuple(d for d in dim_sizes if d not in drop_dims) + new_dims

return _ElemwiseSignature(broadcast_dims, core_dims, output_dims, axis)
output_dims = [broadcast_dims + out for out in gufunc_signature.outputs]
return _ElemwiseSignature(broadcast_dims, output_dims)


def _broadcast_variable_data_to(variable, broadcast_dims, allow_dask=True):
Expand Down Expand Up @@ -208,8 +174,7 @@ def _broadcast_variable_data_to(variable, broadcast_dims, allow_dask=True):
return data


def apply_variable_ufunc(func, args, allow_dask=True, core_dims=None,
axis_dims=None, drop_dims=None, new_dims=None,
def apply_variable_ufunc(func, args, allow_dask=True, gufunc_signature=None,
combine_attrs=None, kwargs=None):

if kwargs is None:
Expand All @@ -218,29 +183,34 @@ def apply_variable_ufunc(func, args, allow_dask=True, core_dims=None,
if combine_attrs is None:
combine_attrs = lambda func, attrs: None

result_attrs = combine_attrs(func, [getattr(a, 'attrs', {}) for a in args])

sig = _build_and_check_dims_signature(
variables, core_dims, axis_dims, drop_dims, new_dims)
sig = _build_and_check_signature(variables, gufunc_signature)

if sig.axis:
if 'axis' in kwargs:
raise ValueError('axis is already set in kwargs')
kwargs = dict(kwargs)
kwargs['axis'] = sig.axis
n_out = len(sig.output_dims)
input_attrs = [getattr(a, 'attrs', {}) for a in args]
result_attrs = [combine_attrs(input_attrs, func, n) for n in range(n_out)]

list_of_data = []
for arg in args:
if isinstance(arg, Variable):
data = _broadcast_variable_data_to(arg, sig.broadcast_dims,
allow_dask=allow_dask)
list_of_data.append(data)
else:
list_of_data.append(arg)
data = arg
list_of_data.append(data)

result_data = func(*list_of_data, **kwargs)

return Variable(sig.output_dims, result_data, result_attrs)
if n_out > 1:
output = []
for dims, data, attrs in zip(
sig.output_dims, result_data, result_attrs):
output.append(Variable(dims, data, attrs))
return tuple(output)
else:
dims, = sig.output_dims
data, = result_data
attrs = result_attrs
return Variable(dims, data, attrs)


def apply_ufunc(func, args, join='inner', allow_dask=True, kwargs=None,
Expand Down

0 comments on commit d8786da

Please sign in to comment.