diff --git a/xarray/core/computation.py b/xarray/core/computation.py index 1d0d507faef..72f66fee1e6 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -1,4 +1,5 @@ -from collections import NamedTuple +import itertools +from collections import namedtuple def result_name(objects): @@ -13,19 +14,25 @@ def result_name(objects): return name -def apply_dataarray(func, args, join='inner', kwargs=None): +def apply_dataarray(func, args, join='inner', gufunc_signature=None, + kwargs=None, combine_names=None): if kwargs is None: kwargs = {} + if combine_names is None: + combine_names = result_name + args = deep_align(*args, join=join, copy=False, raise_on_invalid=False) coord_variables = [getattr(getattr(a, 'coords', {}), 'variables') for a in args] coords = merge_coords_without_align(coord_variables) - name = result_name(args) + name = combine_names(args) data_vars = [getattr(a, 'variable') for a in args] - variable = func(*data_vars, **kwargs) + variables = func(*data_vars, **kwargs) + + # TODO handle gufunc_signature return DataArray(variable, coords, name=name, fastpath=True) @@ -111,76 +118,35 @@ def _as_sequence(arg, cls): return cls(arg) -_ElemwiseSignature = NamedTuple( - '_ElemwiseSignature', 'broadcast_dims, core_dims, output_dims, axis') +_ElemwiseSignature = namedtuple( + '_ElemwiseSignature', 'broadcast_dims, output_dims') +class GUFuncSignature(object): + def __init__(self, inputs, outputs): + self.inputs = inputs + self.outputs = outputs -def _build_and_check_signature(variables, core_dims=None, axis_dims=None, - drop_dims=None, new_dims=None): - # All input dimension arguments are checked to appear on at least one input: - # - core_dims are not broadcast over, and moved to the right with order - # preserved. - # - axis_dims is used to generate an integer or tuples of integers `axis` - # keyword argument, which corresponds to the position of the given - # dimension on the inputs. If `axis_dims` have overlap with `core_dims`, - # no non-axis dimensions may appear in `core_dims` before an axis - # dimension. - # - drop_dims are input dimensions that should be dropped from the output. - # - # All output dimensions arguments are checked not to appear on any inputs: - # - new_dims are new dimensions that should be added to the output array, in - # order to the right of dimensions that are not dropped. + @classmethod + def from_string(cls, string): + raise NotImplementedError - if core_dims is None and drop_dims is None and axis_dims is None: - # broadcast everything - dims = tuple(_calculate_unified_dim_sizes(variables)) - return _ElemwiseSignature(dims, (), dims, None) - core_dims = () if core_dims is None else _as_sequence(core_dims, tuple) - drop_dims = set() if drop_dims is None else _as_sequence(drop_dims, set) - new_dims = () if new_dims is None else _as_sequence(new_dims, tuple) - - axis_is_scalar = axis_dims is not None and is_scalar(axis_dims) - axis_dims = set() if axis_dims is None else _as_sequence(axis_dims, set) +def _build_and_check_signature(variables, gufunc_signature): + # core_dims are not broadcast over, and moved to the right with order + # preserved. dim_sizes = _calculate_unified_dim_sizes(variables) + + if gufunc_signature is None: + # broadcast everything, one output + dims = tuple(size_dims) + return _ElemwiseSignature(dims, [dims]) + + core_dims = set(itertools.chain.from_iterable( + itertools.chain(gufunc_signature.inputs, gufunc_signature.outputs))) broadcast_dims = tuple(d for d in dim_sizes if d not in core_dims) - all_input_dims = set(dim_sizes) - - invalid = set(core_dims) - all_input_dims - if invalid: - raise ValueError('some `core_dims` not found on any input variables: ' - '%r' % list(invalid)) - - invalid = drop_dims - all_input_dims - if invalid: - raise ValueError('some `drop_dims` not found on any input variables: ' - '%r' % list(invalid)) - - invalid = axis_dims - all_input_dims - if invalid: - raise ValueError('some `axis_dims` not found on any input variables: ' - '%r' % list(invalid)) - axis = tuple(broadcast_dims.index(d) for d in axis_dims) - n_remaining_axes = len(axis_dims) - len(axis) - if n_remaining_axes > 0: - valid_core_dims_for_axis = core_dims[:remaining_axes] - if not set(valid_core_dims_for_axis) <= axis_dims: - raise ValueError('axis dimensions %r have overlap with core ' - 'dimensions %r, but do not appear at the start' - % (axis_dims, core_dims)) - axis += tuple(range(len(axis), n_remaining_axes + len(axis))) - if axis_is_scalar: - axis, = axis - - invalid = set(new_dims) ^ all_input_dims - if invalid: - raise ValueError('some `new_dims` are found on input variables: ' - '%r' % list(invalid)) - - output_dims = tuple(d for d in dim_sizes if d not in drop_dims) + new_dims - - return _ElemwiseSignature(broadcast_dims, core_dims, output_dims, axis) + output_dims = [broadcast_dims + out for out in gufunc_signature.outputs] + return _ElemwiseSignature(broadcast_dims, output_dims) def _broadcast_variable_data_to(variable, broadcast_dims, allow_dask=True): @@ -208,8 +174,7 @@ def _broadcast_variable_data_to(variable, broadcast_dims, allow_dask=True): return data -def apply_variable_ufunc(func, args, allow_dask=True, core_dims=None, - axis_dims=None, drop_dims=None, new_dims=None, +def apply_variable_ufunc(func, args, allow_dask=True, gufunc_signature=None, combine_attrs=None, kwargs=None): if kwargs is None: @@ -218,29 +183,34 @@ def apply_variable_ufunc(func, args, allow_dask=True, core_dims=None, if combine_attrs is None: combine_attrs = lambda func, attrs: None - result_attrs = combine_attrs(func, [getattr(a, 'attrs', {}) for a in args]) - - sig = _build_and_check_dims_signature( - variables, core_dims, axis_dims, drop_dims, new_dims) + sig = _build_and_check_signature(variables, gufunc_signature) - if sig.axis: - if 'axis' in kwargs: - raise ValueError('axis is already set in kwargs') - kwargs = dict(kwargs) - kwargs['axis'] = sig.axis + n_out = len(sig.output_dims) + input_attrs = [getattr(a, 'attrs', {}) for a in args] + result_attrs = [combine_attrs(input_attrs, func, n) for n in range(n_out)] list_of_data = [] for arg in args: if isinstance(arg, Variable): data = _broadcast_variable_data_to(arg, sig.broadcast_dims, allow_dask=allow_dask) - list_of_data.append(data) else: - list_of_data.append(arg) + data = arg + list_of_data.append(data) result_data = func(*list_of_data, **kwargs) - return Variable(sig.output_dims, result_data, result_attrs) + if n_out > 1: + output = [] + for dims, data, attrs in zip( + sig.output_dims, result_data, result_attrs): + output.append(Variable(dims, data, attrs)) + return tuple(output) + else: + dims, = sig.output_dims + data, = result_data + attrs = result_attrs + return Variable(dims, data, attrs) def apply_ufunc(func, args, join='inner', allow_dask=True, kwargs=None,