diff --git a/versioneer.py b/versioneer.py index 64fea1c..2b54540 100644 --- a/versioneer.py +++ b/versioneer.py @@ -1,4 +1,3 @@ - # Version: 0.18 """The Versioneer - like a rocketeer, but for versions. @@ -277,6 +276,7 @@ """ from __future__ import print_function + try: import configparser except ImportError: @@ -308,11 +308,13 @@ def get_root(): setup_py = os.path.join(root, "setup.py") versioneer_py = os.path.join(root, "versioneer.py") if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): - err = ("Versioneer was unable to run the project root directory. " - "Versioneer requires setup.py to be executed from " - "its immediate directory (like 'python setup.py COMMAND'), " - "or in a way that lets it use sys.argv[0] to find the root " - "(like 'python path/to/setup.py COMMAND').") + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) raise VersioneerBadRootError(err) try: # Certain runtime workflows (setup.py install/develop in a setuptools @@ -325,8 +327,10 @@ def get_root(): me_dir = os.path.normcase(os.path.splitext(me)[0]) vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) if me_dir != vsr_dir: - print("Warning: build in %s is using versioneer.py from %s" - % (os.path.dirname(me), versioneer_py)) + print( + "Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(me), versioneer_py) + ) except NameError: pass return root @@ -348,6 +352,7 @@ def get(parser, name): if parser.has_option("versioneer", name): return parser.get("versioneer", name) return None + cfg = VersioneerConfig() cfg.VCS = VCS cfg.style = get(parser, "style") or "" @@ -372,17 +377,18 @@ class NotThisMethod(Exception): def register_vcs_handler(vcs, method): # decorator """Decorator to mark a method as the handler for a particular VCS.""" + def decorate(f): """Store f in HANDLERS[vcs][method].""" if vcs not in HANDLERS: HANDLERS[vcs] = {} HANDLERS[vcs][method] = f return f + return decorate -def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, - env=None): +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None): """Call the given command(s).""" assert isinstance(commands, list) p = None @@ -390,10 +396,13 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, try: dispcmd = str([c] + args) # remember shell=False, so use git.cmd on windows, not just git - p = subprocess.Popen([c] + args, cwd=cwd, env=env, - stdout=subprocess.PIPE, - stderr=(subprocess.PIPE if hide_stderr - else None)) + p = subprocess.Popen( + [c] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + ) break except EnvironmentError: e = sys.exc_info()[1] @@ -418,7 +427,9 @@ def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, return stdout, p.returncode -LONG_VERSION_PY['git'] = ''' +LONG_VERSION_PY[ + "git" +] = ''' # This file helps to compute a version number in source trees obtained from # git-archive tarball (such as those provided by githubs download-from-tag # feature). Distribution tarballs (built by setup.py sdist) and build @@ -993,7 +1004,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of # just "foo-1.0". If we see a "tag: " prefix, prefer those. TAG = "tag: " - tags = set([r[len(TAG):] for r in refs if r.startswith(TAG)]) + tags = set([r[len(TAG) :] for r in refs if r.startswith(TAG)]) if not tags: # Either we're using git < 1.8.3, or there really are no tags. We use # a heuristic: assume all version tags have a digit. The old git %d @@ -1002,7 +1013,7 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): # between branches and tags. By ignoring refnames without digits, we # filter out many common branch names like "release" and # "stabilization", as well as "HEAD" and "master". - tags = set([r for r in refs if re.search(r'\d', r)]) + tags = set([r for r in refs if re.search(r"\d", r)]) if verbose: print("discarding '%s', no digits" % ",".join(refs - tags)) if verbose: @@ -1010,19 +1021,26 @@ def git_versions_from_keywords(keywords, tag_prefix, verbose): for ref in sorted(tags): # sorting will prefer e.g. "2.0" over "2.0rc1" if ref.startswith(tag_prefix): - r = ref[len(tag_prefix):] + r = ref[len(tag_prefix) :] if verbose: print("picking %s" % r) - return {"version": r, - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": None, - "date": date} + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } # no suitable tags, so version is "0+unknown", but full hex is still there if verbose: print("no suitable tags, using unknown + full revision id") - return {"version": "0+unknown", - "full-revisionid": keywords["full"].strip(), - "dirty": False, "error": "no suitable tags", "date": None} + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } @register_vcs_handler("git", "pieces_from_vcs") @@ -1037,8 +1055,7 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if sys.platform == "win32": GITS = ["git.cmd", "git.exe"] - out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, - hide_stderr=True) + out, rc = run_command(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=True) if rc != 0: if verbose: print("Directory %s not under git control" % root) @@ -1046,10 +1063,19 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] # if there isn't one, this yields HEX[-dirty] (no NUM) - describe_out, rc = run_command(GITS, ["describe", "--tags", "--dirty", - "--always", "--long", - "--match", "%s*" % tag_prefix], - cwd=root) + describe_out, rc = run_command( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + "%s*" % tag_prefix, + ], + cwd=root, + ) # --long was added in git-1.5.5 if describe_out is None: raise NotThisMethod("'git describe' failed") @@ -1072,17 +1098,16 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): dirty = git_describe.endswith("-dirty") pieces["dirty"] = dirty if dirty: - git_describe = git_describe[:git_describe.rindex("-dirty")] + git_describe = git_describe[: git_describe.rindex("-dirty")] # now we have TAG-NUM-gHEX or HEX if "-" in git_describe: # TAG-NUM-gHEX - mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) if not mo: # unparseable. Maybe git-describe is misbehaving? - pieces["error"] = ("unable to parse git-describe output: '%s'" - % describe_out) + pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out return pieces # tag @@ -1091,10 +1116,12 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): if verbose: fmt = "tag '%s' doesn't start with prefix '%s'" print(fmt % (full_tag, tag_prefix)) - pieces["error"] = ("tag '%s' doesn't start with prefix '%s'" - % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) return pieces - pieces["closest-tag"] = full_tag[len(tag_prefix):] + pieces["closest-tag"] = full_tag[len(tag_prefix) :] # distance: number of commits since tag pieces["distance"] = int(mo.group(2)) @@ -1105,13 +1132,13 @@ def git_pieces_from_vcs(tag_prefix, root, verbose, run_command=run_command): else: # HEX: no tags pieces["closest-tag"] = None - count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], - cwd=root) + count_out, rc = run_command(GITS, ["rev-list", "HEAD", "--count"], cwd=root) pieces["distance"] = int(count_out) # total number of commits # commit date: see ISO-8601 comment in git_versions_from_keywords() - date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], - cwd=root)[0].strip() + date = run_command(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) return pieces @@ -1167,16 +1194,22 @@ def versions_from_parentdir(parentdir_prefix, root, verbose): for i in range(3): dirname = os.path.basename(root) if dirname.startswith(parentdir_prefix): - return {"version": dirname[len(parentdir_prefix):], - "full-revisionid": None, - "dirty": False, "error": None, "date": None} + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } else: rootdirs.append(root) root = os.path.dirname(root) # up a level if verbose: - print("Tried directories %s but none started with prefix %s" % - (str(rootdirs), parentdir_prefix)) + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) raise NotThisMethod("rootdir doesn't start with parentdir_prefix") @@ -1205,11 +1238,13 @@ def versions_from_file(filename): contents = f.read() except EnvironmentError: raise NotThisMethod("unable to read _version.py") - mo = re.search(r"version_json = '''\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: - mo = re.search(r"version_json = '''\r\n(.*)''' # END VERSION_JSON", - contents, re.M | re.S) + mo = re.search( + r"version_json = '''\r\n(.*)''' # END VERSION_JSON", contents, re.M | re.S + ) if not mo: raise NotThisMethod("no version_json in _version.py") return json.loads(mo.group(1)) @@ -1218,8 +1253,7 @@ def versions_from_file(filename): def write_to_version_file(filename, versions): """Write the given version number to the given _version.py file.""" os.unlink(filename) - contents = json.dumps(versions, sort_keys=True, - indent=1, separators=(",", ": ")) + contents = json.dumps(versions, sort_keys=True, indent=1, separators=(",", ": ")) with open(filename, "w") as f: f.write(SHORT_VERSION_PY % contents) @@ -1251,8 +1285,7 @@ def render_pep440(pieces): rendered += ".dirty" else: # exception #1 - rendered = "0+untagged.%d.g%s" % (pieces["distance"], - pieces["short"]) + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) if pieces["dirty"]: rendered += ".dirty" return rendered @@ -1366,11 +1399,13 @@ def render_git_describe_long(pieces): def render(pieces, style): """Render the given version pieces into the requested style.""" if pieces["error"]: - return {"version": "unknown", - "full-revisionid": pieces.get("long"), - "dirty": None, - "error": pieces["error"], - "date": None} + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } if not style or style == "default": style = "pep440" # the default @@ -1390,9 +1425,13 @@ def render(pieces, style): else: raise ValueError("unknown style '%s'" % style) - return {"version": rendered, "full-revisionid": pieces["long"], - "dirty": pieces["dirty"], "error": None, - "date": pieces.get("date")} + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } class VersioneerBadRootError(Exception): @@ -1415,8 +1454,9 @@ def get_versions(verbose=False): handlers = HANDLERS.get(cfg.VCS) assert handlers, "unrecognized VCS '%s'" % cfg.VCS verbose = verbose or cfg.verbose - assert cfg.versionfile_source is not None, \ - "please set versioneer.versionfile_source" + assert ( + cfg.versionfile_source is not None + ), "please set versioneer.versionfile_source" assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" versionfile_abs = os.path.join(root, cfg.versionfile_source) @@ -1470,9 +1510,13 @@ def get_versions(verbose=False): if verbose: print("unable to compute version") - return {"version": "0+unknown", "full-revisionid": None, - "dirty": None, "error": "unable to compute version", - "date": None} + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } def get_version(): @@ -1521,6 +1565,7 @@ def run(self): print(" date: %s" % vers.get("date")) if vers["error"]: print(" error: %s" % vers["error"]) + cmds["version"] = cmd_version # we override "build_py" in both distutils and setuptools @@ -1553,14 +1598,15 @@ def run(self): # now locate _version.py in the new build/ directory and replace # it with an updated value if cfg.versionfile_build: - target_versionfile = os.path.join(self.build_lib, - cfg.versionfile_build) + target_versionfile = os.path.join(self.build_lib, cfg.versionfile_build) print("UPDATING %s" % target_versionfile) write_to_version_file(target_versionfile, versions) + cmds["build_py"] = cmd_build_py if "cx_Freeze" in sys.modules: # cx_freeze enabled? from cx_Freeze.dist import build_exe as _build_exe + # nczeczulin reports that py2exe won't like the pep440-style string # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. # setup(console=[{ @@ -1581,17 +1627,21 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["build_exe"] = cmd_build_exe del cmds["build_py"] - if 'py2exe' in sys.modules: # py2exe enabled? + if "py2exe" in sys.modules: # py2exe enabled? try: from py2exe.distutils_buildexe import py2exe as _py2exe # py3 except ImportError: @@ -1610,13 +1660,17 @@ def run(self): os.unlink(target_versionfile) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % - {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + cmds["py2exe"] = cmd_py2exe # we override different "sdist" commands for both environments @@ -1643,8 +1697,10 @@ def make_release_tree(self, base_dir, files): # updated value target_versionfile = os.path.join(base_dir, cfg.versionfile_source) print("UPDATING %s" % target_versionfile) - write_to_version_file(target_versionfile, - self._versioneer_generated_versions) + write_to_version_file( + target_versionfile, self._versioneer_generated_versions + ) + cmds["sdist"] = cmd_sdist return cmds @@ -1699,11 +1755,13 @@ def do_setup(): root = get_root() try: cfg = get_config_from_root(root) - except (EnvironmentError, configparser.NoSectionError, - configparser.NoOptionError) as e: + except ( + EnvironmentError, + configparser.NoSectionError, + configparser.NoOptionError, + ) as e: if isinstance(e, (EnvironmentError, configparser.NoSectionError)): - print("Adding sample versioneer config to setup.cfg", - file=sys.stderr) + print("Adding sample versioneer config to setup.cfg", file=sys.stderr) with open(os.path.join(root, "setup.cfg"), "a") as f: f.write(SAMPLE_CONFIG) print(CONFIG_ERROR, file=sys.stderr) @@ -1712,15 +1770,18 @@ def do_setup(): print(" creating %s" % cfg.versionfile_source) with open(cfg.versionfile_source, "w") as f: LONG = LONG_VERSION_PY[cfg.VCS] - f.write(LONG % {"DOLLAR": "$", - "STYLE": cfg.style, - "TAG_PREFIX": cfg.tag_prefix, - "PARENTDIR_PREFIX": cfg.parentdir_prefix, - "VERSIONFILE_SOURCE": cfg.versionfile_source, - }) - - ipy = os.path.join(os.path.dirname(cfg.versionfile_source), - "__init__.py") + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") if os.path.exists(ipy): try: with open(ipy, "r") as f: @@ -1762,8 +1823,10 @@ def do_setup(): else: print(" 'versioneer.py' already in MANIFEST.in") if cfg.versionfile_source not in simple_includes: - print(" appending versionfile_source ('%s') to MANIFEST.in" % - cfg.versionfile_source) + print( + " appending versionfile_source ('%s') to MANIFEST.in" + % cfg.versionfile_source + ) with open(manifest_in, "a") as f: f.write("include %s\n" % cfg.versionfile_source) else: diff --git a/xhistogram/core.py b/xhistogram/core.py index 3470506..97a1c7e 100644 --- a/xhistogram/core.py +++ b/xhistogram/core.py @@ -7,7 +7,7 @@ import numpy as np from functools import reduce from collections.abc import Iterable -from .duck_array_ops import ( +from numpy import ( digitize, bincount, reshape, @@ -19,6 +19,20 @@ # range is a keyword so save the builtin so they can use it. _range = range +try: + import dask.array as dsa + + has_dask = True +except ImportError: + has_dask = False + + +def _any_dask_array(*args): + if not has_dask: + return False + else: + return any(isinstance(a, dsa.core.Array) for a in args) + def _ensure_correctly_formatted_bins(bins, N_expected): # TODO: This could be done better / more robustly @@ -120,7 +134,7 @@ def _dispatch_bincount(bin_indices, weights, N, hist_shapes, block_size=None): return _bincount_loop(bin_indices, weights, N, hist_shapes, block_chunks) -def _histogram_2d_vectorized( +def _bincount_2d_vectorized( *args, bins=None, weights=None, density=False, right=False, block_size=None ): """Calculate the histogram independently on each row of a 2D array""" @@ -178,6 +192,57 @@ def _histogram_2d_vectorized( return bin_counts +def _bincount(*all_arrays, weights=False, axis=None, bins=None, density=None): + + # is this necessary? + all_arrays_broadcast = broadcast_arrays(*all_arrays) + + a0 = all_arrays_broadcast[0] + + do_full_array = (axis is None) or (set(axis) == set(_range(a0.ndim))) + + if do_full_array: + kept_axes_shape = (1,) * a0.ndim + else: + kept_axes_shape = tuple( + [a0.shape[i] if i not in axis else 1 for i in _range(a0.ndim)] + ) + + def reshape_input(a): + if do_full_array: + d = a.ravel()[None, :] + else: + # reshape the array to 2D + # axis 0: preserved axis after histogram + # axis 1: calculate histogram along this axis + new_pos = tuple(_range(-len(axis), 0)) + c = np.moveaxis(a, axis, new_pos) + split_idx = c.ndim - len(axis) + dims_0 = c.shape[:split_idx] + # assert dims_0 == kept_axes_shape + dims_1 = c.shape[split_idx:] + new_dim_0 = np.prod(dims_0) + new_dim_1 = np.prod(dims_1) + d = reshape(c, (new_dim_0, new_dim_1)) + return d + + all_arrays_reshaped = [reshape_input(a) for a in all_arrays_broadcast] + + if weights: + weights_array = all_arrays_reshaped.pop() + else: + weights_array = None + + bin_counts = _bincount_2d_vectorized( + *all_arrays_reshaped, bins=bins, weights=weights_array, density=density + ) + + final_shape = kept_axes_shape + bin_counts.shape[1:] + bin_counts = reshape(bin_counts, final_shape) + + return bin_counts + + def histogram( *args, bins=None, @@ -280,43 +345,24 @@ def histogram( ax_positive = ndim + ax assert ax_positive < ndim, "axis must be less than ndim" axis_normed.append(ax_positive) - axis = np.atleast_1d(axis_normed) + axis = [int(i) for i in axis_normed] - do_full_array = (axis is None) or (set(axis) == set(_range(a0.ndim))) - if do_full_array: - kept_axes_shape = None - else: - kept_axes_shape = tuple([a0.shape[i] for i in _range(a0.ndim) if i not in axis]) + all_arrays = list(args) + n_inputs = len(all_arrays) - all_args = list(args) if weights is not None: - all_args += [weights] - all_args_broadcast = broadcast_arrays(*all_args) - - def reshape_input(a): - if do_full_array: - d = a.ravel()[None, :] - else: - # reshape the array to 2D - # axis 0: preserved axis after histogram - # axis 1: calculate histogram along this axis - new_pos = tuple(_range(-len(axis), 0)) - c = np.moveaxis(a, axis, new_pos) - split_idx = c.ndim - len(axis) - dims_0 = c.shape[:split_idx] - assert dims_0 == kept_axes_shape - dims_1 = c.shape[split_idx:] - new_dim_0 = np.prod(dims_0) - new_dim_1 = np.prod(dims_1) - d = reshape(c, (new_dim_0, new_dim_1)) - return d + all_arrays.append(weights) + has_weights = True + else: + has_weights = False - all_args_reshaped = [reshape_input(a) for a in all_args_broadcast] + dtype = "i8" if not has_weights else weights.dtype - if weights is not None: - weights_reshaped = all_args_reshaped.pop() - else: - weights_reshaped = None + # here I am assuming all the arrays have the same shape + # probably needs to be generalized + input_indexes = [tuple(_range(a.ndim)) for a in all_arrays] + input_index = input_indexes[0] + assert all([ii == input_index for ii in input_indexes]) # Some sanity checks and format bins and range correctly bins = _ensure_correctly_formatted_bins(bins, n_inputs) @@ -332,17 +378,58 @@ def reshape_input(a): ) else: bins = [ - np.histogram_bin_edges(a, b, r, weights_reshaped) - for a, b, r in zip(all_args_reshaped, bins, range) + np.histogram_bin_edges(a, b, r) for a, b, r in zip(all_arrays, bins, range) ] + bincount_kwargs = dict(weights=has_weights, axis=axis, bins=bins, density=density) - bin_counts = _histogram_2d_vectorized( - *all_args_reshaped, - bins=bins, - weights=weights_reshaped, - density=density, - block_size=block_size, - ) + # keep these axes in the inputs + if axis is not None: + drop_axes = tuple([ii for ii in input_index if ii in axis]) + else: + drop_axes = input_index + + if _any_dask_array(weights, *all_arrays): + # We should be able to just apply the bin_count function to every + # block and then sum over all blocks to get the total bin count. + # The main challenge is to figure out the chunk shape that will come + # out of _bincount. We might also need to add dummy dimensions to sum + # over in the _bincount function + import dask.array as dsa + + # Important note from blockwise docs + # > Any index, like i missing from the output index is interpreted as a contraction... + # > In the case of a contraction the passed function should expect an iterable of blocks + # > on any array that holds that index. + # This means that we need to have all the input indexes present in the output index + # However, they will be reduced to singleton (len 1) dimensions + + adjust_chunks = {i: (lambda x: 1) for i in drop_axes} + + new_axes = { + max(input_index) + 1 + i: axis_len + for i, axis_len in enumerate([len(bin) - 1 for bin in bins]) + } + out_index = input_index + tuple(new_axes) + + blockwise_args = [] + for arg in all_arrays: + blockwise_args.append(arg) + blockwise_args.append(input_index) + + bin_counts = dsa.blockwise( + _bincount, + out_index, + *blockwise_args, + new_axes=new_axes, + adjust_chunks=adjust_chunks, + meta=np.array((), dtype), + **bincount_kwargs, + ) + # sum over the block dims + bin_counts = bin_counts.sum(drop_axes) + else: + # drop the extra axis used for summing over blocks + bin_counts = _bincount(*all_arrays, **bincount_kwargs).squeeze(drop_axes) if density: # Normalise by dividing by bin counts and areas such that all the @@ -360,11 +447,4 @@ def reshape_input(a): else: h = bin_counts - if h.shape[0] == 1: - assert do_full_array - h = h.squeeze() - else: - final_shape = kept_axes_shape + h.shape[1:] - h = reshape(h, final_shape) - return h, bins diff --git a/xhistogram/duck_array_ops.py b/xhistogram/duck_array_ops.py deleted file mode 100644 index b9e632e..0000000 --- a/xhistogram/duck_array_ops.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Compatibility module defining operations on duck numpy-arrays. -Shamelessly copied from xarray.""" - -import numpy as np - -try: - import dask.array as dsa - - has_dask = True -except ImportError: - has_dask = False - - -def _dask_or_eager_func(name, eager_module=np, list_of_args=False, n_array_args=1): - """Create a function that dispatches to dask for dask array inputs.""" - if has_dask: - - def f(*args, **kwargs): - dispatch_args = args[0] if list_of_args else args - if any(isinstance(a, dsa.Array) for a in dispatch_args[:n_array_args]): - module = dsa - else: - module = eager_module - return getattr(module, name)(*args, **kwargs) - - else: - - def f(*args, **kwargs): - return getattr(eager_module, name)(*args, **kwargs) - - return f - - -digitize = _dask_or_eager_func("digitize") -bincount = _dask_or_eager_func("bincount") -reshape = _dask_or_eager_func("reshape") -concatenate = _dask_or_eager_func("concatenate", list_of_args=True) -broadcast_arrays = _dask_or_eager_func("broadcast_arrays") -ravel_multi_index = _dask_or_eager_func("ravel_multi_index") diff --git a/xhistogram/test/test_duck_array_ops.py b/xhistogram/test/test_duck_array_ops.py deleted file mode 100644 index db3cbd6..0000000 --- a/xhistogram/test/test_duck_array_ops.py +++ /dev/null @@ -1,80 +0,0 @@ -import numpy as np -import dask.array as dsa -from ..duck_array_ops import ( - digitize, - bincount, - reshape, - ravel_multi_index, - broadcast_arrays, -) -from .fixtures import empty_dask_array -import pytest - - -@pytest.mark.parametrize( - "function, args", - [ - (digitize, [np.random.rand(5, 12), np.linspace(0, 1, 7)]), - (bincount, [np.arange(10)]), - ], -) -def test_eager(function, args): - a = function(*args) - assert isinstance(a, np.ndarray) - - -@pytest.mark.parametrize( - "function, args, kwargs", - [ - (digitize, [empty_dask_array((5, 12)), np.linspace(0, 1, 7)], {}), - (bincount, [empty_dask_array((10,))], {"minlength": 5}), - (reshape, [empty_dask_array((10, 5)), (5, 10)], {}), - (ravel_multi_index, (empty_dask_array((10,)), empty_dask_array((10,))), {}), - ], -) -def test_lazy(function, args, kwargs): - # make sure nothing computes - a = function(*args, **kwargs) - assert isinstance(a, dsa.core.Array) - - -@pytest.mark.parametrize("chunks", [(5, 12), (1, 12), (5, 1)]) -def test_digitize_dask_correct(chunks): - a = np.random.rand(5, 12) - da = dsa.from_array(a, chunks=chunks) - bins = np.linspace(0, 1, 7) - d = digitize(a, bins) - dd = digitize(da, bins) - np.testing.assert_array_equal(d, dd.compute()) - - -def test_ravel_multi_index_correct(): - arr = np.array([[3, 6, 6], [4, 5, 1]]) - expected = np.ravel_multi_index(arr, (7, 6)) - actual = ravel_multi_index(arr, (7, 6)) - np.testing.assert_array_equal(expected, actual) - - expected = np.ravel_multi_index(arr, (7, 6), order="F") - actual = ravel_multi_index(arr, (7, 6), order="F") - np.testing.assert_array_equal(expected, actual) - - -def test_broadcast_arrays_numpy(): - a1 = np.empty((1, 5, 25)) - a2 = np.empty((4, 1, 1)) - - a1b, a2b = broadcast_arrays(a1, a2) - assert a1b.shape == (4, 5, 25) - assert a2b.shape == (4, 5, 25) - - -@pytest.mark.parametrize("d1_chunks", [(5 * (1,), (25,)), ((2, 3), (25,))]) -def test_broadcast_arrays_dask(d1_chunks): - d1 = dsa.empty((5, 25), chunks=d1_chunks) - d2 = dsa.empty((1, 25), chunks=(1, 25)) - - d1b, d2b = broadcast_arrays(d1, d2) - assert d1b.shape == (5, 25) - assert d2b.shape == (5, 25) - assert d1b.chunks == d1_chunks - assert d2b.chunks == d1_chunks