Skip to content

Commit

Permalink
[CLN] More cython cleanups, with bonus type annotations (pandas-dev#2…
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and victor committed Sep 30, 2018
1 parent a87e834 commit 5b981fe
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 36 deletions.
8 changes: 4 additions & 4 deletions pandas/_libs/algos_common_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def get_dispatch(dtypes):

@cython.wraparound(False)
@cython.boundscheck(False)
cpdef map_indices_{{name}}(ndarray[{{c_type}}] index):
def map_indices_{{name}}(ndarray[{{c_type}}] index):
"""
Produce a dict mapping the values of the input array to their respective
locations.
Expand Down Expand Up @@ -542,7 +542,7 @@ def put2d_{{name}}_{{dest_type}}(ndarray[{{c_type}}, ndim=2, cast=True] values,
cdef int PLATFORM_INT = (<ndarray> np.arange(0, dtype=np.intp)).descr.type_num


cpdef ensure_platform_int(object arr):
def ensure_platform_int(object arr):
# GH3033, GH1392
# platform int is the size of the int pointer, e.g. np.intp
if util.is_array(arr):
Expand All @@ -554,7 +554,7 @@ cpdef ensure_platform_int(object arr):
return np.array(arr, dtype=np.intp)


cpdef ensure_object(object arr):
def ensure_object(object arr):
if util.is_array(arr):
if (<ndarray> arr).descr.type_num == NPY_OBJECT:
return arr
Expand Down Expand Up @@ -587,7 +587,7 @@ def get_dispatch(dtypes):

{{for name, c_type, dtype in get_dispatch(dtypes)}}

cpdef ensure_{{name}}(object arr, copy=True):
def ensure_{{name}}(object arr, copy=True):
if util.is_array(arr):
if (<ndarray> arr).descr.type_num == NPY_{{c_type}}:
return arr
Expand Down
1 change: 1 addition & 0 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ cdef inline float64_t median_linear(float64_t* a, int n) nogil:
return result


# TODO: Is this redundant with algos.kth_smallest?
cdef inline float64_t kth_smallest_c(float64_t* a,
Py_ssize_t k,
Py_ssize_t n) nogil:
Expand Down
1 change: 1 addition & 0 deletions pandas/_libs/hashing.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ cdef inline void _sipround(uint64_t* v0, uint64_t* v1,
v2[0] = _rotl(v2[0], 32)


# TODO: This appears unused; remove?
cpdef uint64_t siphash(bytes data, bytes key) except? 0:
if len(key) != 16:
raise ValueError("key should be a 16-byte bytestring, "
Expand Down
2 changes: 1 addition & 1 deletion pandas/_libs/index.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ cpdef get_value_at(ndarray arr, object loc, object tz=None):
return util.get_value_at(arr, loc)


cpdef object get_value_box(ndarray arr, object loc):
def get_value_box(arr: ndarray, loc: object) -> object:
return get_value_at(arr, loc, tz=None)


Expand Down
4 changes: 2 additions & 2 deletions pandas/_libs/internals.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ cdef class BlockPlacement:
return self._as_slice


cpdef slice_canonize(slice s):
cdef slice_canonize(slice s):
"""
Convert slice to canonical bounded form.
"""
Expand Down Expand Up @@ -255,7 +255,7 @@ cpdef Py_ssize_t slice_len(
return length


cpdef slice_get_indices_ex(slice slc, Py_ssize_t objlen=PY_SSIZE_T_MAX):
cdef slice_get_indices_ex(slice slc, Py_ssize_t objlen=PY_SSIZE_T_MAX):
"""
Get (start, stop, step, length) tuple for a slice.
Expand Down
5 changes: 3 additions & 2 deletions pandas/_libs/interval.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ cdef class Interval(IntervalMixin):

@cython.wraparound(False)
@cython.boundscheck(False)
cpdef intervals_to_interval_bounds(ndarray intervals,
bint validate_closed=True):
def intervals_to_interval_bounds(ndarray intervals,
bint validate_closed=True):
"""
Parameters
----------
Expand Down Expand Up @@ -415,4 +415,5 @@ cpdef intervals_to_interval_bounds(ndarray intervals,

return left, right, closed


include "intervaltree.pxi"
38 changes: 18 additions & 20 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def memory_usage_of_objects(object[:] arr):
# ----------------------------------------------------------------------


cpdef bint is_scalar(object val):
def is_scalar(val: object) -> bint:
"""
Return True if given value is scalar.

Expand Down Expand Up @@ -137,7 +137,7 @@ cpdef bint is_scalar(object val):
or util.is_period_object(val)
or is_decimal(val)
or is_interval(val)
or is_offset(val))
or util.is_offset_object(val))


def item_from_zerodim(object val):
Expand Down Expand Up @@ -457,7 +457,7 @@ def maybe_booleans_to_slice(ndarray[uint8_t] mask):

@cython.wraparound(False)
@cython.boundscheck(False)
cpdef bint array_equivalent_object(object[:] left, object[:] right):
def array_equivalent_object(left: object[:], right: object[:]) -> bint:
""" perform an element by element comparion on 1-d object arrays
taking into account nan positions """
cdef:
Expand Down Expand Up @@ -497,7 +497,7 @@ def astype_intsafe(ndarray[object] arr, new_dtype):
return result


cpdef ndarray[object] astype_unicode(ndarray arr):
def astype_unicode(arr: ndarray) -> ndarray[object]:
cdef:
Py_ssize_t i, n = arr.size
ndarray[object] result = np.empty(n, dtype=object)
Expand All @@ -508,7 +508,7 @@ cpdef ndarray[object] astype_unicode(ndarray arr):
return result


cpdef ndarray[object] astype_str(ndarray arr):
def astype_str(arr: ndarray) -> ndarray[object]:
cdef:
Py_ssize_t i, n = arr.size
ndarray[object] result = np.empty(n, dtype=object)
Expand Down Expand Up @@ -791,19 +791,19 @@ def indices_fast(object index, ndarray[int64_t] labels, list keys,

# core.common import for fast inference checks

cpdef bint is_float(object obj):
def is_float(obj: object) -> bint:
return util.is_float_object(obj)


cpdef bint is_integer(object obj):
def is_integer(obj: object) -> bint:
return util.is_integer_object(obj)


cpdef bint is_bool(object obj):
def is_bool(obj: object) -> bint:
return util.is_bool_object(obj)


cpdef bint is_complex(object obj):
def is_complex(obj: object) -> bint:
return util.is_complex_object(obj)


Expand All @@ -815,15 +815,11 @@ cpdef bint is_interval(object obj):
return getattr(obj, '_typ', '_typ') == 'interval'


cpdef bint is_period(object val):
def is_period(val: object) -> bint:
""" Return a boolean if this is a Period object """
return util.is_period_object(val)


cdef inline bint is_offset(object val):
return getattr(val, '_typ', '_typ') == 'dateoffset'


_TYPE_MAP = {
'categorical': 'categorical',
'category': 'categorical',
Expand Down Expand Up @@ -1225,7 +1221,7 @@ def infer_dtype(object value, bint skipna=False):
if is_bytes_array(values, skipna=skipna):
return 'bytes'

elif is_period(val):
elif util.is_period_object(val):
if is_period_array(values):
return 'period'

Expand All @@ -1243,7 +1239,7 @@ def infer_dtype(object value, bint skipna=False):
return 'mixed'


cpdef object infer_datetimelike_array(object arr):
def infer_datetimelike_array(arr: object) -> object:
"""
infer if we have a datetime or timedelta array
- date: we have *only* date and maybe strings, nulls
Expand Down Expand Up @@ -1580,7 +1576,7 @@ cpdef bint is_datetime64_array(ndarray values):
return validator.validate(values)


cpdef bint is_datetime_with_singletz_array(ndarray values):
def is_datetime_with_singletz_array(values: ndarray) -> bint:
"""
Check values have the same tzinfo attribute.
Doesn't check values are datetime-like types.
Expand Down Expand Up @@ -1616,7 +1612,8 @@ cdef class TimedeltaValidator(TemporalValidator):
return is_null_timedelta64(value)


cpdef bint is_timedelta_array(ndarray values):
# TODO: Not used outside of tests; remove?
def is_timedelta_array(values: ndarray) -> bint:
cdef:
TimedeltaValidator validator = TimedeltaValidator(len(values),
skipna=True)
Expand All @@ -1628,7 +1625,8 @@ cdef class Timedelta64Validator(TimedeltaValidator):
return util.is_timedelta64_object(value)


cpdef bint is_timedelta64_array(ndarray values):
# TODO: Not used outside of tests; remove?
def is_timedelta64_array(values: ndarray) -> bint:
cdef:
Timedelta64Validator validator = Timedelta64Validator(len(values),
skipna=True)
Expand Down Expand Up @@ -1672,7 +1670,7 @@ cpdef bint is_time_array(ndarray values, bint skipna=False):

cdef class PeriodValidator(TemporalValidator):
cdef inline bint is_value_typed(self, object value) except -1:
return is_period(value)
return util.is_period_object(value)

cdef inline bint is_valid_null(self, object value) except -1:
return is_null_period(value)
Expand Down
2 changes: 1 addition & 1 deletion pandas/_libs/tslib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def format_array_from_datetime(ndarray[int64_t] values, object tz=None,
return result


cpdef array_with_unit_to_datetime(ndarray values, unit, errors='coerce'):
def array_with_unit_to_datetime(ndarray values, unit, errors='coerce'):
"""
convert the ndarray according to the unit
if errors:
Expand Down
11 changes: 5 additions & 6 deletions pandas/_libs/writers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
cimport cython
from cython cimport Py_ssize_t

from cpython cimport (PyString_Check, PyBytes_Check, PyUnicode_Check,
PyBytes_GET_SIZE, PyUnicode_GET_SIZE)
from cpython cimport PyBytes_GET_SIZE, PyUnicode_GET_SIZE

try:
from cpython cimport PyString_GET_SIZE
Expand Down Expand Up @@ -124,19 +123,19 @@ def convert_json_to_lines(object arr):
# stata, pytables
@cython.boundscheck(False)
@cython.wraparound(False)
cpdef Py_ssize_t max_len_string_array(pandas_string[:] arr):
def max_len_string_array(pandas_string[:] arr) -> Py_ssize_t:
""" return the maximum size of elements in a 1-dim string array """
cdef:
Py_ssize_t i, m = 0, l = 0, length = arr.shape[0]
pandas_string v

for i in range(length):
v = arr[i]
if PyString_Check(v):
if isinstance(v, str):
l = PyString_GET_SIZE(v)
elif PyBytes_Check(v):
elif isinstance(v, bytes):
l = PyBytes_GET_SIZE(v)
elif PyUnicode_Check(v):
elif isinstance(v, unicode):
l = PyUnicode_GET_SIZE(v)

if l > m:
Expand Down

0 comments on commit 5b981fe

Please sign in to comment.