From 4b3ae356a327d69dbc3bb246b21299be22a096a7 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 2 Nov 2014 03:01:15 -0500 Subject: [PATCH 1/7] API/ENH: IntervalIndex Fixes GH7640, GH8625 --- pandas/core/algorithms.py | 8 +- pandas/core/api.py | 1 + pandas/core/common.py | 6 +- pandas/core/index.py | 16 +- pandas/core/interval.py | 521 ++++++++++ pandas/hashtable.pxd | 14 + pandas/hashtable.pyx | 12 +- pandas/lib.pyx | 2 + pandas/src/generate_intervaltree.py | 401 +++++++ pandas/src/inference.pyx | 17 + pandas/src/interval.pyx | 131 +++ pandas/src/intervaltree.pyx | 1492 +++++++++++++++++++++++++++ pandas/tests/test_algos.py | 16 +- pandas/tests/test_base.py | 12 +- pandas/tests/test_categorical.py | 14 +- pandas/tests/test_frame.py | 11 + pandas/tests/test_index.py | 23 +- pandas/tests/test_indexing.py | 28 +- pandas/tests/test_interval.py | 537 ++++++++++ pandas/tools/tests/test_tile.py | 116 +-- pandas/tools/tile.py | 124 ++- pandas/util/testing.py | 8 +- 22 files changed, 3365 insertions(+), 145 deletions(-) create mode 100644 pandas/core/interval.py create mode 100644 pandas/src/generate_intervaltree.py create mode 100644 pandas/src/interval.pyx create mode 100644 pandas/src/intervaltree.pyx create mode 100644 pandas/tests/test_interval.py diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index e5347f03b5462..5ff780b5e5593 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -262,10 +262,9 @@ def value_counts(values, sort=True, ascending=False, normalize=False, if bins is not None: try: - cat, bins = cut(values, bins, retbins=True) + values, bins = cut(values, bins, retbins=True) except TypeError: raise TypeError("bins argument only works with numeric data.") - values = cat.codes if com.is_categorical_dtype(values.dtype): result = values.value_counts(dropna) @@ -320,11 +319,6 @@ def value_counts(values, sort=True, ascending=False, normalize=False, keys = Index(keys) result = Series(counts, index=keys, name=name) - if bins is not None: - # TODO: This next line should be more efficient - result = result.reindex(np.arange(len(cat.categories)), fill_value=0) - result.index = bins[:-1] - if sort: result = result.sort_values(ascending=ascending) diff --git a/pandas/core/api.py b/pandas/core/api.py index e2ac57e37cba6..8c446b0922e72 100644 --- a/pandas/core/api.py +++ b/pandas/core/api.py @@ -9,6 +9,7 @@ from pandas.core.groupby import Grouper from pandas.core.format import set_eng_float_format from pandas.core.index import Index, CategoricalIndex, Int64Index, Float64Index, MultiIndex +from pandas.core.interval import Interval, IntervalIndex from pandas.core.series import Series, TimeSeries from pandas.core.frame import DataFrame diff --git a/pandas/core/common.py b/pandas/core/common.py index e81b58a3f7eef..ddc74b7b34899 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -83,6 +83,7 @@ def _check(cls, inst): ABCTimedeltaIndex = create_pandas_abc_type("ABCTimedeltaIndex", "_typ", ("timedeltaindex",)) ABCPeriodIndex = create_pandas_abc_type("ABCPeriodIndex", "_typ", ("periodindex",)) ABCCategoricalIndex = create_pandas_abc_type("ABCCategoricalIndex", "_typ", ("categoricalindex",)) +ABCIntervalIndex = create_pandas_abc_type("ABCIntervalIndex", "_typ", ("intervalindex",)) ABCIndexClass = create_pandas_abc_type("ABCIndexClass", "_typ", ("index", "int64index", "float64index", @@ -90,8 +91,9 @@ def _check(cls, inst): "datetimeindex", "timedeltaindex", "periodindex", - "categoricalindex")) - + "categoricalindex", + "intervalindex")) +ABCInterval = create_pandas_abc_type("ABCInterval", "_typ", ("interval",)) ABCSeries = create_pandas_abc_type("ABCSeries", "_typ", ("series",)) ABCDataFrame = create_pandas_abc_type("ABCDataFrame", "_typ", ("dataframe",)) ABCPanel = create_pandas_abc_type("ABCPanel", "_typ", ("panel",)) diff --git a/pandas/core/index.py b/pandas/core/index.py index 1433d755d294d..26a0cd8e4f0b5 100644 --- a/pandas/core/index.py +++ b/pandas/core/index.py @@ -173,6 +173,9 @@ def __new__(cls, data=None, dtype=None, copy=False, name=None, fastpath=False, return Int64Index(subarr.astype('i8'), copy=copy, name=name) elif inferred in ['floating', 'mixed-integer-float']: return Float64Index(subarr, copy=copy, name=name) + elif inferred == 'interval': + from pandas.core.interval import IntervalIndex + return IntervalIndex.from_intervals(subarr, name=name) elif inferred == 'boolean': # don't support boolean explicity ATM pass @@ -829,7 +832,7 @@ def _mpl_repr(self): @property def is_monotonic(self): """ alias for is_monotonic_increasing (deprecated) """ - return self._engine.is_monotonic_increasing + return self.is_monotonic_increasing @property def is_monotonic_increasing(self): @@ -1633,7 +1636,7 @@ def union(self, other): def _wrap_union_result(self, other, result): name = self.name if self.name == other.name else None - return self.__class__(data=result, name=name) + return self._constructor(data=result, name=name) def intersection(self, other): """ @@ -2671,6 +2674,13 @@ def _searchsorted_monotonic(self, label, side='left'): raise ValueError('index must be monotonic increasing or decreasing') + def _get_loc_only_exact_matches(self, key): + """ + This is overriden on subclasses (namely, IntervalIndex) to control + get_slice_bound. + """ + return self.get_loc(key) + def get_slice_bound(self, label, side, kind): """ Calculate slice bound that corresponds to given label. @@ -2698,7 +2708,7 @@ def get_slice_bound(self, label, side, kind): # we need to look up the label try: - slc = self.get_loc(label) + slc = self._get_loc_only_exact_matches(label) except KeyError as err: try: return self._searchsorted_monotonic(label, side) diff --git a/pandas/core/interval.py b/pandas/core/interval.py new file mode 100644 index 0000000000000..68e07f21367a0 --- /dev/null +++ b/pandas/core/interval.py @@ -0,0 +1,521 @@ +import operator + +import numpy as np +import pandas as pd + +from pandas.core.base import PandasObject, IndexOpsMixin +from pandas.core.common import (_values_from_object, _ensure_platform_int, + notnull, is_datetime_or_timedelta_dtype, + is_integer_dtype, is_float_dtype) +from pandas.core.index import (Index, _ensure_index, default_pprint, + InvalidIndexError, MultiIndex) +from pandas.lib import (Interval, IntervalMixin, IntervalTree, + interval_bounds_to_intervals, + intervals_to_interval_bounds) +from pandas.util.decorators import cache_readonly +import pandas.core.common as com + + +_VALID_CLOSED = set(['left', 'right', 'both', 'neither']) + + +def _get_next_label(label): + dtype = getattr(label, 'dtype', type(label)) + if isinstance(label, (pd.Timestamp, pd.Timedelta)): + dtype = 'datetime64' + if is_datetime_or_timedelta_dtype(dtype): + return label + np.timedelta64(1, 'ns') + elif is_integer_dtype(dtype): + return label + 1 + elif is_float_dtype(dtype): + return np.nextafter(label, np.infty) + else: + raise TypeError('cannot determine next label for type %r' + % type(label)) + + +def _get_prev_label(label): + dtype = getattr(label, 'dtype', type(label)) + if isinstance(label, (pd.Timestamp, pd.Timedelta)): + dtype = 'datetime64' + if is_datetime_or_timedelta_dtype(dtype): + return label - np.timedelta64(1, 'ns') + elif is_integer_dtype(dtype): + return label - 1 + elif is_float_dtype(dtype): + return np.nextafter(label, -np.infty) + else: + raise TypeError('cannot determine next label for type %r' + % type(label)) + + +def _get_interval_closed_bounds(interval): + """ + Given an Interval or IntervalIndex, return the corresponding interval with + closed bounds. + """ + left, right = interval.left, interval.right + if interval.open_left: + left = _get_next_label(left) + if interval.open_right: + right = _get_prev_label(right) + return left, right + + +class IntervalIndex(IntervalMixin, Index): + """ + Immutable Index implementing an ordered, sliceable set. IntervalIndex + represents an Index of intervals that are all closed on the same side. + + .. versionadded:: 0.18 + + Properties + ---------- + left, right : array-like (1-dimensional) + Left and right bounds for each interval. + closed : {'left', 'right', 'both', 'neither'}, optional + Whether the intervals are closed on the left-side, right-side, both or + neither. Defaults to 'right'. + name : object, optional + Name to be stored in the index. + """ + _typ = 'intervalindex' + _comparables = ['name'] + _attributes = ['name', 'closed'] + _allow_index_ops = True + _engine = None # disable it + + def __new__(cls, left, right, closed='right', name=None, fastpath=False): + # TODO: validation + result = IntervalMixin.__new__(cls) + result._left = _ensure_index(left) + result._right = _ensure_index(right) + result._closed = closed + result.name = name + if not fastpath: + result._validate() + result._reset_identity() + return result + + def _validate(self): + """Verify that the IntervalIndex is valid. + """ + # TODO: exclude periods? + if self.closed not in _VALID_CLOSED: + raise ValueError("invalid options for 'closed': %s" % self.closed) + if len(self.left) != len(self.right): + raise ValueError('left and right must have the same length') + left_valid = notnull(self.left) + right_valid = notnull(self.right) + if not (left_valid == right_valid).all(): + raise ValueError('missing values must be missing in the same ' + 'location both left and right sides') + if not (self.left[left_valid] <= self.right[left_valid]).all(): + raise ValueError('left side of interval must be <= right side') + + def _simple_new(cls, values, name=None, **kwargs): + # ensure we don't end up here (this is a superclass method) + raise NotImplementedError + + def _cleanup(self): + pass + + @property + def _engine(self): + raise NotImplementedError + + @cache_readonly + def _tree(self): + return IntervalTree(self.left, self.right, closed=self.closed) + + @property + def _constructor(self): + return type(self).from_intervals + + @classmethod + def from_breaks(cls, breaks, closed='right', name=None): + """ + Construct an IntervalIndex from an array of splits + + Parameters + ---------- + breaks : array-like (1-dimensional) + Left and right bounds for each interval. + closed : {'left', 'right', 'both', 'neither'}, optional + Whether the intervals are closed on the left-side, right-side, both + or neither. Defaults to 'right'. + name : object, optional + Name to be stored in the index. + + Examples + -------- + + >>> IntervalIndex.from_breaks([0, 1, 2, 3]) + IntervalIndex(left=[0, 1, 2], + right=[1, 2, 3], + closed='right') + """ + return cls(breaks[:-1], breaks[1:], closed, name) + + @classmethod + def from_intervals(cls, data, name=None): + """ + Construct an IntervalIndex from a 1d array of Interval objects + + Parameters + ---------- + data : array-like (1-dimensional) + Array of Interval objects. All intervals must be closed on the same + sides. + name : object, optional + Name to be stored in the index. + + Examples + -------- + + >>> IntervalIndex.from_intervals([Interval(0, 1), Interval(1, 2)]) + IntervalIndex(left=[0, 1], + right=[1, 2], + closed='right') + + The generic Index constructor work identically when it infers an array + of all intervals: + + >>> Index([Interval(0, 1), Interval(1, 2)]) + IntervalIndex(left=[0, 1], + right=[1, 2], + closed='right') + """ + data = np.asarray(data) + left, right, closed = intervals_to_interval_bounds(data) + return cls(left, right, closed, name) + + @classmethod + def from_tuples(cls, data, closed='right', name=None): + left = [] + right = [] + for l, r in data: + left.append(l) + right.append(r) + return cls(np.array(left), np.array(right), closed, name) + + def to_tuples(self): + return Index(com._asarray_tuplesafe(zip(self.left, self.right))) + + @cache_readonly + def _multiindex(self): + return MultiIndex.from_arrays([self.left, self.right], + names=['left', 'right']) + + @property + def left(self): + return self._left + + @property + def right(self): + return self._right + + @property + def closed(self): + return self._closed + + def __len__(self): + return len(self.left) + + @cache_readonly + def values(self): + """Returns the IntervalIndex's data as a numpy array of Interval + objects (with dtype='object') + """ + left = np.asarray(self.left) + right = np.asarray(self.right) + return interval_bounds_to_intervals(left, right, self.closed) + + def __array__(self, result=None): + """ the array interface, return my values """ + return self.values + + def __array_wrap__(self, result, context=None): + # we don't want the superclass implementation + return result + + def _array_values(self): + return self.values + + def __reduce__(self): + return self.__class__, (self.left, self.right, self.closed, self.name) + + def _shallow_copy(self, values=None, name=None): + name = name if name is not None else self.name + if values is not None: + return type(self).from_intervals(values, name=name) + else: + return self.copy(name=name) + + def copy(self, deep=False, name=None): + left = self.left.copy(deep=True) if deep else self.left + right = self.right.copy(deep=True) if deep else self.right + name = name if name is not None else self.name + return type(self)(left, right, closed=self.closed, name=name, + fastpath=True) + + @cache_readonly + def dtype(self): + return np.dtype('O') + + @cache_readonly + def mid(self): + """Returns the mid-point of each interval in the index as an array + """ + try: + return Index(0.5 * (self.left.values + self.right.values)) + except TypeError: + # datetime safe version + delta = self.right.values - self.left.values + return Index(self.left.values + 0.5 * delta) + + @cache_readonly + def is_monotonic_increasing(self): + return self._multiindex.is_monotonic_increasing + + @cache_readonly + def is_monotonic_decreasing(self): + return self._multiindex.is_monotonic_decreasing + + @cache_readonly + def is_unique(self): + return self._multiindex.is_unique + + @cache_readonly + def is_non_overlapping_monotonic(self): + # must be increasing (e.g., [0, 1), [1, 2), [2, 3), ... ) + # or decreasing (e.g., [-1, 0), [-2, -1), [-3, -2), ...) + # we already require left <= right + return ((self.right[:-1] <= self.left[1:]).all() or + (self.left[:-1] >= self.right[1:]).all()) + + def _convert_scalar_indexer(self, key, kind=None): + return key + + def _maybe_cast_slice_bound(self, label, side, kind): + return getattr(self, side)._maybe_cast_slice_bound(label, side, kind) + + def _convert_list_indexer(self, keyarr, kind=None): + """ + we are passed a list indexer. + Return our indexer or raise if all of the values are not included in the categories + """ + locs = self.get_indexer(keyarr) + # TODO: handle keyarr if it includes intervals + if (locs == -1).any(): + raise KeyError("a list-indexer must only include existing intervals") + + return locs + + def _check_method(self, method): + if method is not None: + raise NotImplementedError( + 'method %r not yet implemented for IntervalIndex' % method) + + def _searchsorted_monotonic(self, label, side, exclude_label=False): + if not self.is_non_overlapping_monotonic: + raise KeyError('can only get slices from an IntervalIndex if ' + 'bounds are non-overlapping and all monotonic ' + 'increasing or decreasing') + + if isinstance(label, IntervalMixin): + raise NotImplementedError + + if ((side == 'left' and self.left.is_monotonic_increasing) or + (side == 'right' and self.left.is_monotonic_decreasing)): + sub_idx = self.right + if self.open_right or exclude_label: + label = _get_next_label(label) + else: + sub_idx = self.left + if self.open_left or exclude_label: + label = _get_prev_label(label) + + return sub_idx._searchsorted_monotonic(label, side) + + def _get_loc_only_exact_matches(self, key): + return self._multiindex._tuple_index.get_loc(key) + + def _find_non_overlapping_monotonic_bounds(self, key): + if isinstance(key, IntervalMixin): + start = self._searchsorted_monotonic( + key.left, 'left', exclude_label=key.open_left) + stop = self._searchsorted_monotonic( + key.right, 'right', exclude_label=key.open_right) + else: + # scalar + start = self._searchsorted_monotonic(key, 'left') + stop = self._searchsorted_monotonic(key, 'right') + return start, stop + + def get_loc(self, key, method=None): + self._check_method(method) + + original_key = key + + if self.is_non_overlapping_monotonic: + if isinstance(key, Interval): + left = self._maybe_cast_slice_bound(key.left, 'left', None) + right = self._maybe_cast_slice_bound(key.right, 'right', None) + key = Interval(left, right, key.closed) + else: + key = self._maybe_cast_slice_bound(key, 'left', None) + + start, stop = self._find_non_overlapping_monotonic_bounds(key) + + if start + 1 == stop: + return start + elif start < stop: + return slice(start, stop) + else: + raise KeyError(original_key) + + else: + # use the interval tree + if isinstance(key, Interval): + left, right = _get_interval_closed_bounds(key) + return self._tree.get_loc_interval(left, right) + else: + return self._tree.get_loc(key) + + def get_value(self, series, key): + # this method seems necessary for Series.__getitem__ but I have no idea + # what it should actually do here... + loc = self.get_loc(key) # nb. this can't handle slice objects + return series.iloc[loc] + + def get_indexer(self, target, method=None, limit=None, tolerance=None): + self._check_method(method) + target = _ensure_index(target) + + if self.is_non_overlapping_monotonic: + start, stop = self._find_non_overlapping_monotonic_bounds(target) + + start_plus_one = start + 1 + if (start_plus_one < stop).any(): + raise ValueError('indexer corresponds to non-unique elements') + return np.where(start_plus_one == stop, start, -1) + + else: + if isinstance(target, IntervalIndex): + raise NotImplementedError( + 'have not yet implemented get_indexer ' + 'for IntervalIndex indexers') + else: + return self._tree.get_indexer(target) + + def delete(self, loc): + new_left = self.left.delete(loc) + new_right = self.right.delete(loc) + return type(self)(new_left, new_right, self.closed, self.name, + fastpath=True) + + def insert(self, loc, item): + if not isinstance(item, Interval): + raise ValueError('can only insert Interval objects into an ' + 'IntervalIndex') + if not item.closed == self.closed: + raise ValueError('inserted item must be closed on the same side ' + 'as the index') + new_left = self.left.insert(loc, item.left) + new_right = self.right.insert(loc, item.right) + return type(self)(new_left, new_right, self.closed, self.name, + fastpath=True) + + def _as_like_interval_index(self, other, error_msg): + self._assert_can_do_setop(other) + other = _ensure_index(other) + if (not isinstance(other, IntervalIndex) or + self.closed != other.closed): + raise ValueError(error_msg) + return other + + def append(self, other): + msg = ('can only append two IntervalIndex objects that are closed on ' + 'the same side') + other = self._as_like_interval_index(other, msg) + new_left = self.left.append(other.left) + new_right = self.right.append(other.right) + if other.name is not None and other.name != self.name: + name = None + else: + name = self.name + return type(self)(new_left, new_right, self.closed, name, + fastpath=True) + + def take(self, indexer, axis=0): + indexer = com._ensure_platform_int(indexer) + new_left = self.left.take(indexer) + new_right = self.right.take(indexer) + return type(self)(new_left, new_right, self.closed, self.name, + fastpath=True) + + def __contains__(self, key): + try: + self.get_loc(key) + return True + except KeyError: + return False + + def __getitem__(self, value): + left = self.left[value] + right = self.right[value] + if not isinstance(left, Index): + return Interval(left, right, self.closed) + else: + return type(self)(left, right, self.closed, self.name) + + # __repr__ associated methods are based on MultiIndex + + def _format_attrs(self): + attrs = [('left', default_pprint(self.left)), + ('right', default_pprint(self.right)), + ('closed', repr(self.closed))] + if self.name is not None: + attrs.append(('name', default_pprint(self.name))) + return attrs + + def _format_space(self): + return "\n%s" % (' ' * (len(self.__class__.__name__) + 1)) + + def _format_data(self): + return None + + def argsort(self, *args, **kwargs): + return np.lexsort((self.right, self.left)) + + def equals(self, other): + if self.is_(other): + return True + try: + return (self.left.equals(other.left) + and self.right.equals(other.right) + and self.closed == other.closed) + except AttributeError: + return False + + def _setop(op_name): + def func(self, other): + msg = ('can only do set operations between two IntervalIndex ' + 'objects that are closed on the same side') + other = self._as_like_interval_index(other, msg) + result = getattr(self._multiindex, op_name)(other._multiindex) + result_name = self.name if self.name == other.name else None + return type(self).from_tuples(result.values, closed=self.closed, + name=result_name) + return func + + union = _setop('union') + intersection = _setop('intersection') + difference = _setop('difference') + sym_diff = _setop('sym_diff') + + # TODO: arithmetic operations + + +IntervalIndex._add_logical_methods_disabled() diff --git a/pandas/hashtable.pxd b/pandas/hashtable.pxd index 97b6687d061e9..593d8b8a1833a 100644 --- a/pandas/hashtable.pxd +++ b/pandas/hashtable.pxd @@ -1,4 +1,5 @@ from khash cimport kh_int64_t, kh_float64_t, kh_pymap_t, int64_t, float64_t +from numpy cimport ndarray # prototypes for sharing @@ -22,3 +23,16 @@ cdef class PyObjectHashTable(HashTable): cpdef get_item(self, object val) cpdef set_item(self, object key, Py_ssize_t val) + +cdef struct Int64VectorData: + int64_t *data + size_t n, m + +cdef class Int64Vector: + cdef Int64VectorData *data + cdef ndarray ao + + cdef resize(self) + cpdef to_array(self) + cdef inline void append(self, int64_t x) + cdef extend(self, int64_t[:] x) diff --git a/pandas/hashtable.pyx b/pandas/hashtable.pyx index dfa7930ada62f..c5fb65c1acee7 100644 --- a/pandas/hashtable.pyx +++ b/pandas/hashtable.pyx @@ -96,11 +96,8 @@ cdef void append_data(vector_data *data, sixty_four_bit_scalar x) nogil: data.data[data.n] = x data.n += 1 -cdef class Int64Vector: - cdef: - Int64VectorData *data - ndarray ao +cdef class Int64Vector: def __cinit__(self): self.data = PyMem_Malloc(sizeof(Int64VectorData)) @@ -122,7 +119,7 @@ cdef class Int64Vector: def __len__(self): return self.data.n - def to_array(self): + cpdef to_array(self): self.ao.resize(self.data.n) self.data.m = self.data.n return self.ao @@ -134,6 +131,11 @@ cdef class Int64Vector: append_data(self.data, x) + cdef extend(self, int64_t[:] x): + for i in range(len(x)): + self.append(x[i]) + + cdef class Float64Vector: cdef: diff --git a/pandas/lib.pyx b/pandas/lib.pyx index f7978c4791538..e0390eeb4e1f7 100644 --- a/pandas/lib.pyx +++ b/pandas/lib.pyx @@ -1896,4 +1896,6 @@ cdef class BlockPlacement: include "reduce.pyx" include "properties.pyx" +include "interval.pyx" +include "intervaltree.pyx" include "inference.pyx" diff --git a/pandas/src/generate_intervaltree.py b/pandas/src/generate_intervaltree.py new file mode 100644 index 0000000000000..c2dfac86f0ad2 --- /dev/null +++ b/pandas/src/generate_intervaltree.py @@ -0,0 +1,401 @@ +""" +This file generates `intervaltree.pyx` which is then included in `../lib.pyx` +during building. To regenerate `intervaltree.pyx`, just run: + + `python generate_intervaltree.py`. +""" +from __future__ import print_function +import os +from pandas.compat import StringIO +import numpy as np + + +warning_to_new_contributors = """ +# DO NOT EDIT THIS FILE: This file was autogenerated from +# generate_intervaltree.py, so please edit that file and then run +# `python2 generate_intervaltree.py` to re-generate this file. +""" + +header = r''' +from numpy cimport int64_t, float64_t +from numpy cimport ndarray, PyArray_ArgSort, NPY_QUICKSORT, PyArray_Take +import numpy as np + +cimport cython +cimport numpy as cnp +cnp.import_array() + +from hashtable cimport Int64Vector, Int64VectorData + + +ctypedef fused scalar64_t: + float64_t + int64_t + + +NODE_CLASSES = {} + + +cdef class IntervalTree(IntervalMixin): + """A centered interval tree + + Based off the algorithm described on Wikipedia: + http://en.wikipedia.org/wiki/Interval_tree + """ + cdef: + readonly object left, right, root + readonly str closed + object _left_sorter, _right_sorter + + def __init__(self, left, right, closed='right', leaf_size=100): + """ + Parameters + ---------- + left, right : np.ndarray[ndim=1] + Left and right bounds for each interval. Assumed to contain no + NaNs. + closed : {'left', 'right', 'both', 'neither'}, optional + Whether the intervals are closed on the left-side, right-side, both + or neither. Defaults to 'right'. + leaf_size : int, optional + Parameter that controls when the tree switches from creating nodes + to brute-force search. Tune this parameter to optimize query + performance. + """ + if closed not in ['left', 'right', 'both', 'neither']: + raise ValueError("invalid option for 'closed': %s" % closed) + + left = np.asarray(left) + right = np.asarray(right) + dtype = np.result_type(left, right) + self.left = np.asarray(left, dtype=dtype) + self.right = np.asarray(right, dtype=dtype) + + indices = np.arange(len(left), dtype='int64') + + self.closed = closed + + node_cls = NODE_CLASSES[str(dtype), closed] + self.root = node_cls(self.left, self.right, indices, leaf_size) + + @property + def left_sorter(self): + """How to sort the left labels; this is used for binary search + """ + if self._left_sorter is None: + self._left_sorter = np.argsort(self.left) + return self._left_sorter + + @property + def right_sorter(self): + """How to sort the right labels + """ + if self._right_sorter is None: + self._right_sorter = np.argsort(self.right) + return self._right_sorter + + def get_loc(self, scalar64_t key): + """Return all positions corresponding to intervals that overlap with + the given scalar key + """ + result = Int64Vector() + self.root.query(result, key) + if not result.data.n: + raise KeyError(key) + return result.to_array() + + def _get_partial_overlap(self, key_left, key_right, side): + """Return all positions corresponding to intervals with the given side + falling between the left and right bounds of an interval query + """ + if side == 'left': + values = self.left + sorter = self.left_sorter + else: + values = self.right + sorter = self.right_sorter + key = [key_left, key_right] + i, j = values.searchsorted(key, sorter=sorter) + return sorter[i:j] + + def get_loc_interval(self, key_left, key_right): + """Lookup the intervals enclosed in the given interval bounds + + The given interval is presumed to have closed bounds. + """ + import pandas as pd + left_overlap = self._get_partial_overlap(key_left, key_right, 'left') + right_overlap = self._get_partial_overlap(key_left, key_right, 'right') + enclosing = self.get_loc(0.5 * (key_left + key_right)) + combined = np.concatenate([left_overlap, right_overlap, enclosing]) + uniques = pd.unique(combined) + return uniques + + def get_indexer(self, scalar64_t[:] target): + """Return the positions corresponding to unique intervals that overlap + with the given array of scalar targets. + """ + # TODO: write get_indexer_intervals + cdef: + int64_t old_len, i + Int64Vector result + + result = Int64Vector() + old_len = 0 + for i in range(len(target)): + self.root.query(result, target[i]) + if result.data.n == old_len: + result.append(-1) + elif result.data.n > old_len + 1: + raise KeyError( + 'indexer does not intersect a unique set of intervals') + old_len = result.data.n + return result.to_array() + + def get_indexer_non_unique(self, scalar64_t[:] target): + """Return the positions corresponding to intervals that overlap with + the given array of scalar targets. Non-unique positions are repeated. + """ + cdef: + int64_t old_len, i + Int64Vector result, missing + + result = Int64Vector() + missing = Int64Vector() + old_len = 0 + for i in range(len(target)): + self.root.query(result, target[i]) + if result.data.n == old_len: + result.append(-1) + missing.append(i) + old_len = result.data.n + return result.to_array(), missing.to_array() + + def __repr__(self): + return ('' + % self.root.n_elements) + + +cdef take(ndarray source, ndarray indices): + """Take the given positions from a 1D ndarray + """ + return PyArray_Take(source, indices, 0) + + +cdef sort_values_and_indices(all_values, all_indices, subset): + indices = take(all_indices, subset) + values = take(all_values, subset) + sorter = PyArray_ArgSort(values, 0, NPY_QUICKSORT) + sorted_values = take(values, sorter) + sorted_indices = take(indices, sorter) + return sorted_values, sorted_indices +''' + +# we need specialized nodes and leaves to optimize for different dtype and +# closed values +# unfortunately, fused dtypes can't parameterize attributes on extension types, +# so we're stuck using template generation. + +node_template = r''' +cdef class {dtype_title}Closed{closed_title}IntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + readonly left_node, right_node + readonly {dtype}_t[:] center_left_values, center_right_values + readonly int64_t[:] center_left_indices, center_right_indices + readonly {dtype}_t pivot + readonly int64_t n_elements, leaf_size + + def __init__(self, + ndarray[{dtype}_t, ndim=1] left, + ndarray[{dtype}_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.pivot = np.median(left + right) / 2 + self.n_elements = len(left) + self.leaf_size = leaf_size + + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, {dtype}_t[:] left, {dtype}_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + int i + Int64Vector left_ind, right_ind, overlapping_ind + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(len(left)): + if right[i] {cmp_right_converse} self.pivot: + left_ind.append(i) + elif self.pivot {cmp_left_converse} left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[{dtype}_t, ndim=1] left, + ndarray[{dtype}_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + + This should be a terminal leaf node if the number of indices is smaller + than leaf_size. Otherwise it should be a non-terminal node. + """ + + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + + if len(indices) <= self.leaf_size: + return {dtype_title}Closed{closed_title}IntervalLeaf( + left, right, indices) + else: + return {dtype_title}Closed{closed_title}IntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + {dtype}_t[:] values + int i + + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(len(values)): + if not values[i] {cmp_left} point: + break + result.append(indices[i]) + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(len(values) - 1, -1, -1): + if not point {cmp_right} values[i]: + break + result.append(indices[i]) + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + return ('<{dtype_title}Closed{closed_title}IntervalNode: pivot %s, ' + '%s elements (%s left, %s right, %s overlapping)>' % + (self.pivot, self.n_elements, self.left_node.n_elements, + self.right_node.n_elements, len(self.center_left_indices))) + + def counts(self): + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['{dtype}', '{closed}'] = {dtype_title}Closed{closed_title}IntervalNode + + +cdef class {dtype_title}Closed{closed_title}IntervalLeaf: + """Terminal node for an IntervalTree + + Once we get down to a certain size, it doens't make sense to continue the + binary tree structure. Instead, we store interval bounds in 1d arrays use + linear search. + """ + cdef: + readonly {dtype}_t[:] left, right + readonly int64_t[:] indices + + def __init__(self, + {dtype}_t[:] left, + {dtype}_t[:] right, + int64_t[:] indices): + self.left = left + self.right = right + self.indices = indices + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + for i in range(len(self.left)): + if self.left[i] {cmp_left} point {cmp_right} self.right[i]: + result.append(self.indices[i]) + + def __repr__(self): + return ('<{dtype_title}Closed{closed_title}IntervalLeaf: %s elements>' + % self.n_elements) + + @property + def n_elements(self): + return len(self.left) + + def counts(self): + return self.n_elements +''' + + +def generate_node_template(): + output = StringIO() + for dtype in ['float64', 'int64']: + for closed, cmp_left, cmp_right in [ + ('left', '<=', '<'), + ('right', '<', '<='), + ('both', '<=', '<='), + ('neither', '<', '<')]: + cmp_left_converse = '<' if cmp_left == '<=' else '<=' + cmp_right_converse = '<' if cmp_right == '<=' else '<=' + classes = node_template.format(dtype=dtype, + dtype_title=dtype.title(), + closed=closed, + closed_title=closed.title(), + cmp_left=cmp_left, + cmp_right=cmp_right, + cmp_left_converse=cmp_left_converse, + cmp_right_converse=cmp_right_converse) + output.write(classes) + output.write("\n") + return output.getvalue() + + +def generate_cython_file(): + # Put `intervaltree.pyx` in the same directory as this file + directory = os.path.dirname(os.path.realpath(__file__)) + filename = 'intervaltree.pyx' + path = os.path.join(directory, filename) + + with open(path, 'w') as f: + print(warning_to_new_contributors, file=f) + print(header, file=f) + print(generate_node_template(), file=f) + + +if __name__ == '__main__': + generate_cython_file() diff --git a/pandas/src/inference.pyx b/pandas/src/inference.pyx index 1a5703eb91053..4b3716cbcb5ae 100644 --- a/pandas/src/inference.pyx +++ b/pandas/src/inference.pyx @@ -194,6 +194,10 @@ def infer_dtype(object _values): if is_period_array(values): return 'period' + elif is_interval(val): + if is_interval_array(values): + return 'interval' + for i in range(n): val = util.get_value_1d(values, i) if util.is_integer_object(val): @@ -539,6 +543,19 @@ def is_period_array(ndarray[object] values): return False return True +cdef inline bint is_interval(object o): + return isinstance(o, Interval) + +def is_interval_array(ndarray[object] values): + cdef Py_ssize_t i, n = len(values) + + if n == 0: + return False + for i in range(n): + if not is_interval(values[i]): + return False + return True + cdef extern from "parse_helper.h": inline int floatify(object, double *result, int *maybe_int) except -1 diff --git a/pandas/src/interval.pyx b/pandas/src/interval.pyx new file mode 100644 index 0000000000000..5e51e34a9b5d4 --- /dev/null +++ b/pandas/src/interval.pyx @@ -0,0 +1,131 @@ +cimport numpy as np +import numpy as np +import pandas as pd + +cimport cython +import cython + +from cpython.object cimport (Py_EQ, Py_NE, Py_GT, Py_LT, Py_GE, Py_LE, + PyObject_RichCompare) + + +_VALID_CLOSED = frozenset(['left', 'right', 'both', 'neither']) + + +cdef class IntervalMixin: + property closed_left: + def __get__(self): + return self.closed == 'left' or self.closed == 'both' + + property closed_right: + def __get__(self): + return self.closed == 'right' or self.closed == 'both' + + property open_left: + def __get__(self): + return not self.closed_left + + property open_right: + def __get__(self): + return not self.closed_right + + property mid: + def __get__(self): + try: + return 0.5 * (self.left + self.right) + except TypeError: + # datetime safe version + return self.left + 0.5 * (self.right - self.left) + + +cdef _interval_like(other): + return (hasattr(other, 'left') + and hasattr(other, 'right') + and hasattr(other, 'closed')) + + +cdef class Interval(IntervalMixin): + cdef readonly object left, right + cdef readonly str closed + + def __init__(self, left, right, str closed='right'): + # note: it is faster to just do these checks than to use a special + # constructor (__cinit__/__new__) to avoid them + if closed not in _VALID_CLOSED: + raise ValueError("invalid option for 'closed': %s" % closed) + if not left <= right: + raise ValueError('left side of interval must be <= right side') + self.left = left + self.right = right + self.closed = closed + + def __hash__(self): + return hash((self.left, self.right, self.closed)) + + def __contains__(self, key): + if _interval_like(key): + raise TypeError('__contains__ not defined for two intervals') + return ((self.left < key if self.open_left else self.left <= key) and + (key < self.right if self.open_right else key <= self.right)) + + def __richcmp__(self, other, int op): + if hasattr(other, 'ndim'): + # let numpy (or IntervalIndex) handle vectorization + return NotImplemented + + if _interval_like(other): + self_tuple = (self.left, self.right, self.closed) + other_tuple = (other.left, other.right, other.closed) + return PyObject_RichCompare(self_tuple, other_tuple, op) + + # nb. could just return NotImplemented now, but handling this + # explicitly allows us to opt into the Python 3 behavior, even on + # Python 2. + if op == Py_EQ or op == Py_NE: + return NotImplemented + else: + op_str = {Py_LT: '<', Py_LE: '<=', Py_GT: '>', Py_GE: '>='}[op] + raise TypeError('unorderable types: %s() %s %s()' % + (type(self).__name__, op_str, type(other).__name__)) + + def __reduce__(self): + args = (self.left, self.right, self.closed) + return (type(self), args) + + def __repr__(self): + return ('%s(%r, %r, closed=%r)' % + (type(self).__name__, self.left, self.right, self.closed)) + + def __str__(self): + start_symbol = '[' if self.closed_left else '(' + end_symbol = ']' if self.closed_right else ')' + return '%s%s, %s%s' % (start_symbol, self.left, self.right, end_symbol) + + +@cython.wraparound(False) +@cython.boundscheck(False) +cpdef interval_bounds_to_intervals(np.ndarray left, np.ndarray right, + str closed): + result = np.empty(len(left), dtype=object) + nulls = pd.isnull(left) | pd.isnull(right) + result[nulls] = np.nan + for i in np.flatnonzero(~nulls): + result[i] = Interval(left[i], right[i], closed) + return result + + +@cython.wraparound(False) +@cython.boundscheck(False) +cpdef intervals_to_interval_bounds(np.ndarray intervals): + left = np.empty(len(intervals), dtype=object) + right = np.empty(len(intervals), dtype=object) + cdef str closed = None + for i in range(len(intervals)): + interval = intervals[i] + left[i] = interval.left + right[i] = interval.right + if closed is None: + closed = interval.closed + elif closed != interval.closed: + raise ValueError('intervals must all be closed on the same side') + return left, right, closed diff --git a/pandas/src/intervaltree.pyx b/pandas/src/intervaltree.pyx new file mode 100644 index 0000000000000..f3a7447bc09f8 --- /dev/null +++ b/pandas/src/intervaltree.pyx @@ -0,0 +1,1492 @@ + +# DO NOT EDIT THIS FILE: This file was autogenerated from +# generate_intervaltree.py, so please edit that file and then run +# `python2 generate_intervaltree.py` to re-generate this file. + + +from numpy cimport int64_t, float64_t +from numpy cimport ndarray, PyArray_ArgSort, NPY_QUICKSORT, PyArray_Take +import numpy as np + +cimport cython +cimport numpy as cnp +cnp.import_array() + +from hashtable cimport Int64Vector, Int64VectorData + + +ctypedef fused scalar64_t: + float64_t + int64_t + + +NODE_CLASSES = {} + + +cdef class IntervalTree(IntervalMixin): + """A centered interval tree + + Based off the algorithm described on Wikipedia: + http://en.wikipedia.org/wiki/Interval_tree + """ + cdef: + readonly object left, right, root + readonly str closed + object _left_sorter, _right_sorter + + def __init__(self, left, right, closed='right', leaf_size=100): + """ + Parameters + ---------- + left, right : np.ndarray[ndim=1] + Left and right bounds for each interval. Assumed to contain no + NaNs. + closed : {'left', 'right', 'both', 'neither'}, optional + Whether the intervals are closed on the left-side, right-side, both + or neither. Defaults to 'right'. + leaf_size : int, optional + Parameter that controls when the tree switches from creating nodes + to brute-force search. Tune this parameter to optimize query + performance. + """ + if closed not in ['left', 'right', 'both', 'neither']: + raise ValueError("invalid option for 'closed': %s" % closed) + + left = np.asarray(left) + right = np.asarray(right) + dtype = np.result_type(left, right) + self.left = np.asarray(left, dtype=dtype) + self.right = np.asarray(right, dtype=dtype) + + indices = np.arange(len(left), dtype='int64') + + self.closed = closed + + node_cls = NODE_CLASSES[str(dtype), closed] + self.root = node_cls(self.left, self.right, indices, leaf_size) + + @property + def left_sorter(self): + """How to sort the left labels; this is used for binary search + """ + if self._left_sorter is None: + self._left_sorter = np.argsort(self.left) + return self._left_sorter + + @property + def right_sorter(self): + """How to sort the right labels + """ + if self._right_sorter is None: + self._right_sorter = np.argsort(self.right) + return self._right_sorter + + def get_loc(self, scalar64_t key): + """Return all positions corresponding to intervals that overlap with + the given scalar key + """ + result = Int64Vector() + self.root.query(result, key) + if not result.data.n: + raise KeyError(key) + return result.to_array() + + def _get_partial_overlap(self, key_left, key_right, side): + """Return all positions corresponding to intervals with the given side + falling between the left and right bounds of an interval query + """ + if side == 'left': + values = self.left + sorter = self.left_sorter + else: + values = self.right + sorter = self.right_sorter + key = [key_left, key_right] + i, j = values.searchsorted(key, sorter=sorter) + return sorter[i:j] + + def get_loc_interval(self, key_left, key_right): + """Lookup the intervals enclosed in the given interval bounds + + The given interval is presumed to have closed bounds. + """ + import pandas as pd + left_overlap = self._get_partial_overlap(key_left, key_right, 'left') + right_overlap = self._get_partial_overlap(key_left, key_right, 'right') + enclosing = self.get_loc(0.5 * (key_left + key_right)) + combined = np.concatenate([left_overlap, right_overlap, enclosing]) + uniques = pd.unique(combined) + return uniques + + def get_indexer(self, scalar64_t[:] target): + """Return the positions corresponding to unique intervals that overlap + with the given array of scalar targets. + """ + # TODO: write get_indexer_intervals + cdef: + int64_t old_len, i + Int64Vector result + + result = Int64Vector() + old_len = 0 + for i in range(len(target)): + self.root.query(result, target[i]) + if result.data.n == old_len: + result.append(-1) + elif result.data.n > old_len + 1: + raise KeyError( + 'indexer does not intersect a unique set of intervals') + old_len = result.data.n + return result.to_array() + + def get_indexer_non_unique(self, scalar64_t[:] target): + """Return the positions corresponding to intervals that overlap with + the given array of scalar targets. Non-unique positions are repeated. + """ + cdef: + int64_t old_len, i + Int64Vector result, missing + + result = Int64Vector() + missing = Int64Vector() + old_len = 0 + for i in range(len(target)): + self.root.query(result, target[i]) + if result.data.n == old_len: + result.append(-1) + missing.append(i) + old_len = result.data.n + return result.to_array(), missing.to_array() + + def __repr__(self): + return ('' + % self.root.n_elements) + + +cdef take(ndarray source, ndarray indices): + """Take the given positions from a 1D ndarray + """ + return PyArray_Take(source, indices, 0) + + +cdef sort_values_and_indices(all_values, all_indices, subset): + indices = take(all_indices, subset) + values = take(all_values, subset) + sorter = PyArray_ArgSort(values, 0, NPY_QUICKSORT) + sorted_values = take(values, sorter) + sorted_indices = take(indices, sorter) + return sorted_values, sorted_indices + + +cdef class Float64ClosedLeftIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + readonly left_node, right_node + readonly float64_t[:] center_left_values, center_right_values + readonly int64_t[:] center_left_indices, center_right_indices + readonly float64_t pivot + readonly int64_t n_elements, leaf_size + + def __init__(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.pivot = np.median(left + right) / 2 + self.n_elements = len(left) + self.leaf_size = leaf_size + + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, float64_t[:] left, float64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + int i + Int64Vector left_ind, right_ind, overlapping_ind + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(len(left)): + if right[i] <= self.pivot: + left_ind.append(i) + elif self.pivot < left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + + This should be a terminal leaf node if the number of indices is smaller + than leaf_size. Otherwise it should be a non-terminal node. + """ + + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + + if len(indices) <= self.leaf_size: + return Float64ClosedLeftIntervalLeaf( + left, right, indices) + else: + return Float64ClosedLeftIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + float64_t[:] values + int i + + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(len(values)): + if not values[i] <= point: + break + result.append(indices[i]) + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(len(values) - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + return ('' % + (self.pivot, self.n_elements, self.left_node.n_elements, + self.right_node.n_elements, len(self.center_left_indices))) + + def counts(self): + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'left'] = Float64ClosedLeftIntervalNode + + +cdef class Float64ClosedLeftIntervalLeaf: + """Terminal node for an IntervalTree + + Once we get down to a certain size, it doens't make sense to continue the + binary tree structure. Instead, we store interval bounds in 1d arrays use + linear search. + """ + cdef: + readonly float64_t[:] left, right + readonly int64_t[:] indices + + def __init__(self, + float64_t[:] left, + float64_t[:] right, + int64_t[:] indices): + self.left = left + self.right = right + self.indices = indices + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + for i in range(len(self.left)): + if self.left[i] <= point < self.right[i]: + result.append(self.indices[i]) + + def __repr__(self): + return ('' + % self.n_elements) + + @property + def n_elements(self): + return len(self.left) + + def counts(self): + return self.n_elements + + +cdef class Float64ClosedRightIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + readonly left_node, right_node + readonly float64_t[:] center_left_values, center_right_values + readonly int64_t[:] center_left_indices, center_right_indices + readonly float64_t pivot + readonly int64_t n_elements, leaf_size + + def __init__(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.pivot = np.median(left + right) / 2 + self.n_elements = len(left) + self.leaf_size = leaf_size + + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, float64_t[:] left, float64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + int i + Int64Vector left_ind, right_ind, overlapping_ind + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(len(left)): + if right[i] < self.pivot: + left_ind.append(i) + elif self.pivot <= left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + + This should be a terminal leaf node if the number of indices is smaller + than leaf_size. Otherwise it should be a non-terminal node. + """ + + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + + if len(indices) <= self.leaf_size: + return Float64ClosedRightIntervalLeaf( + left, right, indices) + else: + return Float64ClosedRightIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + float64_t[:] values + int i + + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(len(values)): + if not values[i] < point: + break + result.append(indices[i]) + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(len(values) - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + return ('' % + (self.pivot, self.n_elements, self.left_node.n_elements, + self.right_node.n_elements, len(self.center_left_indices))) + + def counts(self): + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'right'] = Float64ClosedRightIntervalNode + + +cdef class Float64ClosedRightIntervalLeaf: + """Terminal node for an IntervalTree + + Once we get down to a certain size, it doens't make sense to continue the + binary tree structure. Instead, we store interval bounds in 1d arrays use + linear search. + """ + cdef: + readonly float64_t[:] left, right + readonly int64_t[:] indices + + def __init__(self, + float64_t[:] left, + float64_t[:] right, + int64_t[:] indices): + self.left = left + self.right = right + self.indices = indices + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + for i in range(len(self.left)): + if self.left[i] < point <= self.right[i]: + result.append(self.indices[i]) + + def __repr__(self): + return ('' + % self.n_elements) + + @property + def n_elements(self): + return len(self.left) + + def counts(self): + return self.n_elements + + +cdef class Float64ClosedBothIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + readonly left_node, right_node + readonly float64_t[:] center_left_values, center_right_values + readonly int64_t[:] center_left_indices, center_right_indices + readonly float64_t pivot + readonly int64_t n_elements, leaf_size + + def __init__(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.pivot = np.median(left + right) / 2 + self.n_elements = len(left) + self.leaf_size = leaf_size + + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, float64_t[:] left, float64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + int i + Int64Vector left_ind, right_ind, overlapping_ind + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(len(left)): + if right[i] < self.pivot: + left_ind.append(i) + elif self.pivot < left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + + This should be a terminal leaf node if the number of indices is smaller + than leaf_size. Otherwise it should be a non-terminal node. + """ + + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + + if len(indices) <= self.leaf_size: + return Float64ClosedBothIntervalLeaf( + left, right, indices) + else: + return Float64ClosedBothIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + float64_t[:] values + int i + + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(len(values)): + if not values[i] <= point: + break + result.append(indices[i]) + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(len(values) - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + return ('' % + (self.pivot, self.n_elements, self.left_node.n_elements, + self.right_node.n_elements, len(self.center_left_indices))) + + def counts(self): + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'both'] = Float64ClosedBothIntervalNode + + +cdef class Float64ClosedBothIntervalLeaf: + """Terminal node for an IntervalTree + + Once we get down to a certain size, it doens't make sense to continue the + binary tree structure. Instead, we store interval bounds in 1d arrays use + linear search. + """ + cdef: + readonly float64_t[:] left, right + readonly int64_t[:] indices + + def __init__(self, + float64_t[:] left, + float64_t[:] right, + int64_t[:] indices): + self.left = left + self.right = right + self.indices = indices + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + for i in range(len(self.left)): + if self.left[i] <= point <= self.right[i]: + result.append(self.indices[i]) + + def __repr__(self): + return ('' + % self.n_elements) + + @property + def n_elements(self): + return len(self.left) + + def counts(self): + return self.n_elements + + +cdef class Float64ClosedNeitherIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + readonly left_node, right_node + readonly float64_t[:] center_left_values, center_right_values + readonly int64_t[:] center_left_indices, center_right_indices + readonly float64_t pivot + readonly int64_t n_elements, leaf_size + + def __init__(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.pivot = np.median(left + right) / 2 + self.n_elements = len(left) + self.leaf_size = leaf_size + + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, float64_t[:] left, float64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + int i + Int64Vector left_ind, right_ind, overlapping_ind + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(len(left)): + if right[i] <= self.pivot: + left_ind.append(i) + elif self.pivot <= left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[float64_t, ndim=1] left, + ndarray[float64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + + This should be a terminal leaf node if the number of indices is smaller + than leaf_size. Otherwise it should be a non-terminal node. + """ + + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + + if len(indices) <= self.leaf_size: + return Float64ClosedNeitherIntervalLeaf( + left, right, indices) + else: + return Float64ClosedNeitherIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + float64_t[:] values + int i + + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(len(values)): + if not values[i] < point: + break + result.append(indices[i]) + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(len(values) - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + return ('' % + (self.pivot, self.n_elements, self.left_node.n_elements, + self.right_node.n_elements, len(self.center_left_indices))) + + def counts(self): + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'neither'] = Float64ClosedNeitherIntervalNode + + +cdef class Float64ClosedNeitherIntervalLeaf: + """Terminal node for an IntervalTree + + Once we get down to a certain size, it doens't make sense to continue the + binary tree structure. Instead, we store interval bounds in 1d arrays use + linear search. + """ + cdef: + readonly float64_t[:] left, right + readonly int64_t[:] indices + + def __init__(self, + float64_t[:] left, + float64_t[:] right, + int64_t[:] indices): + self.left = left + self.right = right + self.indices = indices + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + for i in range(len(self.left)): + if self.left[i] < point < self.right[i]: + result.append(self.indices[i]) + + def __repr__(self): + return ('' + % self.n_elements) + + @property + def n_elements(self): + return len(self.left) + + def counts(self): + return self.n_elements + + +cdef class Int64ClosedLeftIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + readonly left_node, right_node + readonly int64_t[:] center_left_values, center_right_values + readonly int64_t[:] center_left_indices, center_right_indices + readonly int64_t pivot + readonly int64_t n_elements, leaf_size + + def __init__(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.pivot = np.median(left + right) / 2 + self.n_elements = len(left) + self.leaf_size = leaf_size + + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, int64_t[:] left, int64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + int i + Int64Vector left_ind, right_ind, overlapping_ind + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(len(left)): + if right[i] <= self.pivot: + left_ind.append(i) + elif self.pivot < left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + + This should be a terminal leaf node if the number of indices is smaller + than leaf_size. Otherwise it should be a non-terminal node. + """ + + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + + if len(indices) <= self.leaf_size: + return Int64ClosedLeftIntervalLeaf( + left, right, indices) + else: + return Int64ClosedLeftIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + int64_t[:] values + int i + + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(len(values)): + if not values[i] <= point: + break + result.append(indices[i]) + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(len(values) - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + return ('' % + (self.pivot, self.n_elements, self.left_node.n_elements, + self.right_node.n_elements, len(self.center_left_indices))) + + def counts(self): + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'left'] = Int64ClosedLeftIntervalNode + + +cdef class Int64ClosedLeftIntervalLeaf: + """Terminal node for an IntervalTree + + Once we get down to a certain size, it doens't make sense to continue the + binary tree structure. Instead, we store interval bounds in 1d arrays use + linear search. + """ + cdef: + readonly int64_t[:] left, right + readonly int64_t[:] indices + + def __init__(self, + int64_t[:] left, + int64_t[:] right, + int64_t[:] indices): + self.left = left + self.right = right + self.indices = indices + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + for i in range(len(self.left)): + if self.left[i] <= point < self.right[i]: + result.append(self.indices[i]) + + def __repr__(self): + return ('' + % self.n_elements) + + @property + def n_elements(self): + return len(self.left) + + def counts(self): + return self.n_elements + + +cdef class Int64ClosedRightIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + readonly left_node, right_node + readonly int64_t[:] center_left_values, center_right_values + readonly int64_t[:] center_left_indices, center_right_indices + readonly int64_t pivot + readonly int64_t n_elements, leaf_size + + def __init__(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.pivot = np.median(left + right) / 2 + self.n_elements = len(left) + self.leaf_size = leaf_size + + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, int64_t[:] left, int64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + int i + Int64Vector left_ind, right_ind, overlapping_ind + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(len(left)): + if right[i] < self.pivot: + left_ind.append(i) + elif self.pivot <= left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + + This should be a terminal leaf node if the number of indices is smaller + than leaf_size. Otherwise it should be a non-terminal node. + """ + + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + + if len(indices) <= self.leaf_size: + return Int64ClosedRightIntervalLeaf( + left, right, indices) + else: + return Int64ClosedRightIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + int64_t[:] values + int i + + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(len(values)): + if not values[i] < point: + break + result.append(indices[i]) + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(len(values) - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + return ('' % + (self.pivot, self.n_elements, self.left_node.n_elements, + self.right_node.n_elements, len(self.center_left_indices))) + + def counts(self): + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'right'] = Int64ClosedRightIntervalNode + + +cdef class Int64ClosedRightIntervalLeaf: + """Terminal node for an IntervalTree + + Once we get down to a certain size, it doens't make sense to continue the + binary tree structure. Instead, we store interval bounds in 1d arrays use + linear search. + """ + cdef: + readonly int64_t[:] left, right + readonly int64_t[:] indices + + def __init__(self, + int64_t[:] left, + int64_t[:] right, + int64_t[:] indices): + self.left = left + self.right = right + self.indices = indices + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + for i in range(len(self.left)): + if self.left[i] < point <= self.right[i]: + result.append(self.indices[i]) + + def __repr__(self): + return ('' + % self.n_elements) + + @property + def n_elements(self): + return len(self.left) + + def counts(self): + return self.n_elements + + +cdef class Int64ClosedBothIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + readonly left_node, right_node + readonly int64_t[:] center_left_values, center_right_values + readonly int64_t[:] center_left_indices, center_right_indices + readonly int64_t pivot + readonly int64_t n_elements, leaf_size + + def __init__(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.pivot = np.median(left + right) / 2 + self.n_elements = len(left) + self.leaf_size = leaf_size + + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, int64_t[:] left, int64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + int i + Int64Vector left_ind, right_ind, overlapping_ind + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(len(left)): + if right[i] < self.pivot: + left_ind.append(i) + elif self.pivot < left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + + This should be a terminal leaf node if the number of indices is smaller + than leaf_size. Otherwise it should be a non-terminal node. + """ + + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + + if len(indices) <= self.leaf_size: + return Int64ClosedBothIntervalLeaf( + left, right, indices) + else: + return Int64ClosedBothIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + int64_t[:] values + int i + + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(len(values)): + if not values[i] <= point: + break + result.append(indices[i]) + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(len(values) - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + return ('' % + (self.pivot, self.n_elements, self.left_node.n_elements, + self.right_node.n_elements, len(self.center_left_indices))) + + def counts(self): + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'both'] = Int64ClosedBothIntervalNode + + +cdef class Int64ClosedBothIntervalLeaf: + """Terminal node for an IntervalTree + + Once we get down to a certain size, it doens't make sense to continue the + binary tree structure. Instead, we store interval bounds in 1d arrays use + linear search. + """ + cdef: + readonly int64_t[:] left, right + readonly int64_t[:] indices + + def __init__(self, + int64_t[:] left, + int64_t[:] right, + int64_t[:] indices): + self.left = left + self.right = right + self.indices = indices + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + for i in range(len(self.left)): + if self.left[i] <= point <= self.right[i]: + result.append(self.indices[i]) + + def __repr__(self): + return ('' + % self.n_elements) + + @property + def n_elements(self): + return len(self.left) + + def counts(self): + return self.n_elements + + +cdef class Int64ClosedNeitherIntervalNode: + """Non-terminal node for an IntervalTree + + Categorizes intervals by those that fall to the left, those that fall to + the right, and those that overlap with the pivot. + """ + cdef: + readonly left_node, right_node + readonly int64_t[:] center_left_values, center_right_values + readonly int64_t[:] center_left_indices, center_right_indices + readonly int64_t pivot + readonly int64_t n_elements, leaf_size + + def __init__(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + int64_t leaf_size): + + self.pivot = np.median(left + right) / 2 + self.n_elements = len(left) + self.leaf_size = leaf_size + + left_set, right_set, center_set = self.classify_intervals(left, right) + + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) + + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + + @cython.wraparound(False) + @cython.boundscheck(False) + cdef classify_intervals(self, int64_t[:] left, int64_t[:] right): + """Classify the given intervals based upon whether they fall to the + left, right, or overlap with this node's pivot. + """ + cdef: + int i + Int64Vector left_ind, right_ind, overlapping_ind + + left_ind = Int64Vector() + right_ind = Int64Vector() + overlapping_ind = Int64Vector() + + for i in range(len(left)): + if right[i] <= self.pivot: + left_ind.append(i) + elif self.pivot <= left[i]: + right_ind.append(i) + else: + overlapping_ind.append(i) + + return (left_ind.to_array(), + right_ind.to_array(), + overlapping_ind.to_array()) + + cdef new_child_node(self, + ndarray[int64_t, ndim=1] left, + ndarray[int64_t, ndim=1] right, + ndarray[int64_t, ndim=1] indices, + ndarray[int64_t, ndim=1] subset): + """Create a new child node. + + This should be a terminal leaf node if the number of indices is smaller + than leaf_size. Otherwise it should be a non-terminal node. + """ + + left = take(left, subset) + right = take(right, subset) + indices = take(indices, subset) + + if len(indices) <= self.leaf_size: + return Int64ClosedNeitherIntervalLeaf( + left, right, indices) + else: + return Int64ClosedNeitherIntervalNode( + left, right, indices, self.leaf_size) + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + """Recursively query this node and its sub-nodes for intervals that + overlap with the query point. + """ + cdef: + int64_t[:] indices + int64_t[:] values + int i + + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(len(values)): + if not values[i] < point: + break + result.append(indices[i]) + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(len(values) - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) + + def __repr__(self): + return ('' % + (self.pivot, self.n_elements, self.left_node.n_elements, + self.right_node.n_elements, len(self.center_left_indices))) + + def counts(self): + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'neither'] = Int64ClosedNeitherIntervalNode + + +cdef class Int64ClosedNeitherIntervalLeaf: + """Terminal node for an IntervalTree + + Once we get down to a certain size, it doens't make sense to continue the + binary tree structure. Instead, we store interval bounds in 1d arrays use + linear search. + """ + cdef: + readonly int64_t[:] left, right + readonly int64_t[:] indices + + def __init__(self, + int64_t[:] left, + int64_t[:] right, + int64_t[:] indices): + self.left = left + self.right = right + self.indices = indices + + @cython.wraparound(False) + @cython.boundscheck(False) + cpdef query(self, Int64Vector result, scalar64_t point): + for i in range(len(self.left)): + if self.left[i] < point < self.right[i]: + result.append(self.indices[i]) + + def __repr__(self): + return ('' + % self.n_elements) + + @property + def n_elements(self): + return len(self.left) + + def counts(self): + return self.n_elements + + diff --git a/pandas/tests/test_algos.py b/pandas/tests/test_algos.py index b18bd7b2b3978..eb0964392d20c 100644 --- a/pandas/tests/test_algos.py +++ b/pandas/tests/test_algos.py @@ -352,14 +352,10 @@ def test_value_counts(self): arr = np.random.randn(4) factor = cut(arr, 4) - tm.assertIsInstance(factor, Categorical) + # tm.assertIsInstance(factor, n) result = algos.value_counts(factor) - cats = ['(-1.194, -0.535]', - '(-0.535, 0.121]', - '(0.121, 0.777]', - '(0.777, 1.433]' - ] - expected_index = CategoricalIndex(cats, cats, ordered=True) + breaks = [-1.192, -0.535, 0.121, 0.777, 1.433] + expected_index = pd.IntervalIndex.from_breaks(breaks) expected = Series([1, 1, 1, 1], index=expected_index) tm.assert_series_equal(result.sort_index(), expected.sort_index()) @@ -368,12 +364,12 @@ def test_value_counts_bins(self): s = [1, 2, 3, 4] result = algos.value_counts(s, bins=1) self.assertEqual(result.tolist(), [4]) - self.assertEqual(result.index[0], 0.997) + self.assertEqual(result.index[0], pd.Interval(0.999, 4.0)) result = algos.value_counts(s, bins=2, sort=False) self.assertEqual(result.tolist(), [2, 2]) - self.assertEqual(result.index[0], 0.997) - self.assertEqual(result.index[1], 2.5) + self.assertEqual(result.index.min(), pd.Interval(0.999, 2.5)) + self.assertEqual(result.index.max(), pd.Interval(2.5, 4.0)) def test_value_counts_dtypes(self): result = algos.value_counts([1, 1.]) diff --git a/pandas/tests/test_base.py b/pandas/tests/test_base.py index 8eb5bd3a202a4..ca60c357b73c5 100644 --- a/pandas/tests/test_base.py +++ b/pandas/tests/test_base.py @@ -11,7 +11,8 @@ from pandas.tseries.base import DatetimeIndexOpsMixin from pandas.util.testing import assertRaisesRegexp, assertIsInstance from pandas.tseries.common import is_datetimelike -from pandas import Series, Index, Int64Index, DatetimeIndex, TimedeltaIndex, PeriodIndex, Timedelta +from pandas import (Series, Index, Int64Index, DatetimeIndex, TimedeltaIndex, + PeriodIndex, IntervalIndex, Timedelta, Interval) import pandas.tslib as tslib from pandas import _np_version_under1p9 import nose @@ -553,20 +554,21 @@ def test_value_counts_inferred(self): s1 = Series([1, 1, 2, 3]) res1 = s1.value_counts(bins=1) - exp1 = Series({0.998: 4}) + exp1 = Series({Interval(0.999, 3.0): 4}) tm.assert_series_equal(res1, exp1) res1n = s1.value_counts(bins=1, normalize=True) - exp1n = Series({0.998: 1.0}) + exp1n = Series({Interval(0.999, 3.0): 1.0}) tm.assert_series_equal(res1n, exp1n) self.assert_numpy_array_equal(s1.unique(), np.array([1, 2, 3])) self.assertEqual(s1.nunique(), 3) res4 = s1.value_counts(bins=4) - exp4 = Series({0.998: 2, 1.5: 1, 2.0: 0, 2.5: 1}, index=[0.998, 2.5, 1.5, 2.0]) + intervals = IntervalIndex.from_breaks([0.999, 1.5, 2.0, 2.5, 3.0]) + exp4 = Series([2, 1, 1], index=intervals.take([0, 3, 1])) tm.assert_series_equal(res4, exp4) res4n = s1.value_counts(bins=4, normalize=True) - exp4n = Series({0.998: 0.5, 1.5: 0.25, 2.0: 0.0, 2.5: 0.25}, index=[0.998, 2.5, 1.5, 2.0]) + exp4n = Series([0.5, 0.25, 0.25], index=intervals.take([0, 3, 1])) tm.assert_series_equal(res4n, exp4n) # handle NA's properly diff --git a/pandas/tests/test_categorical.py b/pandas/tests/test_categorical.py index 64908f96bfdd8..492f3eb79947e 100755 --- a/pandas/tests/test_categorical.py +++ b/pandas/tests/test_categorical.py @@ -11,7 +11,8 @@ import numpy as np import pandas as pd -from pandas import Categorical, Index, Series, DataFrame, PeriodIndex, Timestamp, CategoricalIndex +from pandas import (Categorical, Index, Series, DataFrame, PeriodIndex, + Timestamp, CategoricalIndex, Interval) from pandas.core.config import option_context import pandas.core.common as com @@ -1328,9 +1329,10 @@ def setUp(self): df = DataFrame({'value': np.random.randint(0, 10000, 100)}) labels = [ "{0} - {1}".format(i, i + 499) for i in range(0, 10000, 500) ] + cat_labels = Categorical(labels, labels) df = df.sort_values(by=['value'], ascending=True) - df['value_group'] = pd.cut(df.value, range(0, 10500, 500), right=False, labels=labels) + df['value_group'] = pd.cut(df.value, range(0, 10500, 500), right=False, labels=cat_labels) self.cat = df def test_dtypes(self): @@ -1727,7 +1729,7 @@ def test_series_functions_no_warnings(self): def test_assignment_to_dataframe(self): # assignment df = DataFrame({'value': np.array(np.random.randint(0, 10000, 100),dtype='int32')}) - labels = [ "{0} - {1}".format(i, i + 499) for i in range(0, 10000, 500) ] + labels = Categorical(["{0} - {1}".format(i, i + 499) for i in range(0, 10000, 500)]) df = df.sort_values(by=['value'], ascending=True) s = pd.cut(df.value, range(0, 10500, 500), right=False, labels=labels) @@ -2590,7 +2592,7 @@ def f(x): # GH 9603 df = pd.DataFrame({'a': [1, 0, 0, 0]}) - c = pd.cut(df.a, [0, 1, 2, 3, 4]) + c = pd.cut(df.a, [0, 1, 2, 3, 4], labels=pd.Categorical(list('abcd'))) result = df.groupby(c).apply(len) expected = pd.Series([1, 0, 0, 0], index=pd.CategoricalIndex(c.values.categories)) expected.index.name = 'a' @@ -2725,7 +2727,7 @@ def test_slicing(self): df = DataFrame({'value': (np.arange(100)+1).astype('int64')}) df['D'] = pd.cut(df.value, bins=[0,25,50,75,100]) - expected = Series([11,'(0, 25]'], index=['value','D'], name=10) + expected = Series([11, Interval(0, 25)], index=['value','D'], name=10) result = df.iloc[10] tm.assert_series_equal(result, expected) @@ -2735,7 +2737,7 @@ def test_slicing(self): result = df.iloc[10:20] tm.assert_frame_equal(result, expected) - expected = Series([9,'(0, 25]'],index=['value', 'D'], name=8) + expected = Series([9, Interval(0, 25)],index=['value', 'D'], name=8) result = df.loc[8] tm.assert_series_equal(result, expected) diff --git a/pandas/tests/test_frame.py b/pandas/tests/test_frame.py index 09de3bf4a8046..f3af12b68cb47 100644 --- a/pandas/tests/test_frame.py +++ b/pandas/tests/test_frame.py @@ -14465,6 +14465,17 @@ def test_reset_index_with_datetimeindex_cols(self): datetime(2013, 1, 2)]) assert_frame_equal(result, expected) + def test_reset_index_with_intervals(self): + idx = pd.IntervalIndex.from_breaks(np.arange(11), name='x') + original = pd.DataFrame({'x': idx, 'y': np.arange(10)})[['x', 'y']] + + result = original.set_index('x') + expected = pd.DataFrame({'y': np.arange(10)}, index=idx) + assert_frame_equal(result, expected) + + result2 = result.reset_index() + assert_frame_equal(result2, original) + #---------------------------------------------------------------------- # Tests to cope with refactored internals def test_as_matrix_numeric_cols(self): diff --git a/pandas/tests/test_index.py b/pandas/tests/test_index.py index e2fa6a90429dc..2a7138c9bfdec 100644 --- a/pandas/tests/test_index.py +++ b/pandas/tests/test_index.py @@ -13,7 +13,8 @@ from pandas import (period_range, date_range, Categorical, Series, Index, Float64Index, Int64Index, MultiIndex, - CategoricalIndex, DatetimeIndex, TimedeltaIndex, PeriodIndex) + CategoricalIndex, IntervalIndex, DatetimeIndex, + TimedeltaIndex, PeriodIndex) from pandas.core.index import InvalidIndexError, NumericIndex from pandas.util.testing import (assert_almost_equal, assertRaisesRegexp, assert_copy) @@ -109,9 +110,6 @@ def test_reindex_base(self): actual = idx.get_indexer(idx) tm.assert_numpy_array_equal(expected, actual) - with tm.assertRaisesRegexp(ValueError, 'Invalid fill method'): - idx.get_indexer(idx, method='invalid') - def test_ndarray_compat_properties(self): idx = self.create_index() @@ -222,7 +220,7 @@ def test_duplicates(self): if not len(ind): continue - if isinstance(ind, MultiIndex): + if isinstance(ind, (MultiIndex, IntervalIndex)): continue idx = self._holder([ind[0]]*5) self.assertFalse(idx.is_unique) @@ -1410,6 +1408,9 @@ def test_get_indexer_invalid(self): with tm.assertRaisesRegexp(ValueError, 'limit argument'): idx.get_indexer([1, 0], limit=1) + with tm.assertRaisesRegexp(ValueError, 'Invalid fill method'): + idx.get_indexer(idx, method='invalid') + def test_get_indexer_nearest(self): idx = Index(np.arange(10)) @@ -2615,6 +2616,18 @@ def test_fillna_categorical(self): idx.fillna(2.0) +class TestIntervalIndex(Base, tm.TestCase): + # see test_interval for more extensive tests + _holder = IntervalIndex + + def setUp(self): + self.indices = dict(intvIndex = tm.makeIntervalIndex(100)) + self.setup_indices() + + def create_index(self): + return IntervalIndex.from_breaks(np.arange(0, 100, 10)) + + class Numeric(Base): def test_numeric_compat(self): diff --git a/pandas/tests/test_indexing.py b/pandas/tests/test_indexing.py index c6d80a08ad61a..a5e54d58b2559 100644 --- a/pandas/tests/test_indexing.py +++ b/pandas/tests/test_indexing.py @@ -17,7 +17,8 @@ from pandas import option_context from pandas.core.indexing import _non_reducing_slice, _maybe_numeric_slice from pandas.core.api import (DataFrame, Index, Series, Panel, isnull, - MultiIndex, Float64Index, Timestamp, Timedelta) + MultiIndex, Float64Index, IntervalIndex, + Timestamp, Timedelta) from pandas.util.testing import (assert_almost_equal, assert_series_equal, assert_frame_equal, assert_panel_equal, assert_attr_equal) @@ -4345,6 +4346,31 @@ def test_floating_index(self): assert_series_equal(result1, result3) assert_series_equal(result1, Series([1],index=[2.5])) + def test_interval_index(self): + s = Series(np.arange(5), IntervalIndex.from_breaks(np.arange(6))) + + expected = 0 + self.assertEqual(expected, s.loc[0.5]) + self.assertEqual(expected, s.loc[1]) + self.assertEqual(expected, s.loc[pd.Interval(0, 1)]) + self.assertRaises(KeyError, s.loc.__getitem__, 0) + + expected = s.iloc[:3] + assert_series_equal(expected, s.loc[:3]) + assert_series_equal(expected, s.loc[:2.5]) + assert_series_equal(expected, s.loc[0.1:2.5]) + assert_series_equal(expected, s.loc[-1:3]) + + def _assert_expected_loc_array_indexer(expected, original, indexer): + expected = pd.Series(expected, indexer) + actual = original.loc[indexer] + assert_series_equal(expected, actual) + + expected = s.iloc[1:4] + assert_series_equal(expected, s.loc[[1.5, 2.5, 3.5]]) + assert_series_equal(expected, s.loc[[2, 3, 4]]) + assert_series_equal(expected, s.loc[[1.5, 3, 4]]) + def test_scalar_indexer(self): # float indexing checked above diff --git a/pandas/tests/test_interval.py b/pandas/tests/test_interval.py new file mode 100644 index 0000000000000..4be7bc6175cbf --- /dev/null +++ b/pandas/tests/test_interval.py @@ -0,0 +1,537 @@ +import numpy as np + +from pandas.core.interval import Interval, IntervalIndex +from pandas.core.index import Index +from pandas.lib import IntervalTree + +import pandas.util.testing as tm +import pandas as pd + + +class TestInterval(tm.TestCase): + def setUp(self): + self.interval = Interval(0, 1) + + def test_properties(self): + self.assertEqual(self.interval.closed, 'right') + self.assertEqual(self.interval.left, 0) + self.assertEqual(self.interval.right, 1) + self.assertEqual(self.interval.mid, 0.5) + + def test_repr(self): + self.assertEqual(repr(self.interval), + "Interval(0, 1, closed='right')") + self.assertEqual(str(self.interval), "(0, 1]") + + interval_left = Interval(0, 1, closed='left') + self.assertEqual(repr(interval_left), + "Interval(0, 1, closed='left')") + self.assertEqual(str(interval_left), "[0, 1)") + + def test_contains(self): + self.assertIn(0.5, self.interval) + self.assertIn(1, self.interval) + self.assertNotIn(0, self.interval) + self.assertRaises(TypeError, lambda: self.interval in self.interval) + + interval = Interval(0, 1, closed='both') + self.assertIn(0, interval) + self.assertIn(1, interval) + + interval = Interval(0, 1, closed='neither') + self.assertNotIn(0, interval) + self.assertIn(0.5, interval) + self.assertNotIn(1, interval) + + def test_equal(self): + self.assertEqual(Interval(0, 1), Interval(0, 1, closed='right')) + self.assertNotEqual(Interval(0, 1), Interval(0, 1, closed='left')) + self.assertNotEqual(Interval(0, 1), 0) + + def test_comparison(self): + with self.assertRaisesRegexp(TypeError, 'unorderable types'): + Interval(0, 1) < 2 + + self.assertTrue(Interval(0, 1) < Interval(1, 2)) + self.assertTrue(Interval(0, 1) < Interval(0, 2)) + self.assertTrue(Interval(0, 1) < Interval(0.5, 1.5)) + self.assertTrue(Interval(0, 1) <= Interval(0, 1)) + self.assertTrue(Interval(0, 1) > Interval(-1, 2)) + self.assertTrue(Interval(0, 1) >= Interval(0, 1)) + + def test_hash(self): + # should not raise + hash(self.interval) + + # def test_math(self): + # expected = Interval(1, 2) + # actual = self.interval + 1 + # self.assertEqual(expected, actual) + + +class TestIntervalTree(tm.TestCase): + def setUp(self): + self.tree = IntervalTree(np.arange(5), np.arange(5) + 2) + + def test_get_loc(self): + self.assert_numpy_array_equal(self.tree.get_loc(1), [0]) + self.assert_numpy_array_equal(np.sort(self.tree.get_loc(2)), [0, 1]) + with self.assertRaises(KeyError): + self.tree.get_loc(-1) + + def test_get_indexer(self): + self.assert_numpy_array_equal( + self.tree.get_indexer(np.array([1.0, 5.5, 6.5])), [0, 4, -1]) + with self.assertRaises(KeyError): + self.tree.get_indexer(np.array([3.0])) + + def test_get_indexer_non_unique(self): + indexer, missing = self.tree.get_indexer_non_unique( + np.array([1.0, 2.0, 6.5])) + self.assert_numpy_array_equal(indexer[:1], [0]) + self.assert_numpy_array_equal(np.sort(indexer[1:3]), [0, 1]) + self.assert_numpy_array_equal(np.sort(indexer[3:]), [-1]) + self.assert_numpy_array_equal(missing, [2]) + + def test_duplicates(self): + tree = IntervalTree([0, 0, 0], [1, 1, 1]) + self.assert_numpy_array_equal(np.sort(tree.get_loc(0.5)), [0, 1, 2]) + + with self.assertRaises(KeyError): + tree.get_indexer(np.array([0.5])) + + indexer, missing = tree.get_indexer_non_unique(np.array([0.5])) + self.assert_numpy_array_equal(np.sort(indexer), [0, 1, 2]) + self.assert_numpy_array_equal(missing, []) + + def test_get_loc_closed(self): + for closed in ['left', 'right', 'both', 'neither']: + tree = IntervalTree([0], [1], closed=closed) + for p, errors in [(0, tree.open_left), + (1, tree.open_right)]: + if errors: + with self.assertRaises(KeyError): + tree.get_loc(p) + else: + self.assert_numpy_array_equal(tree.get_loc(p), + np.array([0])) + + def test_get_indexer_closed(self): + x = np.arange(1000) + found = x + not_found = -np.ones(1000) + for closed in ['left', 'right', 'both', 'neither']: + tree = IntervalTree(x, x + 0.5, closed=closed) + self.assert_numpy_array_equal(found, tree.get_indexer(x + 0.25)) + + expected = found if tree.closed_left else not_found + self.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.0)) + + expected = found if tree.closed_right else not_found + self.assert_numpy_array_equal(expected, tree.get_indexer(x + 0.5)) + + +class TestIntervalIndex(tm.TestCase): + def setUp(self): + self.index = IntervalIndex([0, 1], [1, 2]) + + def test_constructors(self): + expected = self.index + actual = IntervalIndex.from_breaks(np.arange(3), closed='right') + self.assertTrue(expected.equals(actual)) + + alternate = IntervalIndex.from_breaks(np.arange(3), closed='left') + self.assertFalse(expected.equals(alternate)) + + actual = IntervalIndex.from_intervals([Interval(0, 1), Interval(1, 2)]) + self.assertTrue(expected.equals(actual)) + + self.assertRaises(ValueError, IntervalIndex, [0], [1], closed='invalid') + + # TODO: fix all these commented out tests (here and below) + + intervals = [Interval(0, 1), Interval(1, 2, closed='left')] + with self.assertRaises(ValueError): + IntervalIndex.from_intervals(intervals) + + with self.assertRaises(ValueError): + IntervalIndex([0, 10], [3, 5]) + + actual = Index([Interval(0, 1), Interval(1, 2)]) + self.assertIsInstance(actual, IntervalIndex) + self.assertTrue(expected.equals(actual)) + + actual = Index(expected) + self.assertIsInstance(actual, IntervalIndex) + self.assertTrue(expected.equals(actual)) + + # no point in nesting periods in an IntervalIndex + # self.assertRaises(ValueError, IntervalIndex.from_breaks, + # pd.period_range('2000-01-01', periods=3)) + + def test_properties(self): + self.assertEqual(len(self.index), 2) + self.assertEqual(self.index.size, 2) + + self.assert_numpy_array_equal(self.index.left, [0, 1]) + self.assertIsInstance(self.index.left, Index) + + self.assert_numpy_array_equal(self.index.right, [1, 2]) + self.assertIsInstance(self.index.right, Index) + + self.assert_numpy_array_equal(self.index.mid, [0.5, 1.5]) + self.assertIsInstance(self.index.mid, Index) + + self.assertEqual(self.index.closed, 'right') + + expected = np.array([Interval(0, 1), Interval(1, 2)], dtype=object) + self.assert_numpy_array_equal(np.asarray(self.index), expected) + self.assert_numpy_array_equal(self.index.values, expected) + + def test_copy(self): + actual = self.index.copy() + self.assertTrue(actual.equals(self.index)) + + actual = self.index.copy(deep=True) + self.assertTrue(actual.equals(self.index)) + self.assertIsNot(actual.left, self.index.left) + + def test_delete(self): + expected = IntervalIndex.from_breaks([1, 2]) + actual = self.index.delete(0) + self.assertTrue(expected.equals(actual)) + + def test_insert(self): + expected = IntervalIndex.from_breaks(range(4)) + actual = self.index.insert(2, Interval(2, 3)) + self.assertTrue(expected.equals(actual)) + + self.assertRaises(ValueError, self.index.insert, 0, 1) + self.assertRaises(ValueError, self.index.insert, 0, + Interval(2, 3, closed='left')) + + def test_take(self): + actual = self.index.take([0, 1]) + self.assertTrue(self.index.equals(actual)) + + expected = IntervalIndex([0, 0, 1], [1, 1, 2]) + actual = self.index.take([0, 0, 1]) + self.assertTrue(expected.equals(actual)) + + def test_monotonic_and_unique(self): + self.assertTrue(self.index.is_monotonic) + self.assertTrue(self.index.is_unique) + + idx = IntervalIndex.from_tuples([(0, 1), (0.5, 1.5)]) + self.assertTrue(idx.is_monotonic) + self.assertTrue(idx.is_unique) + + idx = IntervalIndex.from_tuples([(0, 1), (2, 3), (1, 2)]) + self.assertFalse(idx.is_monotonic) + self.assertTrue(idx.is_unique) + + idx = IntervalIndex.from_tuples([(0, 2), (0, 2)]) + self.assertFalse(idx.is_unique) + self.assertTrue(idx.is_monotonic) + + def test_repr(self): + expected = ("IntervalIndex(left=[0, 1],\n right=[1, 2]," + "\n closed='right')") + IntervalIndex((0, 1), (1, 2), closed='right') + self.assertEqual(repr(self.index), expected) + + def test_get_loc_value(self): + self.assertRaises(KeyError, self.index.get_loc, 0) + self.assertEqual(self.index.get_loc(0.5), 0) + self.assertEqual(self.index.get_loc(1), 0) + self.assertEqual(self.index.get_loc(1.5), 1) + self.assertEqual(self.index.get_loc(2), 1) + self.assertRaises(KeyError, self.index.get_loc, -1) + self.assertRaises(KeyError, self.index.get_loc, 3) + + idx = IntervalIndex.from_tuples([(0, 2), (1, 3)]) + self.assertEqual(idx.get_loc(0.5), 0) + self.assertEqual(idx.get_loc(1), 0) + self.assert_numpy_array_equal(idx.get_loc(1.5), [0, 1]) + self.assert_numpy_array_equal(np.sort(idx.get_loc(2)), [0, 1]) + self.assertEqual(idx.get_loc(3), 1) + self.assertRaises(KeyError, idx.get_loc, 3.5) + + idx = IntervalIndex([0, 2], [1, 3]) + self.assertRaises(KeyError, idx.get_loc, 1.5) + + def slice_locs_cases(self, breaks): + # TODO: same tests for more index types + index = IntervalIndex.from_breaks([0, 1, 2], closed='right') + self.assertEqual(index.slice_locs(), (0, 2)) + self.assertEqual(index.slice_locs(0, 1), (0, 1)) + self.assertEqual(index.slice_locs(1, 1), (0, 1)) + self.assertEqual(index.slice_locs(0, 2), (0, 2)) + self.assertEqual(index.slice_locs(0.5, 1.5), (0, 2)) + self.assertEqual(index.slice_locs(0, 0.5), (0, 1)) + self.assertEqual(index.slice_locs(start=1), (0, 2)) + self.assertEqual(index.slice_locs(start=1.2), (1, 2)) + self.assertEqual(index.slice_locs(end=1), (0, 1)) + self.assertEqual(index.slice_locs(end=1.1), (0, 2)) + self.assertEqual(index.slice_locs(end=1.0), (0, 1)) + self.assertEqual(*index.slice_locs(-1, -1)) + + index = IntervalIndex.from_breaks([0, 1, 2], closed='neither') + self.assertEqual(index.slice_locs(0, 1), (0, 1)) + self.assertEqual(index.slice_locs(0, 2), (0, 2)) + self.assertEqual(index.slice_locs(0.5, 1.5), (0, 2)) + self.assertEqual(index.slice_locs(1, 1), (1, 1)) + self.assertEqual(index.slice_locs(1, 2), (1, 2)) + + index = IntervalIndex.from_breaks([0, 1, 2], closed='both') + self.assertEqual(index.slice_locs(1, 1), (0, 2)) + self.assertEqual(index.slice_locs(1, 2), (0, 2)) + + def test_slice_locs_int64(self): + self.slice_locs_cases([0, 1, 2]) + + def test_slice_locs_float64(self): + self.slice_locs_cases([0.0, 1.0, 2.0]) + + def slice_locs_decreasing_cases(self, tuples): + index = IntervalIndex.from_tuples(tuples) + self.assertEqual(index.slice_locs(1.5, 0.5), (1, 3)) + self.assertEqual(index.slice_locs(2, 0), (1, 3)) + self.assertEqual(index.slice_locs(2, 1), (1, 3)) + self.assertEqual(index.slice_locs(3, 1.1), (0, 3)) + self.assertEqual(index.slice_locs(3, 3), (0, 2)) + self.assertEqual(index.slice_locs(3.5, 3.3), (0, 1)) + self.assertEqual(index.slice_locs(1, -3), (2, 3)) + self.assertEqual(*index.slice_locs(-1, -1)) + + def test_slice_locs_decreasing_int64(self): + self.slice_locs_cases([(2, 4), (1, 3), (0, 2)]) + + def test_slice_locs_decreasing_float64(self): + self.slice_locs_cases([(2., 4.), (1., 3.), (0., 2.)]) + + def test_slice_locs_fails(self): + index = IntervalIndex.from_tuples([(1, 2), (0, 1), (2, 3)]) + with self.assertRaises(KeyError): + index.slice_locs(1, 2) + + def test_get_loc_interval(self): + self.assertEqual(self.index.get_loc(Interval(0, 1)), 0) + self.assertEqual(self.index.get_loc(Interval(0, 0.5)), 0) + self.assertEqual(self.index.get_loc(Interval(0, 1, 'left')), 0) + self.assertRaises(KeyError, self.index.get_loc, Interval(2, 3)) + self.assertRaises(KeyError, self.index.get_loc, Interval(-1, 0, 'left')) + + def test_get_indexer(self): + actual = self.index.get_indexer([-1, 0, 0.5, 1, 1.5, 2, 3]) + expected = [-1, -1, 0, 0, 1, 1, -1] + self.assert_numpy_array_equal(actual, expected) + + actual = self.index.get_indexer(self.index) + expected = [0, 1] + self.assert_numpy_array_equal(actual, expected) + + index = IntervalIndex.from_breaks([0, 1, 2], closed='left') + actual = index.get_indexer([-1, 0, 0.5, 1, 1.5, 2, 3]) + expected = [-1, 0, 0, 1, 1, -1, -1] + self.assert_numpy_array_equal(actual, expected) + + actual = self.index.get_indexer(index[:1]) + expected = [0] + self.assert_numpy_array_equal(actual, expected) + + self.assertRaises(ValueError, self.index.get_indexer, index) + + def test_get_indexer_subintervals(self): + # return indexers for wholly contained subintervals + target = IntervalIndex.from_breaks(np.linspace(0, 2, 5)) + actual = self.index.get_indexer(target) + expected = [0, 0, 1, 1] + self.assert_numpy_array_equal(actual, expected) + + target = IntervalIndex.from_breaks([0, 0.67, 1.33, 2]) + self.assertRaises(ValueError, self.index.get_indexer, target) + + actual = self.index.get_indexer(target[[0, -1]]) + expected = [0, 1] + self.assert_numpy_array_equal(actual, expected) + + target = IntervalIndex.from_breaks([0, 0.33, 0.67, 1], closed='left') + actual = self.index.get_indexer(target) + expected = [0, 0, 0] + self.assert_numpy_array_equal(actual, expected) + + def test_contains(self): + self.assertNotIn(0, self.index) + self.assertIn(0.5, self.index) + self.assertIn(2, self.index) + + self.assertIn(Interval(0, 1), self.index) + self.assertIn(Interval(0, 2), self.index) + self.assertIn(Interval(0, 0.5), self.index) + self.assertNotIn(Interval(3, 5), self.index) + self.assertNotIn(Interval(-1, 0, closed='left'), self.index) + + def test_non_contiguous(self): + index = IntervalIndex.from_tuples([(0, 1), (2, 3)]) + target = [0.5, 1.5, 2.5] + actual = index.get_indexer(target) + expected = [0, -1, 1] + self.assert_numpy_array_equal(actual, expected) + + self.assertNotIn(1.5, index) + + def test_union(self): + other = IntervalIndex([2], [3]) + expected = IntervalIndex(range(3), range(1, 4)) + actual = self.index.union(other) + self.assertTrue(expected.equals(actual)) + + actual = other.union(self.index) + self.assertTrue(expected.equals(actual)) + + self.assert_numpy_array_equal(self.index.union(self.index), self.index) + self.assert_numpy_array_equal(self.index.union(self.index[:1]), + self.index) + + def test_intersection(self): + other = IntervalIndex.from_breaks([1, 2, 3]) + expected = IntervalIndex.from_breaks([1, 2]) + actual = self.index.intersection(other) + self.assertTrue(expected.equals(actual)) + + self.assert_numpy_array_equal(self.index.intersection(self.index), + self.index) + + def test_difference(self): + self.assert_numpy_array_equal(self.index.difference(self.index[:1]), + self.index[1:]) + + def test_sym_diff(self): + self.assert_numpy_array_equal(self.index[:1].sym_diff(self.index[1:]), + self.index) + + def test_set_operation_errors(self): + self.assertRaises(ValueError, self.index.union, self.index.left) + + other = IntervalIndex.from_breaks([0, 1, 2], closed='neither') + self.assertRaises(ValueError, self.index.union, other) + + def test_isin(self): + actual = self.index.isin(self.index) + self.assert_numpy_array_equal([True, True], actual) + + actual = self.index.isin(self.index[:1]) + self.assert_numpy_array_equal([True, False], actual) + + def test_comparison(self): + actual = Interval(0, 1) < self.index + expected = [False, True] + self.assert_numpy_array_equal(actual, expected) + + actual = Interval(0.5, 1.5) < self.index + expected = [False, True] + self.assert_numpy_array_equal(actual, expected) + actual = self.index > Interval(0.5, 1.5) + self.assert_numpy_array_equal(actual, expected) + + actual = self.index == self.index + expected = [True, True] + self.assert_numpy_array_equal(actual, expected) + actual = self.index <= self.index + self.assert_numpy_array_equal(actual, expected) + actual = self.index >= self.index + self.assert_numpy_array_equal(actual, expected) + + actual = self.index < self.index + expected = [False, False] + self.assert_numpy_array_equal(actual, expected) + actual = self.index > self.index + self.assert_numpy_array_equal(actual, expected) + + actual = self.index == IntervalIndex.from_breaks([0, 1, 2], 'left') + self.assert_numpy_array_equal(actual, expected) + + actual = self.index == self.index.values + self.assert_numpy_array_equal(actual, [True, True]) + actual = self.index.values == self.index + self.assert_numpy_array_equal(actual, [True, True]) + actual = self.index <= self.index.values + self.assert_numpy_array_equal(actual, [True, True]) + actual = self.index != self.index.values + self.assert_numpy_array_equal(actual, [False, False]) + actual = self.index > self.index.values + self.assert_numpy_array_equal(actual, [False, False]) + actual = self.index.values > self.index + self.assert_numpy_array_equal(actual, [False, False]) + + # invalid comparisons + actual = self.index == 0 + self.assert_numpy_array_equal(actual, [False, False]) + actual = self.index == self.index.left + self.assert_numpy_array_equal(actual, [False, False]) + + with self.assertRaisesRegexp(TypeError, 'unorderable types'): + self.index > 0 + with self.assertRaisesRegexp(TypeError, 'unorderable types'): + self.index <= 0 + with self.assertRaises(TypeError): + self.index > np.arange(2) + with self.assertRaises(ValueError): + self.index > np.arange(3) + + def test_missing_values(self): + idx = pd.Index([np.nan, pd.Interval(0, 1), pd.Interval(1, 2)]) + idx2 = pd.IntervalIndex([np.nan, 0, 1], [np.nan, 1, 2]) + assert idx.equals(idx2) + + with tm.assertRaisesRegexp(ValueError, 'both left and right sides'): + pd.IntervalIndex([np.nan, 0, 1], [0, 1, 2]) + + self.assert_numpy_array_equal(pd.isnull(idx), [True, False, False]) + + def test_order(self): + expected = IntervalIndex.from_breaks([1, 2, 3, 4]) + actual = IntervalIndex.from_tuples([(3, 4), (1, 2), (2, 3)]).order() + self.assert_numpy_array_equal(expected, actual) + + def test_datetime(self): + dates = pd.date_range('2000', periods=3) + idx = IntervalIndex.from_breaks(dates) + + self.assert_numpy_array_equal(idx.left, dates[:2]) + self.assert_numpy_array_equal(idx.right, dates[-2:]) + + expected = pd.date_range('2000-01-01T12:00', periods=2) + self.assert_numpy_array_equal(idx.mid, expected) + + self.assertIn('2000-01-01T12', idx) + + target = pd.date_range('1999-12-31T12:00', periods=7, freq='12H') + actual = idx.get_indexer(target) + expected = [-1, -1, 0, 0, 1, 1, -1] + self.assert_numpy_array_equal(actual, expected) + + # def test_math(self): + # # add, subtract, multiply, divide with scalers should be OK + # actual = 2 * self.index + 1 + # expected = IntervalIndex.from_breaks((2 * np.arange(3) + 1)) + # self.assertTrue(expected.equals(actual)) + + # actual = self.index / 2.0 - 1 + # expected = IntervalIndex.from_breaks((np.arange(3) / 2.0 - 1)) + # self.assertTrue(expected.equals(actual)) + + # with self.assertRaises(TypeError): + # # doesn't make sense to add two IntervalIndex objects + # self.index + self.index + + # def test_datetime_math(self): + + # expected = IntervalIndex(pd.date_range('2000-01-02', periods=3)) + # actual = idx + pd.to_timedelta(1, unit='D') + # self.assertTrue(expected.equals(actual)) + + # TODO: other set operations (left join, right join, intersection), + # set operations with conflicting IntervalIndex objects or other dtypes, + # groupby, cut, reset_index... diff --git a/pandas/tools/tests/test_tile.py b/pandas/tools/tests/test_tile.py index eac6973bffb25..00ea8c155f558 100644 --- a/pandas/tools/tests/test_tile.py +++ b/pandas/tools/tests/test_tile.py @@ -4,12 +4,14 @@ import numpy as np from pandas.compat import zip -from pandas import DataFrame, Series, unique +from pandas import DataFrame, Series, unique, isnull import pandas.util.testing as tm from pandas.util.testing import assertRaisesRegexp import pandas.core.common as com from pandas.core.algorithms import quantile +from pandas.core.categorical import Categorical +from pandas.core.interval import Interval, IntervalIndex from pandas.tools.tile import cut, qcut import pandas.tools.tile as tmod @@ -25,26 +27,30 @@ def test_simple(self): def test_bins(self): data = np.array([.2, 1.4, 2.5, 6.2, 9.7, 2.1]) result, bins = cut(data, 3, retbins=True) - tm.assert_numpy_array_equal(result.codes, [0, 0, 0, 1, 2, 0]) - tm.assert_almost_equal(bins, [0.1905, 3.36666667, 6.53333333, 9.7]) + intervals = IntervalIndex.from_breaks(bins.round(3)) + tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 1, 2, 0])) + tm.assert_almost_equal(bins, [0.199, 3.36666667, 6.53333333, 9.7]) def test_right(self): data = np.array([.2, 1.4, 2.5, 6.2, 9.7, 2.1, 2.575]) result, bins = cut(data, 4, right=True, retbins=True) - tm.assert_numpy_array_equal(result.codes, [0, 0, 0, 2, 3, 0, 0]) - tm.assert_almost_equal(bins, [0.1905, 2.575, 4.95, 7.325, 9.7]) + intervals = IntervalIndex.from_breaks(bins.round(3)) + tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 2, 3, 0, 0])) + tm.assert_almost_equal(bins, [0.199, 2.575, 4.95, 7.325, 9.7]) def test_noright(self): data = np.array([.2, 1.4, 2.5, 6.2, 9.7, 2.1, 2.575]) result, bins = cut(data, 4, right=False, retbins=True) - tm.assert_numpy_array_equal(result.codes, [0, 0, 0, 2, 3, 0, 1]) - tm.assert_almost_equal(bins, [0.2, 2.575, 4.95, 7.325, 9.7095]) + intervals = IntervalIndex.from_breaks(bins.round(3), closed='left') + tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 2, 3, 0, 1])) + tm.assert_almost_equal(bins, [0.2, 2.575, 4.95, 7.325, 9.701]) def test_arraylike(self): data = [.2, 1.4, 2.5, 6.2, 9.7, 2.1] result, bins = cut(data, 3, retbins=True) - tm.assert_numpy_array_equal(result.codes, [0, 0, 0, 1, 2, 0]) - tm.assert_almost_equal(bins, [0.1905, 3.36666667, 6.53333333, 9.7]) + intervals = IntervalIndex.from_breaks(bins.round(3)) + tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 1, 2, 0])) + tm.assert_almost_equal(bins, [0.199, 3.36666667, 6.53333333, 9.7]) def test_bins_not_monotonic(self): data = [.2, 1.4, 2.5, 6.2, 9.7, 2.1] @@ -72,14 +78,13 @@ def test_labels(self): arr = np.tile(np.arange(0, 1.01, 0.1), 4) result, bins = cut(arr, 4, retbins=True) - ex_levels = ['(-0.001, 0.25]', '(0.25, 0.5]', '(0.5, 0.75]', - '(0.75, 1]'] - self.assert_numpy_array_equal(result.categories, ex_levels) + ex_levels = IntervalIndex.from_breaks([-1e-3, 0.25, 0.5, 0.75, 1]) + self.assert_numpy_array_equal(unique(result), ex_levels) result, bins = cut(arr, 4, retbins=True, right=False) - ex_levels = ['[0, 0.25)', '[0.25, 0.5)', '[0.5, 0.75)', - '[0.75, 1.001)'] - self.assert_numpy_array_equal(result.categories, ex_levels) + ex_levels = IntervalIndex.from_breaks([0, 0.25, 0.5, 0.75, 1 + 1e-3], + closed='left') + self.assert_numpy_array_equal(unique(result), ex_levels) def test_cut_pass_series_name_to_factor(self): s = Series(np.random.randn(100), name='foo') @@ -91,9 +96,8 @@ def test_label_precision(self): arr = np.arange(0, 0.73, 0.01) result = cut(arr, 4, precision=2) - ex_levels = ['(-0.00072, 0.18]', '(0.18, 0.36]', '(0.36, 0.54]', - '(0.54, 0.72]'] - self.assert_numpy_array_equal(result.categories, ex_levels) + ex_levels = IntervalIndex.from_breaks([-0.01, 0.18, 0.36, 0.54, 0.72]) + self.assert_numpy_array_equal(unique(result), ex_levels) def test_na_handling(self): arr = np.arange(0, 0.75, 0.01) @@ -115,26 +119,26 @@ def test_inf_handling(self): data = np.arange(6) data_ser = Series(data,dtype='int64') - result = cut(data, [-np.inf, 2, 4, np.inf]) - result_ser = cut(data_ser, [-np.inf, 2, 4, np.inf]) + bins = [-np.inf, 2, 4, np.inf] + result = cut(data, bins) + result_ser = cut(data_ser, bins) - ex_categories = ['(-inf, 2]', '(2, 4]', '(4, inf]'] - - tm.assert_numpy_array_equal(result.categories, ex_categories) - tm.assert_numpy_array_equal(result_ser.cat.categories, ex_categories) - self.assertEqual(result[5], '(4, inf]') - self.assertEqual(result[0], '(-inf, 2]') - self.assertEqual(result_ser[5], '(4, inf]') - self.assertEqual(result_ser[0], '(-inf, 2]') + ex_uniques = IntervalIndex.from_breaks(bins).values + tm.assert_numpy_array_equal(unique(result), ex_uniques) + self.assertEqual(result[5], Interval(4, np.inf)) + self.assertEqual(result[0], Interval(-np.inf, 2)) + self.assertEqual(result_ser[5], Interval(4, np.inf)) + self.assertEqual(result_ser[0], Interval(-np.inf, 2)) def test_qcut(self): arr = np.random.randn(1000) labels, bins = qcut(arr, 4, retbins=True) ex_bins = quantile(arr, [0, .25, .5, .75, 1.]) + ex_bins[0] -= 0.001 tm.assert_almost_equal(bins, ex_bins) - ex_levels = cut(arr, ex_bins, include_lowest=True) + ex_levels = cut(arr, ex_bins) self.assert_numpy_array_equal(labels, ex_levels) def test_qcut_bounds(self): @@ -148,7 +152,7 @@ def test_qcut_specify_quantiles(self): factor = qcut(arr, [0, .25, .5, .75, 1.]) expected = qcut(arr, 4) - self.assertTrue(factor.equals(expected)) + self.assert_numpy_array_equal(factor, expected) def test_qcut_all_bins_same(self): assertRaisesRegexp(ValueError, "edges.*unique", qcut, [0,0,0,0,0,0,0,0,0,0], 3) @@ -158,7 +162,7 @@ def test_cut_out_of_bounds(self): result = cut(arr, [-1, 0, 1]) - mask = result.codes == -1 + mask = isnull(result) ex_mask = (arr < -1) | (arr > 1) self.assert_numpy_array_equal(mask, ex_mask) @@ -168,20 +172,13 @@ def test_cut_pass_labels(self): labels = ['Small', 'Medium', 'Large'] result = cut(arr, bins, labels=labels) + exp = ['Medium'] + 4 * ['Small'] + ['Medium', 'Large'] + self.assert_numpy_array_equal(result, exp) - exp = cut(arr, bins) - exp.categories = labels - + result = cut(arr, bins, labels=Categorical.from_codes([0, 1, 2], labels)) + exp = Categorical.from_codes([1] + 4 * [0] + [1, 2], labels) self.assertTrue(result.equals(exp)) - def test_qcut_include_lowest(self): - values = np.arange(10) - - cats = qcut(values, 4) - - ex_levels = ['[0, 2.25]', '(2.25, 4.5]', '(4.5, 6.75]', '(6.75, 9]'] - self.assertTrue((cats.categories == ex_levels).all()) - def test_qcut_nas(self): arr = np.random.randn(100) arr[:20] = np.nan @@ -214,9 +211,9 @@ def test_qcut_binning_issues(self): starts = [] ends = [] - for lev in result.categories: - s, e = lev[1:-1].split(',') - + for lev in np.unique(result): + s = lev.left + e = lev.right self.assertTrue(s != e) starts.append(float(s)) @@ -228,34 +225,31 @@ def test_qcut_binning_issues(self): self.assertTrue(ep < en) self.assertTrue(ep <= sn) - def test_cut_return_categorical(self): - from pandas import Categorical + def test_cut_return_intervals(self): s = Series([0,1,2,3,4,5,6,7,8]) res = cut(s,3) - exp = Series(Categorical.from_codes([0,0,0,1,1,1,2,2,2], - ["(-0.008, 2.667]", "(2.667, 5.333]", "(5.333, 8]"], - ordered=True)) + exp_bins = np.linspace(0, 8, num=4).round(3) + exp_bins[0] -= 1e-3 + exp = Series(IntervalIndex.from_breaks(exp_bins).take([0,0,0,1,1,1,2,2,2])) tm.assert_series_equal(res, exp) - def test_qcut_return_categorical(self): - from pandas import Categorical + def test_qcut_return_intervals(self): s = Series([0,1,2,3,4,5,6,7,8]) res = qcut(s,[0,0.333,0.666,1]) - exp = Series(Categorical.from_codes([0,0,0,1,1,1,2,2,2], - ["[0, 2.664]", "(2.664, 5.328]", "(5.328, 8]"], - ordered=True)) + exp_levels = IntervalIndex.from_breaks([-0.001, 2.664, 5.328, 8]) + exp = Series(exp_levels.take([0,0,0,1,1,1,2,2,2])) tm.assert_series_equal(res, exp) def test_series_retbins(self): # GH 8589 s = Series(np.arange(4)) - result, bins = cut(s, 2, retbins=True) - tm.assert_numpy_array_equal(result.cat.codes.values, [0, 0, 1, 1]) - tm.assert_almost_equal(bins, [-0.003, 1.5, 3]) + result, bins = cut(s, 2, retbins=True, labels=[0, 1]) + tm.assert_numpy_array_equal(result, [0, 0, 1, 1]) + tm.assert_almost_equal(bins, [-0.001, 1.5, 3]) - result, bins = qcut(s, 2, retbins=True) - tm.assert_numpy_array_equal(result.cat.codes.values, [0, 0, 1, 1]) - tm.assert_almost_equal(bins, [0, 1.5, 3]) + result, bins = qcut(s, 2, retbins=True, labels=[0, 1]) + tm.assert_numpy_array_equal(result, [0, 0, 1, 1]) + tm.assert_almost_equal(bins, [-0.001, 1.5, 3]) def curpath(): diff --git a/pandas/tools/tile.py b/pandas/tools/tile.py index 416addfcf2ad5..4aa1b84793f04 100644 --- a/pandas/tools/tile.py +++ b/pandas/tools/tile.py @@ -5,6 +5,7 @@ from pandas.core.api import DataFrame, Series from pandas.core.categorical import Categorical from pandas.core.index import _ensure_index +from pandas.core.interval import IntervalIndex import pandas.core.algorithms as algos import pandas.core.common as com import pandas.core.nanops as nanops @@ -12,8 +13,10 @@ import numpy as np +import warnings -def cut(x, bins, right=True, labels=None, retbins=False, precision=3, + +def cut(x, bins, right=True, labels=None, retbins=False, precision=None, include_lowest=False): """ Return indices of half-open bins to which each value of `x` belongs. @@ -42,7 +45,7 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=3, precision : int The precision at which to store and display the bins labels include_lowest : bool - Whether the first interval should be left-inclusive or not. + Deprecated. Whether the first interval should be left-inclusive or not. Returns ------- @@ -75,6 +78,10 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=3, >>> pd.cut(np.ones(5), 4, labels=False) array([1, 1, 1, 1, 1], dtype=int64) """ + if include_lowest: + warnings.warn("include_lowest is deprecated and will be removed in a " + "future version", FutureWarning, stacklevel=2) + # NOTE: this binning code is changed a bit from histogram for var(x) == 0 if not np.iterable(bins): if np.isscalar(bins) and bins < 1: @@ -84,37 +91,45 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=3, except AttributeError: x = np.asarray(x) sz = x.size + if sz == 0: raise ValueError('Cannot cut empty array') - # handle empty arrays. Can't determine range, so use 0-1. - # rng = (0, 1) - else: - rng = (nanops.nanmin(x), nanops.nanmax(x)) + + rng = (nanops.nanmin(x), nanops.nanmax(x)) mn, mx = [mi + 0.0 for mi in rng] if mn == mx: # adjust end points before binning - mn -= .001 * mn - mx += .001 * mx + if precision is None: + precision = 3 + adj = 10 ** -precision + + mn -= adj + mx += adj bins = np.linspace(mn, mx, bins + 1, endpoint=True) else: # adjust end points after binning bins = np.linspace(mn, mx, bins + 1, endpoint=True) - adj = (mx - mn) * 0.001 # 0.1% of the range + + if precision is None: + precision = _infer_precision(bins) + + adj = 10 ** -precision if right: bins[0] -= adj else: bins[-1] += adj else: + # if isinstance(bins, class_or_type_or_tuple) + bins = np.asarray(bins) if (np.diff(bins) < 0).any(): raise ValueError('bins must increase monotonically.') - return _bins_to_cuts(x, bins, right=right, labels=labels,retbins=retbins, precision=precision, - include_lowest=include_lowest) - + return _bins_to_cuts(x, bins, right=right, labels=labels,retbins=retbins, + precision=precision, include_lowest=include_lowest) -def qcut(x, q, labels=None, retbins=False, precision=3): +def qcut(x, q, labels=None, retbins=False, precision=None): """ Quantile-based discretization function. Discretize variable into equal-sized buckets based on rank or based on sample quantiles. For example @@ -163,15 +178,20 @@ def qcut(x, q, labels=None, retbins=False, precision=3): if com.is_integer(q): quantiles = np.linspace(0, 1, q + 1) else: - quantiles = q + quantiles = np.asarray(q) bins = algos.quantile(x, quantiles) - return _bins_to_cuts(x, bins, labels=labels, retbins=retbins,precision=precision, - include_lowest=True) - + zero_q = (quantiles == 0) + if np.any(zero_q): + if precision is None: + precision = _infer_precision(bins) + adj = 10 ** -precision + bins = np.asarray(bins, dtype=np.float64) + bins[zero_q] -= adj + return cut(x, bins, labels=labels, retbins=retbins, precision=precision) def _bins_to_cuts(x, bins, right=True, labels=None, retbins=False, - precision=3, name=None, include_lowest=False): + precision=None, name=None, include_lowest=False): x_is_series = isinstance(x, Series) series_index = None @@ -196,42 +216,68 @@ def _bins_to_cuts(x, bins, right=True, labels=None, retbins=False, if labels is not False: if labels is None: - increases = 0 - while True: - try: - levels = _format_levels(bins, precision, right=right, - include_lowest=include_lowest) - except ValueError: - increases += 1 - precision += 1 - if increases >= 20: - raise - else: - break + if not include_lowest: + closed = 'right' if right else 'left' + + if precision is None: + precision = _infer_precision(bins) + + breaks = np.around(bins, precision) + labels = IntervalIndex.from_breaks(breaks, closed=closed) + + else: + # this code path is deprecated + if precision is None: + precision = 3 + + increases = 0 + while True: + try: + levels = _format_levels(bins, precision, right=right, + include_lowest=include_lowest) + except ValueError: + increases += 1 + precision += 1 + if increases >= 20: + raise + else: + break else: if len(labels) != len(bins) - 1: raise ValueError('Bin labels must be one fewer than ' 'the number of bin edges') - levels = labels - levels = np.asarray(levels, dtype=object) + if not com.is_categorical(labels): + labels = np.asarray(labels) + np.putmask(ids, na_mask, 0) - fac = Categorical(ids - 1, levels, ordered=True, fastpath=True) + result = com.take_nd(labels, ids - 1) + else: - fac = ids - 1 + result = ids - 1 if has_nas: - fac = fac.astype(np.float64) - np.putmask(fac, na_mask, np.nan) + result = result.astype(np.float64) + np.putmask(result, na_mask, np.nan) if x_is_series: - fac = Series(fac, index=series_index, name=name) + result = Series(result, index=series_index, name=name) if not retbins: - return fac + return result + + return result, bins + + +def _infer_precision(bins): + for precision in range(3, 20): + levels = np.around(bins, precision) + if algos.unique(levels).size == bins.size: + return precision + return 3 # default - return fac, bins +# these functions are only used with the deprecated argument include_lowest def _format_levels(bins, prec, right=True, include_lowest=False): diff --git a/pandas/util/testing.py b/pandas/util/testing.py index c01a7c1d2c240..f6fbeb217875a 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -36,7 +36,8 @@ from pandas.computation import expressions as expr -from pandas import (bdate_range, CategoricalIndex, DatetimeIndex, TimedeltaIndex, PeriodIndex, +from pandas import (bdate_range, CategoricalIndex, IntervalIndex, + DatetimeIndex, TimedeltaIndex, PeriodIndex, Index, MultiIndex, Series, DataFrame, Panel, Panel4D) from pandas.util.decorators import deprecate from pandas import _testing @@ -1121,6 +1122,11 @@ def makeCategoricalIndex(k=10, n=3, name=None): x = rands_array(nchars=4, size=n) return CategoricalIndex(np.random.choice(x,k), name=name) +def makeIntervalIndex(k=10, name=None): + """ make a length k IntervalIndex """ + x = np.linspace(0, 100, num=(k + 1)) + return IntervalIndex.from_breaks(x, name=name) + def makeBoolIndex(k=10, name=None): if k == 1: return Index([True], name=name) From e9045951b68a66bf2b000b7d087c72d18a6f05a7 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sun, 13 Dec 2015 15:28:29 -0800 Subject: [PATCH 2/7] revert some cut default changes --- pandas/tools/tests/test_tile.py | 48 ++++++---- pandas/tools/tile.py | 164 ++++++++------------------------ 2 files changed, 68 insertions(+), 144 deletions(-) diff --git a/pandas/tools/tests/test_tile.py b/pandas/tools/tests/test_tile.py index 00ea8c155f558..7f77aa710df2e 100644 --- a/pandas/tools/tests/test_tile.py +++ b/pandas/tools/tests/test_tile.py @@ -29,28 +29,28 @@ def test_bins(self): result, bins = cut(data, 3, retbins=True) intervals = IntervalIndex.from_breaks(bins.round(3)) tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 1, 2, 0])) - tm.assert_almost_equal(bins, [0.199, 3.36666667, 6.53333333, 9.7]) + tm.assert_almost_equal(bins, [0.1905, 3.36666667, 6.53333333, 9.7]) def test_right(self): data = np.array([.2, 1.4, 2.5, 6.2, 9.7, 2.1, 2.575]) result, bins = cut(data, 4, right=True, retbins=True) intervals = IntervalIndex.from_breaks(bins.round(3)) tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 2, 3, 0, 0])) - tm.assert_almost_equal(bins, [0.199, 2.575, 4.95, 7.325, 9.7]) + tm.assert_almost_equal(bins, [0.1905, 2.575, 4.95, 7.325, 9.7]) def test_noright(self): data = np.array([.2, 1.4, 2.5, 6.2, 9.7, 2.1, 2.575]) result, bins = cut(data, 4, right=False, retbins=True) intervals = IntervalIndex.from_breaks(bins.round(3), closed='left') tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 2, 3, 0, 1])) - tm.assert_almost_equal(bins, [0.2, 2.575, 4.95, 7.325, 9.701]) + tm.assert_almost_equal(bins, [0.2, 2.575, 4.95, 7.325, 9.7095]) def test_arraylike(self): data = [.2, 1.4, 2.5, 6.2, 9.7, 2.1] result, bins = cut(data, 3, retbins=True) intervals = IntervalIndex.from_breaks(bins.round(3)) tm.assert_numpy_array_equal(result, intervals.take([0, 0, 0, 1, 2, 0])) - tm.assert_almost_equal(bins, [0.199, 3.36666667, 6.53333333, 9.7]) + tm.assert_almost_equal(bins, [0.1905, 3.36666667, 6.53333333, 9.7]) def test_bins_not_monotonic(self): data = [.2, 1.4, 2.5, 6.2, 9.7, 2.1] @@ -96,7 +96,7 @@ def test_label_precision(self): arr = np.arange(0, 0.73, 0.01) result = cut(arr, 4, precision=2) - ex_levels = IntervalIndex.from_breaks([-0.01, 0.18, 0.36, 0.54, 0.72]) + ex_levels = IntervalIndex.from_breaks([-0.00072, 0.18, 0.36, 0.54, 0.72]) self.assert_numpy_array_equal(unique(result), ex_levels) def test_na_handling(self): @@ -135,10 +135,9 @@ def test_qcut(self): labels, bins = qcut(arr, 4, retbins=True) ex_bins = quantile(arr, [0, .25, .5, .75, 1.]) - ex_bins[0] -= 0.001 tm.assert_almost_equal(bins, ex_bins) - ex_levels = cut(arr, ex_bins) + ex_levels = cut(arr, ex_bins, include_lowest=True) self.assert_numpy_array_equal(labels, ex_levels) def test_qcut_bounds(self): @@ -179,6 +178,15 @@ def test_cut_pass_labels(self): exp = Categorical.from_codes([1] + 4 * [0] + [1, 2], labels) self.assertTrue(result.equals(exp)) + def test_qcut_include_lowest(self): + values = np.arange(10) + + cats = qcut(values, 4) + + ex_levels = [Interval(0, 2.25, closed='both'), Interval(2.25, 4.5), + Interval(4.5, 6.75), Interval(6.75, 9)] + self.assert_numpy_array_equal(unique(cats), ex_levels) + def test_qcut_nas(self): arr = np.random.randn(100) arr[:20] = np.nan @@ -186,9 +194,7 @@ def test_qcut_nas(self): result = qcut(arr, 4) self.assertTrue(com.isnull(result[:20]).all()) - def test_label_formatting(self): - self.assertEqual(tmod._trim_zeros('1.000'), '1') - + def test_round_frac(self): # it works result = cut(np.arange(11.), 2) @@ -196,10 +202,15 @@ def test_label_formatting(self): # #1979, negative numbers - result = tmod._format_label(-117.9998, precision=3) - self.assertEqual(result, '-118') - result = tmod._format_label(117.9998, precision=3) - self.assertEqual(result, '118') + result = tmod._round_frac(-117.9998, precision=3) + self.assertEqual(result, -118) + result = tmod._round_frac(117.9998, precision=3) + self.assertEqual(result, 118) + + result = tmod._round_frac(117.9998, precision=2) + self.assertEqual(result, 118) + result = tmod._round_frac(0.000123456, precision=2) + self.assertEqual(result, 0.00012) def test_qcut_binning_issues(self): # #1978, 1979 @@ -229,14 +240,15 @@ def test_cut_return_intervals(self): s = Series([0,1,2,3,4,5,6,7,8]) res = cut(s,3) exp_bins = np.linspace(0, 8, num=4).round(3) - exp_bins[0] -= 1e-3 + exp_bins[0] -= 0.008 exp = Series(IntervalIndex.from_breaks(exp_bins).take([0,0,0,1,1,1,2,2,2])) tm.assert_series_equal(res, exp) def test_qcut_return_intervals(self): s = Series([0,1,2,3,4,5,6,7,8]) res = qcut(s,[0,0.333,0.666,1]) - exp_levels = IntervalIndex.from_breaks([-0.001, 2.664, 5.328, 8]) + exp_levels = np.array([Interval(0, 2.664, closed='both'), + Interval(2.664, 5.328), Interval(5.328, 8)]) exp = Series(exp_levels.take([0,0,0,1,1,1,2,2,2])) tm.assert_series_equal(res, exp) @@ -245,11 +257,11 @@ def test_series_retbins(self): s = Series(np.arange(4)) result, bins = cut(s, 2, retbins=True, labels=[0, 1]) tm.assert_numpy_array_equal(result, [0, 0, 1, 1]) - tm.assert_almost_equal(bins, [-0.001, 1.5, 3]) + tm.assert_almost_equal(bins, [-0.003, 1.5, 3]) result, bins = qcut(s, 2, retbins=True, labels=[0, 1]) tm.assert_numpy_array_equal(result, [0, 0, 1, 1]) - tm.assert_almost_equal(bins, [-0.001, 1.5, 3]) + tm.assert_almost_equal(bins, [0, 1.5, 3]) def curpath(): diff --git a/pandas/tools/tile.py b/pandas/tools/tile.py index 4aa1b84793f04..85ec6b898a8f2 100644 --- a/pandas/tools/tile.py +++ b/pandas/tools/tile.py @@ -5,7 +5,7 @@ from pandas.core.api import DataFrame, Series from pandas.core.categorical import Categorical from pandas.core.index import _ensure_index -from pandas.core.interval import IntervalIndex +from pandas.core.interval import IntervalIndex, Interval import pandas.core.algorithms as algos import pandas.core.common as com import pandas.core.nanops as nanops @@ -16,7 +16,7 @@ import warnings -def cut(x, bins, right=True, labels=None, retbins=False, precision=None, +def cut(x, bins, right=True, labels=None, retbins=False, precision=3, include_lowest=False): """ Return indices of half-open bins to which each value of `x` belongs. @@ -42,10 +42,10 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=None, retbins : bool, optional Whether to return the bins or not. Can be useful if bins is given as a scalar. - precision : int + precision : int, optional The precision at which to store and display the bins labels - include_lowest : bool - Deprecated. Whether the first interval should be left-inclusive or not. + include_lowest : bool, optional + Whether the first interval should be left-inclusive or not. Returns ------- @@ -78,10 +78,6 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=None, >>> pd.cut(np.ones(5), 4, labels=False) array([1, 1, 1, 1, 1], dtype=int64) """ - if include_lowest: - warnings.warn("include_lowest is deprecated and will be removed in a " - "future version", FutureWarning, stacklevel=2) - # NOTE: this binning code is changed a bit from histogram for var(x) == 0 if not np.iterable(bins): if np.isscalar(bins) and bins < 1: @@ -99,37 +95,27 @@ def cut(x, bins, right=True, labels=None, retbins=False, precision=None, mn, mx = [mi + 0.0 for mi in rng] if mn == mx: # adjust end points before binning - if precision is None: - precision = 3 - adj = 10 ** -precision - - mn -= adj - mx += adj + mn -= .001 * mn + mx += .001 * mx bins = np.linspace(mn, mx, bins + 1, endpoint=True) else: # adjust end points after binning bins = np.linspace(mn, mx, bins + 1, endpoint=True) - - if precision is None: - precision = _infer_precision(bins) - - adj = 10 ** -precision + adj = (mx - mn) * 0.001 # 0.1% of the range if right: bins[0] -= adj else: bins[-1] += adj else: - # if isinstance(bins, class_or_type_or_tuple) - bins = np.asarray(bins) if (np.diff(bins) < 0).any(): raise ValueError('bins must increase monotonically.') - return _bins_to_cuts(x, bins, right=right, labels=labels,retbins=retbins, + return _bins_to_cuts(x, bins, right=right, labels=labels, retbins=retbins, precision=precision, include_lowest=include_lowest) -def qcut(x, q, labels=None, retbins=False, precision=None): +def qcut(x, q, labels=None, retbins=False, precision=3): """ Quantile-based discretization function. Discretize variable into equal-sized buckets based on rank or based on sample quantiles. For example @@ -148,7 +134,7 @@ def qcut(x, q, labels=None, retbins=False, precision=None): retbins : bool, optional Whether to return the bins or not. Can be useful if bins is given as a scalar. - precision : int + precision : int, optional The precision at which to store and display the bins labels Returns @@ -178,16 +164,10 @@ def qcut(x, q, labels=None, retbins=False, precision=None): if com.is_integer(q): quantiles = np.linspace(0, 1, q + 1) else: - quantiles = np.asarray(q) + quantiles = q bins = algos.quantile(x, quantiles) - zero_q = (quantiles == 0) - if np.any(zero_q): - if precision is None: - precision = _infer_precision(bins) - adj = 10 ** -precision - bins = np.asarray(bins, dtype=np.float64) - bins[zero_q] -= adj - return cut(x, bins, labels=labels, retbins=retbins, precision=precision) + return _bins_to_cuts(x, bins, labels=labels, retbins=retbins, + precision=precision, include_lowest=True) def _bins_to_cuts(x, bins, right=True, labels=None, retbins=False, @@ -216,32 +196,14 @@ def _bins_to_cuts(x, bins, right=True, labels=None, retbins=False, if labels is not False: if labels is None: - if not include_lowest: - closed = 'right' if right else 'left' - - if precision is None: - precision = _infer_precision(bins) + closed = 'right' if right else 'left' + precision = _infer_precision(precision, bins) + breaks = [_round_frac(b, precision) for b in bins] + labels = IntervalIndex.from_breaks(breaks, closed=closed).values - breaks = np.around(bins, precision) - labels = IntervalIndex.from_breaks(breaks, closed=closed) - - else: - # this code path is deprecated - if precision is None: - precision = 3 - - increases = 0 - while True: - try: - levels = _format_levels(bins, precision, right=right, - include_lowest=include_lowest) - except ValueError: - increases += 1 - precision += 1 - if increases >= 20: - raise - else: - break + if right and include_lowest: + labels[0] = Interval(labels[0].left, labels[0].right, + closed='both') else: if len(labels) != len(bins) - 1: @@ -269,75 +231,25 @@ def _bins_to_cuts(x, bins, right=True, labels=None, retbins=False, return result, bins -def _infer_precision(bins): - for precision in range(3, 20): - levels = np.around(bins, precision) - if algos.unique(levels).size == bins.size: - return precision - return 3 # default - - -# these functions are only used with the deprecated argument include_lowest - -def _format_levels(bins, prec, right=True, - include_lowest=False): - fmt = lambda v: _format_label(v, precision=prec) - if right: - levels = [] - for a, b in zip(bins, bins[1:]): - fa, fb = fmt(a), fmt(b) - - if a != b and fa == fb: - raise ValueError('precision too low') - - formatted = '(%s, %s]' % (fa, fb) - - levels.append(formatted) - - if include_lowest: - levels[0] = '[' + levels[0][1:] +def _round_frac(x, precision): + """Round the fractional part of the given number + """ + if not np.isfinite(x) or x == 0: + return x else: - levels = ['[%s, %s)' % (fmt(a), fmt(b)) - for a, b in zip(bins, bins[1:])] - - return levels - - -def _format_label(x, precision=3): - fmt_str = '%%.%dg' % precision - if np.isinf(x): - return str(x) - elif com.is_float(x): frac, whole = np.modf(x) - sgn = '-' if x < 0 else '' - whole = abs(whole) - if frac != 0.0: - val = fmt_str % frac - - # rounded up or down - if '.' not in val: - if x < 0: - return '%d' % (-whole - 1) - else: - return '%d' % (whole + 1) - - if 'e' in val: - return _trim_zeros(fmt_str % x) - else: - val = _trim_zeros(val) - if '.' in val: - return sgn + '.'.join(('%d' % whole, val.split('.')[1])) - else: # pragma: no cover - return sgn + '.'.join(('%d' % whole, val)) + if whole == 0: + digits = -int(np.floor(np.log10(abs(frac)))) - 1 + precision else: - return sgn + '%0.f' % whole - else: - return str(x) + digits = precision + return np.around(x, digits) -def _trim_zeros(x): - while len(x) > 1 and x[-1] == '0': - x = x[:-1] - if len(x) > 1 and x[-1] == '.': - x = x[:-1] - return x +def _infer_precision(base_precision, bins): + """Infer an appropriate precision for _round_frac + """ + for precision in range(base_precision, 20): + levels = [_round_frac(b, precision) for b in bins] + if algos.unique(levels).size == bins.size: + return precision + return base_precision # default From b13e92fd1a0eacabe087b698a0da3d91351b08d1 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Mon, 14 Dec 2015 19:14:51 -0800 Subject: [PATCH 3/7] Speedup IntervalTree --- pandas/src/generate_intervaltree.py | 167 ++-- pandas/src/intervaltree.pyx | 1360 +++++++++++++-------------- 2 files changed, 708 insertions(+), 819 deletions(-) diff --git a/pandas/src/generate_intervaltree.py b/pandas/src/generate_intervaltree.py index c2dfac86f0ad2..f3479ecf603f8 100644 --- a/pandas/src/generate_intervaltree.py +++ b/pandas/src/generate_intervaltree.py @@ -204,11 +204,12 @@ def __repr__(self): the right, and those that overlap with the pivot. """ cdef: - readonly left_node, right_node - readonly {dtype}_t[:] center_left_values, center_right_values - readonly int64_t[:] center_left_indices, center_right_indices + {dtype_title}Closed{closed_title}IntervalNode left_node, right_node + {dtype}_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices readonly {dtype}_t pivot - readonly int64_t n_elements, leaf_size + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node def __init__(self, ndarray[{dtype}_t, ndim=1] left, @@ -216,19 +217,30 @@ def __init__(self, ndarray[int64_t, ndim=1] indices, int64_t leaf_size): - self.pivot = np.median(left + right) / 2 self.n_elements = len(left) self.leaf_size = leaf_size - left_set, right_set, center_set = self.classify_intervals(left, right) + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) - self.left_node = self.new_child_node(left, right, indices, left_set) - self.right_node = self.new_child_node(left, right, indices, right_set) + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) - self.center_left_values, self.center_left_indices = \ - sort_values_and_indices(left, indices, center_set) - self.center_right_values, self.center_right_indices = \ - sort_values_and_indices(right, indices, center_set) + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) @cython.wraparound(False) @cython.boundscheck(False) @@ -237,14 +249,14 @@ def __init__(self, left, right, or overlap with this node's pivot. """ cdef: - int i Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i left_ind = Int64Vector() right_ind = Int64Vector() overlapping_ind = Int64Vector() - for i in range(len(left)): + for i in range(self.n_elements): if right[i] {cmp_right_converse} self.pivot: left_ind.append(i) elif self.pivot {cmp_left_converse} left[i]: @@ -262,103 +274,76 @@ def __init__(self, ndarray[int64_t, ndim=1] indices, ndarray[int64_t, ndim=1] subset): """Create a new child node. - - This should be a terminal leaf node if the number of indices is smaller - than leaf_size. Otherwise it should be a non-terminal node. """ - left = take(left, subset) right = take(right, subset) indices = take(indices, subset) - - if len(indices) <= self.leaf_size: - return {dtype_title}Closed{closed_title}IntervalLeaf( - left, right, indices) - else: - return {dtype_title}Closed{closed_title}IntervalNode( - left, right, indices, self.leaf_size) + return {dtype_title}Closed{closed_title}IntervalNode( + left, right, indices, self.leaf_size) @cython.wraparound(False) @cython.boundscheck(False) - cpdef query(self, Int64Vector result, scalar64_t point): + @cython.initializedcheck(False) + cdef query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ cdef: int64_t[:] indices {dtype}_t[:] values - int i - - if point < self.pivot: - values = self.center_left_values - indices = self.center_left_indices - for i in range(len(values)): - if not values[i] {cmp_left} point: - break - result.append(indices[i]) - self.left_node.query(result, point) - elif point > self.pivot: - values = self.center_right_values - indices = self.center_right_indices - for i in range(len(values) - 1, -1, -1): - if not point {cmp_right} values[i]: - break - result.append(indices[i]) - self.right_node.query(result, point) + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] {cmp_left} point {cmp_right} self.right[i]: + result.append(self.indices[i]) else: - result.extend(self.center_left_indices) + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] {cmp_left} point: + break + result.append(indices[i]) + self.left_node.query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point {cmp_right} values[i]: + break + result.append(indices[i]) + self.right_node.query(result, point) + else: + result.extend(self.center_left_indices) def __repr__(self): - return ('<{dtype_title}Closed{closed_title}IntervalNode: pivot %s, ' - '%s elements (%s left, %s right, %s overlapping)>' % - (self.pivot, self.n_elements, self.left_node.n_elements, - self.right_node.n_elements, len(self.center_left_indices))) + if self.is_leaf_node: + return ('<{dtype_title}Closed{closed_title}IntervalNode: ' + '%s elements (terminal)>' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('<{dtype_title}Closed{closed_title}IntervalNode: pivot %s, ' + '%s elements (%s left, %s right, %s overlapping)>' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) def counts(self): - m = len(self.center_left_values) - l = self.left_node.counts() - r = self.right_node.counts() - return (m, (l, r)) + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) NODE_CLASSES['{dtype}', '{closed}'] = {dtype_title}Closed{closed_title}IntervalNode - - -cdef class {dtype_title}Closed{closed_title}IntervalLeaf: - """Terminal node for an IntervalTree - - Once we get down to a certain size, it doens't make sense to continue the - binary tree structure. Instead, we store interval bounds in 1d arrays use - linear search. - """ - cdef: - readonly {dtype}_t[:] left, right - readonly int64_t[:] indices - - def __init__(self, - {dtype}_t[:] left, - {dtype}_t[:] right, - int64_t[:] indices): - self.left = left - self.right = right - self.indices = indices - - @cython.wraparound(False) - @cython.boundscheck(False) - cpdef query(self, Int64Vector result, scalar64_t point): - for i in range(len(self.left)): - if self.left[i] {cmp_left} point {cmp_right} self.right[i]: - result.append(self.indices[i]) - - def __repr__(self): - return ('<{dtype_title}Closed{closed_title}IntervalLeaf: %s elements>' - % self.n_elements) - - @property - def n_elements(self): - return len(self.left) - - def counts(self): - return self.n_elements ''' diff --git a/pandas/src/intervaltree.pyx b/pandas/src/intervaltree.pyx index f3a7447bc09f8..6f86b6a190452 100644 --- a/pandas/src/intervaltree.pyx +++ b/pandas/src/intervaltree.pyx @@ -185,11 +185,12 @@ cdef class Float64ClosedLeftIntervalNode: the right, and those that overlap with the pivot. """ cdef: - readonly left_node, right_node - readonly float64_t[:] center_left_values, center_right_values - readonly int64_t[:] center_left_indices, center_right_indices + Float64ClosedLeftIntervalNode left_node, right_node + float64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices readonly float64_t pivot - readonly int64_t n_elements, leaf_size + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node def __init__(self, ndarray[float64_t, ndim=1] left, @@ -197,19 +198,30 @@ cdef class Float64ClosedLeftIntervalNode: ndarray[int64_t, ndim=1] indices, int64_t leaf_size): - self.pivot = np.median(left + right) / 2 self.n_elements = len(left) self.leaf_size = leaf_size - left_set, right_set, center_set = self.classify_intervals(left, right) + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) - self.left_node = self.new_child_node(left, right, indices, left_set) - self.right_node = self.new_child_node(left, right, indices, right_set) + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) - self.center_left_values, self.center_left_indices = \ - sort_values_and_indices(left, indices, center_set) - self.center_right_values, self.center_right_indices = \ - sort_values_and_indices(right, indices, center_set) + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) @cython.wraparound(False) @cython.boundscheck(False) @@ -218,14 +230,14 @@ cdef class Float64ClosedLeftIntervalNode: left, right, or overlap with this node's pivot. """ cdef: - int i Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i left_ind = Int64Vector() right_ind = Int64Vector() overlapping_ind = Int64Vector() - for i in range(len(left)): + for i in range(self.n_elements): if right[i] <= self.pivot: left_ind.append(i) elif self.pivot < left[i]: @@ -243,103 +255,79 @@ cdef class Float64ClosedLeftIntervalNode: ndarray[int64_t, ndim=1] indices, ndarray[int64_t, ndim=1] subset): """Create a new child node. - - This should be a terminal leaf node if the number of indices is smaller - than leaf_size. Otherwise it should be a non-terminal node. """ - left = take(left, subset) right = take(right, subset) indices = take(indices, subset) - - if len(indices) <= self.leaf_size: - return Float64ClosedLeftIntervalLeaf( - left, right, indices) - else: - return Float64ClosedLeftIntervalNode( - left, right, indices, self.leaf_size) + return Float64ClosedLeftIntervalNode( + left, right, indices, self.leaf_size) @cython.wraparound(False) @cython.boundscheck(False) - cpdef query(self, Int64Vector result, scalar64_t point): + @cython.initializedcheck(False) + cdef _query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ cdef: int64_t[:] indices float64_t[:] values - int i - - if point < self.pivot: - values = self.center_left_values - indices = self.center_left_indices - for i in range(len(values)): - if not values[i] <= point: - break - result.append(indices[i]) - self.left_node.query(result, point) - elif point > self.pivot: - values = self.center_right_values - indices = self.center_right_indices - for i in range(len(values) - 1, -1, -1): - if not point < values[i]: - break - result.append(indices[i]) - self.right_node.query(result, point) + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] <= point < self.right[i]: + result.append(self.indices[i]) else: - result.extend(self.center_left_indices) - - def __repr__(self): - return ('' % - (self.pivot, self.n_elements, self.left_node.n_elements, - self.right_node.n_elements, len(self.center_left_indices))) - - def counts(self): - m = len(self.center_left_values) - l = self.left_node.counts() - r = self.right_node.counts() - return (m, (l, r)) - -NODE_CLASSES['float64', 'left'] = Float64ClosedLeftIntervalNode - - -cdef class Float64ClosedLeftIntervalLeaf: - """Terminal node for an IntervalTree - - Once we get down to a certain size, it doens't make sense to continue the - binary tree structure. Instead, we store interval bounds in 1d arrays use - linear search. - """ - cdef: - readonly float64_t[:] left, right - readonly int64_t[:] indices - - def __init__(self, - float64_t[:] left, - float64_t[:] right, - int64_t[:] indices): - self.left = left - self.right = right - self.indices = indices + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] <= point: + break + result.append(indices[i]) + self.left_node._query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + self.right_node._query(result, point) + else: + result.extend(self.center_left_indices) - @cython.wraparound(False) - @cython.boundscheck(False) cpdef query(self, Int64Vector result, scalar64_t point): - for i in range(len(self.left)): - if self.left[i] <= point < self.right[i]: - result.append(self.indices[i]) + return self._query(result, point) def __repr__(self): - return ('' - % self.n_elements) - - @property - def n_elements(self): - return len(self.left) + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) def counts(self): - return self.n_elements + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'left'] = Float64ClosedLeftIntervalNode cdef class Float64ClosedRightIntervalNode: @@ -349,11 +337,12 @@ cdef class Float64ClosedRightIntervalNode: the right, and those that overlap with the pivot. """ cdef: - readonly left_node, right_node - readonly float64_t[:] center_left_values, center_right_values - readonly int64_t[:] center_left_indices, center_right_indices + Float64ClosedRightIntervalNode left_node, right_node + float64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices readonly float64_t pivot - readonly int64_t n_elements, leaf_size + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node def __init__(self, ndarray[float64_t, ndim=1] left, @@ -361,19 +350,30 @@ cdef class Float64ClosedRightIntervalNode: ndarray[int64_t, ndim=1] indices, int64_t leaf_size): - self.pivot = np.median(left + right) / 2 self.n_elements = len(left) self.leaf_size = leaf_size - left_set, right_set, center_set = self.classify_intervals(left, right) + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) - self.left_node = self.new_child_node(left, right, indices, left_set) - self.right_node = self.new_child_node(left, right, indices, right_set) + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) - self.center_left_values, self.center_left_indices = \ - sort_values_and_indices(left, indices, center_set) - self.center_right_values, self.center_right_indices = \ - sort_values_and_indices(right, indices, center_set) + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) @cython.wraparound(False) @cython.boundscheck(False) @@ -382,14 +382,14 @@ cdef class Float64ClosedRightIntervalNode: left, right, or overlap with this node's pivot. """ cdef: - int i Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i left_ind = Int64Vector() right_ind = Int64Vector() overlapping_ind = Int64Vector() - for i in range(len(left)): + for i in range(self.n_elements): if right[i] < self.pivot: left_ind.append(i) elif self.pivot <= left[i]: @@ -407,103 +407,79 @@ cdef class Float64ClosedRightIntervalNode: ndarray[int64_t, ndim=1] indices, ndarray[int64_t, ndim=1] subset): """Create a new child node. - - This should be a terminal leaf node if the number of indices is smaller - than leaf_size. Otherwise it should be a non-terminal node. """ - left = take(left, subset) right = take(right, subset) indices = take(indices, subset) - - if len(indices) <= self.leaf_size: - return Float64ClosedRightIntervalLeaf( - left, right, indices) - else: - return Float64ClosedRightIntervalNode( - left, right, indices, self.leaf_size) + return Float64ClosedRightIntervalNode( + left, right, indices, self.leaf_size) @cython.wraparound(False) @cython.boundscheck(False) - cpdef query(self, Int64Vector result, scalar64_t point): + @cython.initializedcheck(False) + cdef _query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ cdef: int64_t[:] indices float64_t[:] values - int i - - if point < self.pivot: - values = self.center_left_values - indices = self.center_left_indices - for i in range(len(values)): - if not values[i] < point: - break - result.append(indices[i]) - self.left_node.query(result, point) - elif point > self.pivot: - values = self.center_right_values - indices = self.center_right_indices - for i in range(len(values) - 1, -1, -1): - if not point <= values[i]: - break - result.append(indices[i]) - self.right_node.query(result, point) + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] < point <= self.right[i]: + result.append(self.indices[i]) else: - result.extend(self.center_left_indices) - - def __repr__(self): - return ('' % - (self.pivot, self.n_elements, self.left_node.n_elements, - self.right_node.n_elements, len(self.center_left_indices))) - - def counts(self): - m = len(self.center_left_values) - l = self.left_node.counts() - r = self.right_node.counts() - return (m, (l, r)) - -NODE_CLASSES['float64', 'right'] = Float64ClosedRightIntervalNode - - -cdef class Float64ClosedRightIntervalLeaf: - """Terminal node for an IntervalTree - - Once we get down to a certain size, it doens't make sense to continue the - binary tree structure. Instead, we store interval bounds in 1d arrays use - linear search. - """ - cdef: - readonly float64_t[:] left, right - readonly int64_t[:] indices - - def __init__(self, - float64_t[:] left, - float64_t[:] right, - int64_t[:] indices): - self.left = left - self.right = right - self.indices = indices + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] < point: + break + result.append(indices[i]) + self.left_node._query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + self.right_node._query(result, point) + else: + result.extend(self.center_left_indices) - @cython.wraparound(False) - @cython.boundscheck(False) cpdef query(self, Int64Vector result, scalar64_t point): - for i in range(len(self.left)): - if self.left[i] < point <= self.right[i]: - result.append(self.indices[i]) + return self._query(result, point) def __repr__(self): - return ('' - % self.n_elements) - - @property - def n_elements(self): - return len(self.left) + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) def counts(self): - return self.n_elements + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'right'] = Float64ClosedRightIntervalNode cdef class Float64ClosedBothIntervalNode: @@ -513,11 +489,12 @@ cdef class Float64ClosedBothIntervalNode: the right, and those that overlap with the pivot. """ cdef: - readonly left_node, right_node - readonly float64_t[:] center_left_values, center_right_values - readonly int64_t[:] center_left_indices, center_right_indices + Float64ClosedBothIntervalNode left_node, right_node + float64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices readonly float64_t pivot - readonly int64_t n_elements, leaf_size + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node def __init__(self, ndarray[float64_t, ndim=1] left, @@ -525,19 +502,30 @@ cdef class Float64ClosedBothIntervalNode: ndarray[int64_t, ndim=1] indices, int64_t leaf_size): - self.pivot = np.median(left + right) / 2 self.n_elements = len(left) self.leaf_size = leaf_size - left_set, right_set, center_set = self.classify_intervals(left, right) + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) - self.left_node = self.new_child_node(left, right, indices, left_set) - self.right_node = self.new_child_node(left, right, indices, right_set) + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) - self.center_left_values, self.center_left_indices = \ - sort_values_and_indices(left, indices, center_set) - self.center_right_values, self.center_right_indices = \ - sort_values_and_indices(right, indices, center_set) + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) @cython.wraparound(False) @cython.boundscheck(False) @@ -546,14 +534,14 @@ cdef class Float64ClosedBothIntervalNode: left, right, or overlap with this node's pivot. """ cdef: - int i Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i left_ind = Int64Vector() right_ind = Int64Vector() overlapping_ind = Int64Vector() - for i in range(len(left)): + for i in range(self.n_elements): if right[i] < self.pivot: left_ind.append(i) elif self.pivot < left[i]: @@ -571,103 +559,79 @@ cdef class Float64ClosedBothIntervalNode: ndarray[int64_t, ndim=1] indices, ndarray[int64_t, ndim=1] subset): """Create a new child node. - - This should be a terminal leaf node if the number of indices is smaller - than leaf_size. Otherwise it should be a non-terminal node. """ - left = take(left, subset) right = take(right, subset) indices = take(indices, subset) - - if len(indices) <= self.leaf_size: - return Float64ClosedBothIntervalLeaf( - left, right, indices) - else: - return Float64ClosedBothIntervalNode( - left, right, indices, self.leaf_size) + return Float64ClosedBothIntervalNode( + left, right, indices, self.leaf_size) @cython.wraparound(False) @cython.boundscheck(False) - cpdef query(self, Int64Vector result, scalar64_t point): + @cython.initializedcheck(False) + cdef _query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ cdef: int64_t[:] indices float64_t[:] values - int i - - if point < self.pivot: - values = self.center_left_values - indices = self.center_left_indices - for i in range(len(values)): - if not values[i] <= point: - break - result.append(indices[i]) - self.left_node.query(result, point) - elif point > self.pivot: - values = self.center_right_values - indices = self.center_right_indices - for i in range(len(values) - 1, -1, -1): - if not point <= values[i]: - break - result.append(indices[i]) - self.right_node.query(result, point) + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] <= point <= self.right[i]: + result.append(self.indices[i]) else: - result.extend(self.center_left_indices) - - def __repr__(self): - return ('' % - (self.pivot, self.n_elements, self.left_node.n_elements, - self.right_node.n_elements, len(self.center_left_indices))) - - def counts(self): - m = len(self.center_left_values) - l = self.left_node.counts() - r = self.right_node.counts() - return (m, (l, r)) - -NODE_CLASSES['float64', 'both'] = Float64ClosedBothIntervalNode - - -cdef class Float64ClosedBothIntervalLeaf: - """Terminal node for an IntervalTree - - Once we get down to a certain size, it doens't make sense to continue the - binary tree structure. Instead, we store interval bounds in 1d arrays use - linear search. - """ - cdef: - readonly float64_t[:] left, right - readonly int64_t[:] indices - - def __init__(self, - float64_t[:] left, - float64_t[:] right, - int64_t[:] indices): - self.left = left - self.right = right - self.indices = indices + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] <= point: + break + result.append(indices[i]) + self.left_node._query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + self.right_node._query(result, point) + else: + result.extend(self.center_left_indices) - @cython.wraparound(False) - @cython.boundscheck(False) cpdef query(self, Int64Vector result, scalar64_t point): - for i in range(len(self.left)): - if self.left[i] <= point <= self.right[i]: - result.append(self.indices[i]) + return self._query(result, point) def __repr__(self): - return ('' - % self.n_elements) - - @property - def n_elements(self): - return len(self.left) + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) def counts(self): - return self.n_elements + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'both'] = Float64ClosedBothIntervalNode cdef class Float64ClosedNeitherIntervalNode: @@ -677,11 +641,12 @@ cdef class Float64ClosedNeitherIntervalNode: the right, and those that overlap with the pivot. """ cdef: - readonly left_node, right_node - readonly float64_t[:] center_left_values, center_right_values - readonly int64_t[:] center_left_indices, center_right_indices + Float64ClosedNeitherIntervalNode left_node, right_node + float64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices readonly float64_t pivot - readonly int64_t n_elements, leaf_size + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node def __init__(self, ndarray[float64_t, ndim=1] left, @@ -689,19 +654,30 @@ cdef class Float64ClosedNeitherIntervalNode: ndarray[int64_t, ndim=1] indices, int64_t leaf_size): - self.pivot = np.median(left + right) / 2 self.n_elements = len(left) self.leaf_size = leaf_size - left_set, right_set, center_set = self.classify_intervals(left, right) + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) - self.left_node = self.new_child_node(left, right, indices, left_set) - self.right_node = self.new_child_node(left, right, indices, right_set) + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) - self.center_left_values, self.center_left_indices = \ - sort_values_and_indices(left, indices, center_set) - self.center_right_values, self.center_right_indices = \ - sort_values_and_indices(right, indices, center_set) + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) @cython.wraparound(False) @cython.boundscheck(False) @@ -710,14 +686,14 @@ cdef class Float64ClosedNeitherIntervalNode: left, right, or overlap with this node's pivot. """ cdef: - int i Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i left_ind = Int64Vector() right_ind = Int64Vector() overlapping_ind = Int64Vector() - for i in range(len(left)): + for i in range(self.n_elements): if right[i] <= self.pivot: left_ind.append(i) elif self.pivot <= left[i]: @@ -735,103 +711,79 @@ cdef class Float64ClosedNeitherIntervalNode: ndarray[int64_t, ndim=1] indices, ndarray[int64_t, ndim=1] subset): """Create a new child node. - - This should be a terminal leaf node if the number of indices is smaller - than leaf_size. Otherwise it should be a non-terminal node. """ - left = take(left, subset) right = take(right, subset) indices = take(indices, subset) - - if len(indices) <= self.leaf_size: - return Float64ClosedNeitherIntervalLeaf( - left, right, indices) - else: - return Float64ClosedNeitherIntervalNode( - left, right, indices, self.leaf_size) + return Float64ClosedNeitherIntervalNode( + left, right, indices, self.leaf_size) @cython.wraparound(False) @cython.boundscheck(False) - cpdef query(self, Int64Vector result, scalar64_t point): + @cython.initializedcheck(False) + cdef _query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ cdef: int64_t[:] indices float64_t[:] values - int i - - if point < self.pivot: - values = self.center_left_values - indices = self.center_left_indices - for i in range(len(values)): - if not values[i] < point: - break - result.append(indices[i]) - self.left_node.query(result, point) - elif point > self.pivot: - values = self.center_right_values - indices = self.center_right_indices - for i in range(len(values) - 1, -1, -1): - if not point < values[i]: - break - result.append(indices[i]) - self.right_node.query(result, point) + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] < point < self.right[i]: + result.append(self.indices[i]) else: - result.extend(self.center_left_indices) - - def __repr__(self): - return ('' % - (self.pivot, self.n_elements, self.left_node.n_elements, - self.right_node.n_elements, len(self.center_left_indices))) - - def counts(self): - m = len(self.center_left_values) - l = self.left_node.counts() - r = self.right_node.counts() - return (m, (l, r)) - -NODE_CLASSES['float64', 'neither'] = Float64ClosedNeitherIntervalNode - - -cdef class Float64ClosedNeitherIntervalLeaf: - """Terminal node for an IntervalTree - - Once we get down to a certain size, it doens't make sense to continue the - binary tree structure. Instead, we store interval bounds in 1d arrays use - linear search. - """ - cdef: - readonly float64_t[:] left, right - readonly int64_t[:] indices - - def __init__(self, - float64_t[:] left, - float64_t[:] right, - int64_t[:] indices): - self.left = left - self.right = right - self.indices = indices + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] < point: + break + result.append(indices[i]) + self.left_node._query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + self.right_node._query(result, point) + else: + result.extend(self.center_left_indices) - @cython.wraparound(False) - @cython.boundscheck(False) cpdef query(self, Int64Vector result, scalar64_t point): - for i in range(len(self.left)): - if self.left[i] < point < self.right[i]: - result.append(self.indices[i]) + return self._query(result, point) def __repr__(self): - return ('' - % self.n_elements) - - @property - def n_elements(self): - return len(self.left) + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) def counts(self): - return self.n_elements + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['float64', 'neither'] = Float64ClosedNeitherIntervalNode cdef class Int64ClosedLeftIntervalNode: @@ -841,11 +793,12 @@ cdef class Int64ClosedLeftIntervalNode: the right, and those that overlap with the pivot. """ cdef: - readonly left_node, right_node - readonly int64_t[:] center_left_values, center_right_values - readonly int64_t[:] center_left_indices, center_right_indices + Int64ClosedLeftIntervalNode left_node, right_node + int64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices readonly int64_t pivot - readonly int64_t n_elements, leaf_size + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node def __init__(self, ndarray[int64_t, ndim=1] left, @@ -853,19 +806,30 @@ cdef class Int64ClosedLeftIntervalNode: ndarray[int64_t, ndim=1] indices, int64_t leaf_size): - self.pivot = np.median(left + right) / 2 self.n_elements = len(left) self.leaf_size = leaf_size - left_set, right_set, center_set = self.classify_intervals(left, right) + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) - self.left_node = self.new_child_node(left, right, indices, left_set) - self.right_node = self.new_child_node(left, right, indices, right_set) + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) - self.center_left_values, self.center_left_indices = \ - sort_values_and_indices(left, indices, center_set) - self.center_right_values, self.center_right_indices = \ - sort_values_and_indices(right, indices, center_set) + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) @cython.wraparound(False) @cython.boundscheck(False) @@ -874,14 +838,14 @@ cdef class Int64ClosedLeftIntervalNode: left, right, or overlap with this node's pivot. """ cdef: - int i Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i left_ind = Int64Vector() right_ind = Int64Vector() overlapping_ind = Int64Vector() - for i in range(len(left)): + for i in range(self.n_elements): if right[i] <= self.pivot: left_ind.append(i) elif self.pivot < left[i]: @@ -899,103 +863,79 @@ cdef class Int64ClosedLeftIntervalNode: ndarray[int64_t, ndim=1] indices, ndarray[int64_t, ndim=1] subset): """Create a new child node. - - This should be a terminal leaf node if the number of indices is smaller - than leaf_size. Otherwise it should be a non-terminal node. """ - left = take(left, subset) right = take(right, subset) indices = take(indices, subset) - - if len(indices) <= self.leaf_size: - return Int64ClosedLeftIntervalLeaf( - left, right, indices) - else: - return Int64ClosedLeftIntervalNode( - left, right, indices, self.leaf_size) + return Int64ClosedLeftIntervalNode( + left, right, indices, self.leaf_size) @cython.wraparound(False) @cython.boundscheck(False) - cpdef query(self, Int64Vector result, scalar64_t point): + @cython.initializedcheck(False) + cdef _query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ cdef: int64_t[:] indices int64_t[:] values - int i - - if point < self.pivot: - values = self.center_left_values - indices = self.center_left_indices - for i in range(len(values)): - if not values[i] <= point: - break - result.append(indices[i]) - self.left_node.query(result, point) - elif point > self.pivot: - values = self.center_right_values - indices = self.center_right_indices - for i in range(len(values) - 1, -1, -1): - if not point < values[i]: - break - result.append(indices[i]) - self.right_node.query(result, point) + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] <= point < self.right[i]: + result.append(self.indices[i]) else: - result.extend(self.center_left_indices) - - def __repr__(self): - return ('' % - (self.pivot, self.n_elements, self.left_node.n_elements, - self.right_node.n_elements, len(self.center_left_indices))) - - def counts(self): - m = len(self.center_left_values) - l = self.left_node.counts() - r = self.right_node.counts() - return (m, (l, r)) - -NODE_CLASSES['int64', 'left'] = Int64ClosedLeftIntervalNode - - -cdef class Int64ClosedLeftIntervalLeaf: - """Terminal node for an IntervalTree - - Once we get down to a certain size, it doens't make sense to continue the - binary tree structure. Instead, we store interval bounds in 1d arrays use - linear search. - """ - cdef: - readonly int64_t[:] left, right - readonly int64_t[:] indices - - def __init__(self, - int64_t[:] left, - int64_t[:] right, - int64_t[:] indices): - self.left = left - self.right = right - self.indices = indices + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] <= point: + break + result.append(indices[i]) + self.left_node._query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + self.right_node._query(result, point) + else: + result.extend(self.center_left_indices) - @cython.wraparound(False) - @cython.boundscheck(False) cpdef query(self, Int64Vector result, scalar64_t point): - for i in range(len(self.left)): - if self.left[i] <= point < self.right[i]: - result.append(self.indices[i]) + return self._query(result, point) def __repr__(self): - return ('' - % self.n_elements) - - @property - def n_elements(self): - return len(self.left) + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) def counts(self): - return self.n_elements + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'left'] = Int64ClosedLeftIntervalNode cdef class Int64ClosedRightIntervalNode: @@ -1005,11 +945,12 @@ cdef class Int64ClosedRightIntervalNode: the right, and those that overlap with the pivot. """ cdef: - readonly left_node, right_node - readonly int64_t[:] center_left_values, center_right_values - readonly int64_t[:] center_left_indices, center_right_indices + Int64ClosedRightIntervalNode left_node, right_node + int64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices readonly int64_t pivot - readonly int64_t n_elements, leaf_size + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node def __init__(self, ndarray[int64_t, ndim=1] left, @@ -1017,19 +958,30 @@ cdef class Int64ClosedRightIntervalNode: ndarray[int64_t, ndim=1] indices, int64_t leaf_size): - self.pivot = np.median(left + right) / 2 self.n_elements = len(left) self.leaf_size = leaf_size - left_set, right_set, center_set = self.classify_intervals(left, right) + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) - self.left_node = self.new_child_node(left, right, indices, left_set) - self.right_node = self.new_child_node(left, right, indices, right_set) + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) - self.center_left_values, self.center_left_indices = \ - sort_values_and_indices(left, indices, center_set) - self.center_right_values, self.center_right_indices = \ - sort_values_and_indices(right, indices, center_set) + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) @cython.wraparound(False) @cython.boundscheck(False) @@ -1038,14 +990,14 @@ cdef class Int64ClosedRightIntervalNode: left, right, or overlap with this node's pivot. """ cdef: - int i Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i left_ind = Int64Vector() right_ind = Int64Vector() overlapping_ind = Int64Vector() - for i in range(len(left)): + for i in range(self.n_elements): if right[i] < self.pivot: left_ind.append(i) elif self.pivot <= left[i]: @@ -1063,103 +1015,79 @@ cdef class Int64ClosedRightIntervalNode: ndarray[int64_t, ndim=1] indices, ndarray[int64_t, ndim=1] subset): """Create a new child node. - - This should be a terminal leaf node if the number of indices is smaller - than leaf_size. Otherwise it should be a non-terminal node. """ - left = take(left, subset) right = take(right, subset) indices = take(indices, subset) - - if len(indices) <= self.leaf_size: - return Int64ClosedRightIntervalLeaf( - left, right, indices) - else: - return Int64ClosedRightIntervalNode( - left, right, indices, self.leaf_size) + return Int64ClosedRightIntervalNode( + left, right, indices, self.leaf_size) @cython.wraparound(False) @cython.boundscheck(False) - cpdef query(self, Int64Vector result, scalar64_t point): + @cython.initializedcheck(False) + cdef _query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ cdef: int64_t[:] indices int64_t[:] values - int i - - if point < self.pivot: - values = self.center_left_values - indices = self.center_left_indices - for i in range(len(values)): - if not values[i] < point: - break - result.append(indices[i]) - self.left_node.query(result, point) - elif point > self.pivot: - values = self.center_right_values - indices = self.center_right_indices - for i in range(len(values) - 1, -1, -1): - if not point <= values[i]: - break - result.append(indices[i]) - self.right_node.query(result, point) + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] < point <= self.right[i]: + result.append(self.indices[i]) else: - result.extend(self.center_left_indices) - - def __repr__(self): - return ('' % - (self.pivot, self.n_elements, self.left_node.n_elements, - self.right_node.n_elements, len(self.center_left_indices))) - - def counts(self): - m = len(self.center_left_values) - l = self.left_node.counts() - r = self.right_node.counts() - return (m, (l, r)) - -NODE_CLASSES['int64', 'right'] = Int64ClosedRightIntervalNode - - -cdef class Int64ClosedRightIntervalLeaf: - """Terminal node for an IntervalTree - - Once we get down to a certain size, it doens't make sense to continue the - binary tree structure. Instead, we store interval bounds in 1d arrays use - linear search. - """ - cdef: - readonly int64_t[:] left, right - readonly int64_t[:] indices - - def __init__(self, - int64_t[:] left, - int64_t[:] right, - int64_t[:] indices): - self.left = left - self.right = right - self.indices = indices + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] < point: + break + result.append(indices[i]) + self.left_node._query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + self.right_node._query(result, point) + else: + result.extend(self.center_left_indices) - @cython.wraparound(False) - @cython.boundscheck(False) cpdef query(self, Int64Vector result, scalar64_t point): - for i in range(len(self.left)): - if self.left[i] < point <= self.right[i]: - result.append(self.indices[i]) + return self._query(result, point) def __repr__(self): - return ('' - % self.n_elements) - - @property - def n_elements(self): - return len(self.left) + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) def counts(self): - return self.n_elements + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'right'] = Int64ClosedRightIntervalNode cdef class Int64ClosedBothIntervalNode: @@ -1169,11 +1097,12 @@ cdef class Int64ClosedBothIntervalNode: the right, and those that overlap with the pivot. """ cdef: - readonly left_node, right_node - readonly int64_t[:] center_left_values, center_right_values - readonly int64_t[:] center_left_indices, center_right_indices + Int64ClosedBothIntervalNode left_node, right_node + int64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices readonly int64_t pivot - readonly int64_t n_elements, leaf_size + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node def __init__(self, ndarray[int64_t, ndim=1] left, @@ -1181,19 +1110,30 @@ cdef class Int64ClosedBothIntervalNode: ndarray[int64_t, ndim=1] indices, int64_t leaf_size): - self.pivot = np.median(left + right) / 2 self.n_elements = len(left) self.leaf_size = leaf_size - left_set, right_set, center_set = self.classify_intervals(left, right) + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) - self.left_node = self.new_child_node(left, right, indices, left_set) - self.right_node = self.new_child_node(left, right, indices, right_set) + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) - self.center_left_values, self.center_left_indices = \ - sort_values_and_indices(left, indices, center_set) - self.center_right_values, self.center_right_indices = \ - sort_values_and_indices(right, indices, center_set) + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) @cython.wraparound(False) @cython.boundscheck(False) @@ -1202,14 +1142,14 @@ cdef class Int64ClosedBothIntervalNode: left, right, or overlap with this node's pivot. """ cdef: - int i Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i left_ind = Int64Vector() right_ind = Int64Vector() overlapping_ind = Int64Vector() - for i in range(len(left)): + for i in range(self.n_elements): if right[i] < self.pivot: left_ind.append(i) elif self.pivot < left[i]: @@ -1227,103 +1167,79 @@ cdef class Int64ClosedBothIntervalNode: ndarray[int64_t, ndim=1] indices, ndarray[int64_t, ndim=1] subset): """Create a new child node. - - This should be a terminal leaf node if the number of indices is smaller - than leaf_size. Otherwise it should be a non-terminal node. """ - left = take(left, subset) right = take(right, subset) indices = take(indices, subset) - - if len(indices) <= self.leaf_size: - return Int64ClosedBothIntervalLeaf( - left, right, indices) - else: - return Int64ClosedBothIntervalNode( - left, right, indices, self.leaf_size) + return Int64ClosedBothIntervalNode( + left, right, indices, self.leaf_size) @cython.wraparound(False) @cython.boundscheck(False) - cpdef query(self, Int64Vector result, scalar64_t point): + @cython.initializedcheck(False) + cdef _query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ cdef: int64_t[:] indices int64_t[:] values - int i - - if point < self.pivot: - values = self.center_left_values - indices = self.center_left_indices - for i in range(len(values)): - if not values[i] <= point: - break - result.append(indices[i]) - self.left_node.query(result, point) - elif point > self.pivot: - values = self.center_right_values - indices = self.center_right_indices - for i in range(len(values) - 1, -1, -1): - if not point <= values[i]: - break - result.append(indices[i]) - self.right_node.query(result, point) + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] <= point <= self.right[i]: + result.append(self.indices[i]) else: - result.extend(self.center_left_indices) - - def __repr__(self): - return ('' % - (self.pivot, self.n_elements, self.left_node.n_elements, - self.right_node.n_elements, len(self.center_left_indices))) - - def counts(self): - m = len(self.center_left_values) - l = self.left_node.counts() - r = self.right_node.counts() - return (m, (l, r)) - -NODE_CLASSES['int64', 'both'] = Int64ClosedBothIntervalNode - - -cdef class Int64ClosedBothIntervalLeaf: - """Terminal node for an IntervalTree - - Once we get down to a certain size, it doens't make sense to continue the - binary tree structure. Instead, we store interval bounds in 1d arrays use - linear search. - """ - cdef: - readonly int64_t[:] left, right - readonly int64_t[:] indices - - def __init__(self, - int64_t[:] left, - int64_t[:] right, - int64_t[:] indices): - self.left = left - self.right = right - self.indices = indices + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] <= point: + break + result.append(indices[i]) + self.left_node._query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point <= values[i]: + break + result.append(indices[i]) + self.right_node._query(result, point) + else: + result.extend(self.center_left_indices) - @cython.wraparound(False) - @cython.boundscheck(False) cpdef query(self, Int64Vector result, scalar64_t point): - for i in range(len(self.left)): - if self.left[i] <= point <= self.right[i]: - result.append(self.indices[i]) + return self._query(result, point) def __repr__(self): - return ('' - % self.n_elements) - - @property - def n_elements(self): - return len(self.left) + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) def counts(self): - return self.n_elements + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'both'] = Int64ClosedBothIntervalNode cdef class Int64ClosedNeitherIntervalNode: @@ -1333,11 +1249,12 @@ cdef class Int64ClosedNeitherIntervalNode: the right, and those that overlap with the pivot. """ cdef: - readonly left_node, right_node - readonly int64_t[:] center_left_values, center_right_values - readonly int64_t[:] center_left_indices, center_right_indices + Int64ClosedNeitherIntervalNode left_node, right_node + int64_t[:] center_left_values, center_right_values, left, right + int64_t[:] center_left_indices, center_right_indices, indices readonly int64_t pivot - readonly int64_t n_elements, leaf_size + readonly int64_t n_elements, n_center, leaf_size + readonly bint is_leaf_node def __init__(self, ndarray[int64_t, ndim=1] left, @@ -1345,19 +1262,30 @@ cdef class Int64ClosedNeitherIntervalNode: ndarray[int64_t, ndim=1] indices, int64_t leaf_size): - self.pivot = np.median(left + right) / 2 self.n_elements = len(left) self.leaf_size = leaf_size - left_set, right_set, center_set = self.classify_intervals(left, right) + if self.n_elements <= leaf_size: + # make this a terminal (leaf) node + self.is_leaf_node = True + self.left = left + self.right = right + self.indices = indices + self.n_center + else: + # calculate a pivot so we can create child nodes + self.is_leaf_node = False + self.pivot = np.median(left + right) / 2 + left_set, right_set, center_set = self.classify_intervals(left, right) - self.left_node = self.new_child_node(left, right, indices, left_set) - self.right_node = self.new_child_node(left, right, indices, right_set) + self.left_node = self.new_child_node(left, right, indices, left_set) + self.right_node = self.new_child_node(left, right, indices, right_set) - self.center_left_values, self.center_left_indices = \ - sort_values_and_indices(left, indices, center_set) - self.center_right_values, self.center_right_indices = \ - sort_values_and_indices(right, indices, center_set) + self.center_left_values, self.center_left_indices = \ + sort_values_and_indices(left, indices, center_set) + self.center_right_values, self.center_right_indices = \ + sort_values_and_indices(right, indices, center_set) + self.n_center = len(self.center_left_indices) @cython.wraparound(False) @cython.boundscheck(False) @@ -1366,14 +1294,14 @@ cdef class Int64ClosedNeitherIntervalNode: left, right, or overlap with this node's pivot. """ cdef: - int i Int64Vector left_ind, right_ind, overlapping_ind + Py_ssize_t i left_ind = Int64Vector() right_ind = Int64Vector() overlapping_ind = Int64Vector() - for i in range(len(left)): + for i in range(self.n_elements): if right[i] <= self.pivot: left_ind.append(i) elif self.pivot <= left[i]: @@ -1391,102 +1319,78 @@ cdef class Int64ClosedNeitherIntervalNode: ndarray[int64_t, ndim=1] indices, ndarray[int64_t, ndim=1] subset): """Create a new child node. - - This should be a terminal leaf node if the number of indices is smaller - than leaf_size. Otherwise it should be a non-terminal node. """ - left = take(left, subset) right = take(right, subset) indices = take(indices, subset) - - if len(indices) <= self.leaf_size: - return Int64ClosedNeitherIntervalLeaf( - left, right, indices) - else: - return Int64ClosedNeitherIntervalNode( - left, right, indices, self.leaf_size) + return Int64ClosedNeitherIntervalNode( + left, right, indices, self.leaf_size) @cython.wraparound(False) @cython.boundscheck(False) - cpdef query(self, Int64Vector result, scalar64_t point): + @cython.initializedcheck(False) + cdef _query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ cdef: int64_t[:] indices int64_t[:] values - int i - - if point < self.pivot: - values = self.center_left_values - indices = self.center_left_indices - for i in range(len(values)): - if not values[i] < point: - break - result.append(indices[i]) - self.left_node.query(result, point) - elif point > self.pivot: - values = self.center_right_values - indices = self.center_right_indices - for i in range(len(values) - 1, -1, -1): - if not point < values[i]: - break - result.append(indices[i]) - self.right_node.query(result, point) + Py_ssize_t i + + if self.is_leaf_node: + # Once we get down to a certain size, it doesn't make sense to + # continue the binary tree structure. Instead, we use linear + # search. + for i in range(self.n_elements): + if self.left[i] < point < self.right[i]: + result.append(self.indices[i]) else: - result.extend(self.center_left_indices) - - def __repr__(self): - return ('' % - (self.pivot, self.n_elements, self.left_node.n_elements, - self.right_node.n_elements, len(self.center_left_indices))) - - def counts(self): - m = len(self.center_left_values) - l = self.left_node.counts() - r = self.right_node.counts() - return (m, (l, r)) - -NODE_CLASSES['int64', 'neither'] = Int64ClosedNeitherIntervalNode - - -cdef class Int64ClosedNeitherIntervalLeaf: - """Terminal node for an IntervalTree - - Once we get down to a certain size, it doens't make sense to continue the - binary tree structure. Instead, we store interval bounds in 1d arrays use - linear search. - """ - cdef: - readonly int64_t[:] left, right - readonly int64_t[:] indices - - def __init__(self, - int64_t[:] left, - int64_t[:] right, - int64_t[:] indices): - self.left = left - self.right = right - self.indices = indices + # There are child nodes. Based on comparing our query to the pivot, + # look at the center values, then go to the relevant child. + if point < self.pivot: + values = self.center_left_values + indices = self.center_left_indices + for i in range(self.n_center): + if not values[i] < point: + break + result.append(indices[i]) + self.left_node._query(result, point) + elif point > self.pivot: + values = self.center_right_values + indices = self.center_right_indices + for i in range(self.n_center - 1, -1, -1): + if not point < values[i]: + break + result.append(indices[i]) + self.right_node._query(result, point) + else: + result.extend(self.center_left_indices) - @cython.wraparound(False) - @cython.boundscheck(False) cpdef query(self, Int64Vector result, scalar64_t point): - for i in range(len(self.left)): - if self.left[i] < point < self.right[i]: - result.append(self.indices[i]) + return self._query(result, point) def __repr__(self): - return ('' - % self.n_elements) - - @property - def n_elements(self): - return len(self.left) + if self.is_leaf_node: + return ('' % self.n_elements) + else: + n_left = self.left_node.n_elements + n_right = self.right_node.n_elements + n_center = self.n_elements - n_left - n_right + return ('' % + (self.pivot, self.n_elements, n_left, n_right, n_center)) def counts(self): - return self.n_elements + if self.is_leaf_node: + return self.n_elements + else: + m = len(self.center_left_values) + l = self.left_node.counts() + r = self.right_node.counts() + return (m, (l, r)) + +NODE_CLASSES['int64', 'neither'] = Int64ClosedNeitherIntervalNode From e0211f99171338ef524df03b834440aeb42ab013 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Tue, 15 Dec 2015 14:32:22 -0800 Subject: [PATCH 4/7] some fixes --- pandas/core/groupby.py | 22 ++++++++++++++++------ pandas/lib.pyx | 4 +++- pandas/src/inference.pyx | 10 +++++++--- pandas/tests/test_groupby.py | 7 ++++--- pandas/tools/tests/test_tile.py | 10 +++++++++- 5 files changed, 39 insertions(+), 14 deletions(-) diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 584b946d47618..0aa490275c0c3 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -17,6 +17,7 @@ from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame from pandas.core.index import Index, MultiIndex, CategoricalIndex, _ensure_index +from pandas.core.interval import IntervalIndex from pandas.core.internals import BlockManager, make_block from pandas.core.series import Series from pandas.core.panel import Panel @@ -2735,12 +2736,20 @@ def value_counts(self, normalize=False, sort=True, ascending=False, if bins is None: lab, lev = algos.factorize(val, sort=True) else: - cat, bins = cut(val, bins, retbins=True) + raise NotImplementedError('this is broken') + lab, bins = cut(val, bins, retbins=True) # bins[:-1] for backward compat; # o.w. cat.categories could be better - lab, lev, dropna = cat.codes, bins[:-1], False - - sorter = np.lexsort((lab, ids)) + # cat = Categorical(cat) + # lab, lev, dropna = cat.codes, bins[:-1], False + + if (lab.dtype == object + and lib.is_interval_array_fixed_closed(lab[notnull(lab)])): + lab_index = Index(lab) + assert isinstance(lab, IntervalIndex) + sorter = np.lexsort((lab_index.left, lab_index.right, ids)) + else: + sorter = np.lexsort((lab, ids)) ids, lab = ids[sorter], lab[sorter] # group boundaries are where group ids change @@ -2771,12 +2780,13 @@ def value_counts(self, normalize=False, sort=True, ascending=False, acc = rep(np.diff(np.r_[idx, len(ids)])) out /= acc[mask] if dropna else acc - if sort and bins is None: + if sort: # and bins is None: cat = ids[inc][mask] if dropna else ids[inc] sorter = np.lexsort((out if ascending else -out, cat)) out, labels[-1] = out[sorter], labels[-1][sorter] - if bins is None: + # if bins is None: + if True: mi = MultiIndex(levels=levels, labels=labels, names=names, verify_integrity=False) diff --git a/pandas/lib.pyx b/pandas/lib.pyx index e0390eeb4e1f7..34eb954615803 100644 --- a/pandas/lib.pyx +++ b/pandas/lib.pyx @@ -306,6 +306,7 @@ def isscalar(object val): - instances of datetime.datetime - instances of datetime.timedelta - Period + - Interval """ @@ -316,7 +317,8 @@ def isscalar(object val): or PyDate_Check(val) or PyDelta_Check(val) or PyTime_Check(val) - or util.is_period_object(val)) + or util.is_period_object(val) + or is_interval(val)) def item_from_zerodim(object val): diff --git a/pandas/src/inference.pyx b/pandas/src/inference.pyx index 4b3716cbcb5ae..e0cba8d73538c 100644 --- a/pandas/src/inference.pyx +++ b/pandas/src/inference.pyx @@ -195,7 +195,7 @@ def infer_dtype(object _values): return 'period' elif is_interval(val): - if is_interval_array(values): + if is_interval_array_fixed_closed(values): return 'interval' for i in range(n): @@ -546,14 +546,18 @@ def is_period_array(ndarray[object] values): cdef inline bint is_interval(object o): return isinstance(o, Interval) -def is_interval_array(ndarray[object] values): +def is_interval_array_fixed_closed(ndarray[object] values): cdef Py_ssize_t i, n = len(values) - + cdef str closed if n == 0: return False for i in range(n): if not is_interval(values[i]): return False + if i == 0: + closed = values[0].closed + elif closed != values[i].closed: + return False return True diff --git a/pandas/tests/test_groupby.py b/pandas/tests/test_groupby.py index bd21053f37568..c1960f915e981 100644 --- a/pandas/tests/test_groupby.py +++ b/pandas/tests/test_groupby.py @@ -9,6 +9,7 @@ from pandas import date_range,bdate_range, Timestamp from pandas.core.index import Index, MultiIndex, Int64Index, CategoricalIndex +from pandas.core.interval import IntervalIndex from pandas.core.api import Categorical, DataFrame from pandas.core.groupby import (SpecificationError, DataError, _nargsort, _lexsort_indexer) @@ -4036,7 +4037,7 @@ def test_groupby_categorical_unequal_len(self): #GH3011 series = Series([np.nan, np.nan, 1, 1, 2, 2, 3, 3, 4, 4]) # The raises only happens with categorical, not with series of types category - bins = pd.cut(series.dropna().values, 4) + bins = pd.cut(series.dropna().values, 4, labels=pd.Categorical(list('abcd'))) # len(bins) != len(series) here self.assertRaises(ValueError,lambda : series.groupby(bins).mean()) @@ -5677,13 +5678,13 @@ def test_groupby_categorical_two_columns(self): d = {'C1': [3, 3, 4, 5], 'C2': [1, 2, 3, 4], 'C3': [10, 100, 200, 34]} test = pd.DataFrame(d) - values = pd.cut(test['C1'], [1, 2, 3, 6]) + values = pd.cut(test['C1'], [1, 2, 3, 6], labels=pd.Categorical(['a', 'b', 'c'])) values.name = "cat" groups_double_key = test.groupby([values,'C2']) res = groups_double_key.agg('mean') nan = np.nan - idx = MultiIndex.from_product([["(1, 2]", "(2, 3]", "(3, 6]"],[1,2,3,4]], + idx = MultiIndex.from_product([['a', 'b', 'c'], [1, 2, 3, 4]], names=["cat", "C2"]) exp = DataFrame({"C1":[nan,nan,nan,nan, 3, 3,nan,nan, nan,nan, 4, 5], "C3":[nan,nan,nan,nan, 10,100,nan,nan, nan,nan,200,34]}, index=idx) diff --git a/pandas/tools/tests/test_tile.py b/pandas/tools/tests/test_tile.py index 7f77aa710df2e..68eab2df0a516 100644 --- a/pandas/tools/tests/test_tile.py +++ b/pandas/tools/tests/test_tile.py @@ -4,7 +4,7 @@ import numpy as np from pandas.compat import zip -from pandas import DataFrame, Series, unique, isnull +from pandas import DataFrame, Series, Index, unique, isnull import pandas.util.testing as tm from pandas.util.testing import assertRaisesRegexp import pandas.core.common as com @@ -194,6 +194,14 @@ def test_qcut_nas(self): result = qcut(arr, 4) self.assertTrue(com.isnull(result[:20]).all()) + def test_qcut_index(self): + # the result is closed on a different side for the first interval, but + # we should still be able to make an index + result = qcut([0, 2], 2) + index = Index(result) + expected = Index([Interval(0, 1, closed='both'), Interval(1, 2)]) + self.assert_numpy_array_equal(index, expected) + def test_round_frac(self): # it works result = cut(np.arange(11.), 2) From 25fe51d5b72bed4ecd9a53941fad94a811e6cb42 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Thu, 17 Dec 2015 10:34:35 -0800 Subject: [PATCH 5/7] Faster IntervalTreeNode.query --- pandas/src/generate_intervaltree.py | 15 ++- pandas/src/intervaltree.pyx | 144 ++++++++++++++++++---------- pandas/tests/test_interval.py | 16 ++-- 3 files changed, 117 insertions(+), 58 deletions(-) diff --git a/pandas/src/generate_intervaltree.py b/pandas/src/generate_intervaltree.py index f3479ecf603f8..275a0d40e2433 100644 --- a/pandas/src/generate_intervaltree.py +++ b/pandas/src/generate_intervaltree.py @@ -207,6 +207,7 @@ def __repr__(self): {dtype_title}Closed{closed_title}IntervalNode left_node, right_node {dtype}_t[:] center_left_values, center_right_values, left, right int64_t[:] center_left_indices, center_right_indices, indices + {dtype}_t min_left, max_right readonly {dtype}_t pivot readonly int64_t n_elements, n_center, leaf_size readonly bint is_leaf_node @@ -219,6 +220,12 @@ def __init__(self, self.n_elements = len(left) self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 if self.n_elements <= leaf_size: # make this a terminal (leaf) node @@ -284,7 +291,7 @@ def __init__(self, @cython.wraparound(False) @cython.boundscheck(False) @cython.initializedcheck(False) - cdef query(self, Int64Vector result, scalar64_t point): + cpdef query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ @@ -310,7 +317,8 @@ def __init__(self, if not values[i] {cmp_left} point: break result.append(indices[i]) - self.left_node.query(result, point) + if point {cmp_right} self.left_node.max_right: + self.left_node.query(result, point) elif point > self.pivot: values = self.center_right_values indices = self.center_right_indices @@ -318,7 +326,8 @@ def __init__(self, if not point {cmp_right} values[i]: break result.append(indices[i]) - self.right_node.query(result, point) + if self.right_node.min_left {cmp_left} point: + self.right_node.query(result, point) else: result.extend(self.center_left_indices) diff --git a/pandas/src/intervaltree.pyx b/pandas/src/intervaltree.pyx index 6f86b6a190452..55782c930d4f8 100644 --- a/pandas/src/intervaltree.pyx +++ b/pandas/src/intervaltree.pyx @@ -188,6 +188,7 @@ cdef class Float64ClosedLeftIntervalNode: Float64ClosedLeftIntervalNode left_node, right_node float64_t[:] center_left_values, center_right_values, left, right int64_t[:] center_left_indices, center_right_indices, indices + float64_t min_left, max_right readonly float64_t pivot readonly int64_t n_elements, n_center, leaf_size readonly bint is_leaf_node @@ -200,6 +201,12 @@ cdef class Float64ClosedLeftIntervalNode: self.n_elements = len(left) self.leaf_size = leaf_size + if left.size > 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 if self.n_elements <= leaf_size: # make this a terminal (leaf) node @@ -265,7 +272,7 @@ cdef class Float64ClosedLeftIntervalNode: @cython.wraparound(False) @cython.boundscheck(False) @cython.initializedcheck(False) - cdef _query(self, Int64Vector result, scalar64_t point): + cpdef query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ @@ -291,7 +298,8 @@ cdef class Float64ClosedLeftIntervalNode: if not values[i] <= point: break result.append(indices[i]) - self.left_node._query(result, point) + if point < self.left_node.max_right: + self.left_node.query(result, point) elif point > self.pivot: values = self.center_right_values indices = self.center_right_indices @@ -299,13 +307,11 @@ cdef class Float64ClosedLeftIntervalNode: if not point < values[i]: break result.append(indices[i]) - self.right_node._query(result, point) + if self.right_node.min_left <= point: + self.right_node.query(result, point) else: result.extend(self.center_left_indices) - cpdef query(self, Int64Vector result, scalar64_t point): - return self._query(result, point) - def __repr__(self): if self.is_leaf_node: return (' 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 if self.n_elements <= leaf_size: # make this a terminal (leaf) node @@ -417,7 +430,7 @@ cdef class Float64ClosedRightIntervalNode: @cython.wraparound(False) @cython.boundscheck(False) @cython.initializedcheck(False) - cdef _query(self, Int64Vector result, scalar64_t point): + cpdef query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ @@ -443,7 +456,8 @@ cdef class Float64ClosedRightIntervalNode: if not values[i] < point: break result.append(indices[i]) - self.left_node._query(result, point) + if point <= self.left_node.max_right: + self.left_node.query(result, point) elif point > self.pivot: values = self.center_right_values indices = self.center_right_indices @@ -451,13 +465,11 @@ cdef class Float64ClosedRightIntervalNode: if not point <= values[i]: break result.append(indices[i]) - self.right_node._query(result, point) + if self.right_node.min_left < point: + self.right_node.query(result, point) else: result.extend(self.center_left_indices) - cpdef query(self, Int64Vector result, scalar64_t point): - return self._query(result, point) - def __repr__(self): if self.is_leaf_node: return (' 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 if self.n_elements <= leaf_size: # make this a terminal (leaf) node @@ -569,7 +588,7 @@ cdef class Float64ClosedBothIntervalNode: @cython.wraparound(False) @cython.boundscheck(False) @cython.initializedcheck(False) - cdef _query(self, Int64Vector result, scalar64_t point): + cpdef query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ @@ -595,7 +614,8 @@ cdef class Float64ClosedBothIntervalNode: if not values[i] <= point: break result.append(indices[i]) - self.left_node._query(result, point) + if point <= self.left_node.max_right: + self.left_node.query(result, point) elif point > self.pivot: values = self.center_right_values indices = self.center_right_indices @@ -603,13 +623,11 @@ cdef class Float64ClosedBothIntervalNode: if not point <= values[i]: break result.append(indices[i]) - self.right_node._query(result, point) + if self.right_node.min_left <= point: + self.right_node.query(result, point) else: result.extend(self.center_left_indices) - cpdef query(self, Int64Vector result, scalar64_t point): - return self._query(result, point) - def __repr__(self): if self.is_leaf_node: return (' 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 if self.n_elements <= leaf_size: # make this a terminal (leaf) node @@ -721,7 +746,7 @@ cdef class Float64ClosedNeitherIntervalNode: @cython.wraparound(False) @cython.boundscheck(False) @cython.initializedcheck(False) - cdef _query(self, Int64Vector result, scalar64_t point): + cpdef query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ @@ -747,7 +772,8 @@ cdef class Float64ClosedNeitherIntervalNode: if not values[i] < point: break result.append(indices[i]) - self.left_node._query(result, point) + if point < self.left_node.max_right: + self.left_node.query(result, point) elif point > self.pivot: values = self.center_right_values indices = self.center_right_indices @@ -755,13 +781,11 @@ cdef class Float64ClosedNeitherIntervalNode: if not point < values[i]: break result.append(indices[i]) - self.right_node._query(result, point) + if self.right_node.min_left < point: + self.right_node.query(result, point) else: result.extend(self.center_left_indices) - cpdef query(self, Int64Vector result, scalar64_t point): - return self._query(result, point) - def __repr__(self): if self.is_leaf_node: return (' 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 if self.n_elements <= leaf_size: # make this a terminal (leaf) node @@ -873,7 +904,7 @@ cdef class Int64ClosedLeftIntervalNode: @cython.wraparound(False) @cython.boundscheck(False) @cython.initializedcheck(False) - cdef _query(self, Int64Vector result, scalar64_t point): + cpdef query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ @@ -899,7 +930,8 @@ cdef class Int64ClosedLeftIntervalNode: if not values[i] <= point: break result.append(indices[i]) - self.left_node._query(result, point) + if point < self.left_node.max_right: + self.left_node.query(result, point) elif point > self.pivot: values = self.center_right_values indices = self.center_right_indices @@ -907,13 +939,11 @@ cdef class Int64ClosedLeftIntervalNode: if not point < values[i]: break result.append(indices[i]) - self.right_node._query(result, point) + if self.right_node.min_left <= point: + self.right_node.query(result, point) else: result.extend(self.center_left_indices) - cpdef query(self, Int64Vector result, scalar64_t point): - return self._query(result, point) - def __repr__(self): if self.is_leaf_node: return (' 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 if self.n_elements <= leaf_size: # make this a terminal (leaf) node @@ -1025,7 +1062,7 @@ cdef class Int64ClosedRightIntervalNode: @cython.wraparound(False) @cython.boundscheck(False) @cython.initializedcheck(False) - cdef _query(self, Int64Vector result, scalar64_t point): + cpdef query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ @@ -1051,7 +1088,8 @@ cdef class Int64ClosedRightIntervalNode: if not values[i] < point: break result.append(indices[i]) - self.left_node._query(result, point) + if point <= self.left_node.max_right: + self.left_node.query(result, point) elif point > self.pivot: values = self.center_right_values indices = self.center_right_indices @@ -1059,13 +1097,11 @@ cdef class Int64ClosedRightIntervalNode: if not point <= values[i]: break result.append(indices[i]) - self.right_node._query(result, point) + if self.right_node.min_left < point: + self.right_node.query(result, point) else: result.extend(self.center_left_indices) - cpdef query(self, Int64Vector result, scalar64_t point): - return self._query(result, point) - def __repr__(self): if self.is_leaf_node: return (' 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 if self.n_elements <= leaf_size: # make this a terminal (leaf) node @@ -1177,7 +1220,7 @@ cdef class Int64ClosedBothIntervalNode: @cython.wraparound(False) @cython.boundscheck(False) @cython.initializedcheck(False) - cdef _query(self, Int64Vector result, scalar64_t point): + cpdef query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ @@ -1203,7 +1246,8 @@ cdef class Int64ClosedBothIntervalNode: if not values[i] <= point: break result.append(indices[i]) - self.left_node._query(result, point) + if point <= self.left_node.max_right: + self.left_node.query(result, point) elif point > self.pivot: values = self.center_right_values indices = self.center_right_indices @@ -1211,13 +1255,11 @@ cdef class Int64ClosedBothIntervalNode: if not point <= values[i]: break result.append(indices[i]) - self.right_node._query(result, point) + if self.right_node.min_left <= point: + self.right_node.query(result, point) else: result.extend(self.center_left_indices) - cpdef query(self, Int64Vector result, scalar64_t point): - return self._query(result, point) - def __repr__(self): if self.is_leaf_node: return (' 0: + self.min_left = left.min() + self.max_right = right.max() + else: + self.min_left = 0 + self.max_right = 0 if self.n_elements <= leaf_size: # make this a terminal (leaf) node @@ -1329,7 +1378,7 @@ cdef class Int64ClosedNeitherIntervalNode: @cython.wraparound(False) @cython.boundscheck(False) @cython.initializedcheck(False) - cdef _query(self, Int64Vector result, scalar64_t point): + cpdef query(self, Int64Vector result, scalar64_t point): """Recursively query this node and its sub-nodes for intervals that overlap with the query point. """ @@ -1355,7 +1404,8 @@ cdef class Int64ClosedNeitherIntervalNode: if not values[i] < point: break result.append(indices[i]) - self.left_node._query(result, point) + if point < self.left_node.max_right: + self.left_node.query(result, point) elif point > self.pivot: values = self.center_right_values indices = self.center_right_indices @@ -1363,13 +1413,11 @@ cdef class Int64ClosedNeitherIntervalNode: if not point < values[i]: break result.append(indices[i]) - self.right_node._query(result, point) + if self.right_node.min_left < point: + self.right_node.query(result, point) else: result.extend(self.center_left_indices) - cpdef query(self, Int64Vector result, scalar64_t point): - return self._query(result, point) - def __repr__(self): if self.is_leaf_node: return (' Date: Tue, 5 Jan 2016 10:59:22 -0800 Subject: [PATCH 6/7] ENH: Adding arthimetic operations --- pandas/src/interval.pyx | 59 +++++++++++++++++++++++++++++++++-- pandas/tests/test_interval.py | 15 ++++++--- 2 files changed, 67 insertions(+), 7 deletions(-) diff --git a/pandas/src/interval.pyx b/pandas/src/interval.pyx index 5e51e34a9b5d4..65e0fd65679b5 100644 --- a/pandas/src/interval.pyx +++ b/pandas/src/interval.pyx @@ -8,10 +8,9 @@ import cython from cpython.object cimport (Py_EQ, Py_NE, Py_GT, Py_LT, Py_GE, Py_LE, PyObject_RichCompare) - +import numbers _VALID_CLOSED = frozenset(['left', 'right', 'both', 'neither']) - cdef class IntervalMixin: property closed_left: def __get__(self): @@ -101,6 +100,62 @@ cdef class Interval(IntervalMixin): end_symbol = ']' if self.closed_right else ')' return '%s%s, %s%s' % (start_symbol, self.left, self.right, end_symbol) + def __add__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left + y, self.right + y) + elif isinstance(y, Interval) and isinstance(self, numbers.Number): + return Interval(y.left + self, y.right + self) + else: + raise TypeError("unsupported operand type(s) for +: '%s' and '%s'" % + (type(self).__name__, type(y).__name__)) + + def __sub__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left - y, self.right - y) + elif isinstance(y, Interval) and isinstance(self, numbers.Number): + return Interval(y.left - self, y.right - self) + else: + raise TypeError("unsupported operand type(s) for -: '%s' and '%s'" % + (type(self).__name__, type(y).__name__)) + + def __mult__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left * y, self.right * y) + elif isinstance(y, Interval) and isinstance(self, numbers.Number): + return Interval(y.left * self, y.right * self) + else: + raise TypeError("unsupported operand type(s) for *: '%s' and '%s'" % + (type(self).__name__, type(y).__name__)) + + def __div__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left / y, self.right / y) + elif isinstance(y, Interval) and isinstance(self, numbers.Number): + return Interval(y.left / self, y.right / self) + else: + raise TypeError("unsupported operand type(s) for /: '%s' and '%s'" % + (type(self).__name__, type(y).__name__)) + + def overlap(self, y): + if not isinstance(y, Interval): + raise TypeError("unsupported operand type(s) for &: '%s' and '%s'" % + (type(self).__name__, type(y).__name__)) + return self.left <= y.right and y.left <= self.right + + def intersect(self, y): + if not isinstance(y, Interval): + raise TypeError("unsupported operand type(s) for &: '%s' and '%s'" % + (type(self).__name__, type(y).__name__)) + if not self.overlap(y): + # ideally, I would like to return an empty interval, and + # not raise a ValueError + raise ValueError("%s does not overlap with %s" % (self, y)) + return Interval(max(self.left, y.left), + min(self.right, y.right)) + + def __and__(self, y): + return self.intersect(y) + @cython.wraparound(False) @cython.boundscheck(False) diff --git a/pandas/tests/test_interval.py b/pandas/tests/test_interval.py index 56012acebaa53..c7b258b56f1b9 100644 --- a/pandas/tests/test_interval.py +++ b/pandas/tests/test_interval.py @@ -63,10 +63,15 @@ def test_hash(self): # should not raise hash(self.interval) - # def test_math(self): - # expected = Interval(1, 2) - # actual = self.interval + 1 - # self.assertEqual(expected, actual) + def test_math_add(self): + expected = Interval(1, 2) + actual = self.interval + 1 + self.assertEqual(expected, actual) + + def test_math_mult(self): + expected = Interval(0, 2) + actual = self.interval * 2 + self.assertEqual(expected, actual) class TestIntervalTree(tm.TestCase): @@ -515,7 +520,7 @@ def test_datetime(self): self.assert_numpy_array_equal(actual, expected) # def test_math(self): - # # add, subtract, multiply, divide with scalers should be OK + # # add, subtract, multiply, divide with scalars should be OK # actual = 2 * self.index + 1 # expected = IntervalIndex.from_breaks((2 * np.arange(3) + 1)) # self.assertTrue(expected.equals(actual)) From 09847390bc50ae8e452ae9429e3d0b2c6cbe81be Mon Sep 17 00:00:00 2001 From: Jamie Morton Date: Wed, 6 Jan 2016 10:38:21 -0800 Subject: [PATCH 7/7] ENH: Basic arthimetic operations added for interval class STY: Adding @shoyer's comments BUG: Fixing multiplication and division ENH: Adding in PY2 compatible division TST: Adding PY2 specific test. May want to configure in travis ... TST: Modifying division test to handle both PY2 and PY3 --- pandas/src/interval.pyx | 51 +++++++++++++---------------------- pandas/tests/test_interval.py | 47 ++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/pandas/src/interval.pyx b/pandas/src/interval.pyx index 65e0fd65679b5..495730e0fd6a1 100644 --- a/pandas/src/interval.pyx +++ b/pandas/src/interval.pyx @@ -106,55 +106,39 @@ cdef class Interval(IntervalMixin): elif isinstance(y, Interval) and isinstance(self, numbers.Number): return Interval(y.left + self, y.right + self) else: - raise TypeError("unsupported operand type(s) for +: '%s' and '%s'" % - (type(self).__name__, type(y).__name__)) + raise NotImplemented def __sub__(self, y): if isinstance(y, numbers.Number): return Interval(self.left - y, self.right - y) - elif isinstance(y, Interval) and isinstance(self, numbers.Number): - return Interval(y.left - self, y.right - self) else: - raise TypeError("unsupported operand type(s) for -: '%s' and '%s'" % - (type(self).__name__, type(y).__name__)) + raise NotImplemented - def __mult__(self, y): + def __mul__(self, y): if isinstance(y, numbers.Number): return Interval(self.left * y, self.right * y) elif isinstance(y, Interval) and isinstance(self, numbers.Number): return Interval(y.left * self, y.right * self) else: - raise TypeError("unsupported operand type(s) for *: '%s' and '%s'" % - (type(self).__name__, type(y).__name__)) + return NotImplemented def __div__(self, y): if isinstance(y, numbers.Number): return Interval(self.left / y, self.right / y) - elif isinstance(y, Interval) and isinstance(self, numbers.Number): - return Interval(y.left / self, y.right / self) else: - raise TypeError("unsupported operand type(s) for /: '%s' and '%s'" % - (type(self).__name__, type(y).__name__)) - - def overlap(self, y): - if not isinstance(y, Interval): - raise TypeError("unsupported operand type(s) for &: '%s' and '%s'" % - (type(self).__name__, type(y).__name__)) - return self.left <= y.right and y.left <= self.right - - def intersect(self, y): - if not isinstance(y, Interval): - raise TypeError("unsupported operand type(s) for &: '%s' and '%s'" % - (type(self).__name__, type(y).__name__)) - if not self.overlap(y): - # ideally, I would like to return an empty interval, and - # not raise a ValueError - raise ValueError("%s does not overlap with %s" % (self, y)) - return Interval(max(self.left, y.left), - min(self.right, y.right)) - - def __and__(self, y): - return self.intersect(y) + return NotImplemented + + def __truediv__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left / y, self.right / y) + else: + return NotImplemented + + def __floordiv__(self, y): + if isinstance(y, numbers.Number): + return Interval(self.left // y, self.right // y) + else: + return NotImplemented @cython.wraparound(False) @@ -184,3 +168,4 @@ cpdef intervals_to_interval_bounds(np.ndarray intervals): elif closed != interval.closed: raise ValueError('intervals must all be closed on the same side') return left, right, closed + diff --git a/pandas/tests/test_interval.py b/pandas/tests/test_interval.py index c7b258b56f1b9..1b52e2629b38c 100644 --- a/pandas/tests/test_interval.py +++ b/pandas/tests/test_interval.py @@ -1,3 +1,4 @@ +from __future__ import division import numpy as np from pandas.core.interval import Interval, IntervalIndex @@ -68,11 +69,57 @@ def test_math_add(self): actual = self.interval + 1 self.assertEqual(expected, actual) + expected = Interval(1, 2) + actual = 1 + self.interval + self.assertEqual(expected, actual) + + actual = self.interval + actual += 1 + self.assertEqual(expected, actual) + + with self.assertRaises(TypeError): + self.interval + Interval(1, 2) + + def test_math_sub(self): + expected = Interval(-1, 0) + actual = self.interval - 1 + self.assertEqual(expected, actual) + + actual = self.interval + actual -= 1 + self.assertEqual(expected, actual) + + with self.assertRaises(TypeError): + self.interval - Interval(1, 2) + def test_math_mult(self): expected = Interval(0, 2) actual = self.interval * 2 self.assertEqual(expected, actual) + expected = Interval(0, 2) + actual = 2 * self.interval + self.assertEqual(expected, actual) + + actual = self.interval + actual *= 2 + self.assertEqual(expected, actual) + + with self.assertRaises(TypeError): + self.interval * Interval(1, 2) + + def test_math_div(self): + expected = Interval(0, 0.5) + actual = self.interval / 2.0 + self.assertEqual(expected, actual) + + actual = self.interval + actual /= 2.0 + self.assertEqual(expected, actual) + + with self.assertRaises(TypeError): + self.interval / Interval(1, 2) + class TestIntervalTree(tm.TestCase): def setUp(self):