Skip to content

Commit

Permalink
refactor elemwise ops
Browse files Browse the repository at this point in the history
  • Loading branch information
daletovar committed Sep 24, 2020
1 parent e9f2c87 commit 6288e7b
Show file tree
Hide file tree
Showing 10 changed files with 1,068 additions and 1,034 deletions.
2 changes: 1 addition & 1 deletion sparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._sparse_array import SparseArray
from ._utils import check_compressed_axes, normalize_axis, check_zero_fill_value

from ._coo.umath import elemwise
from ._umath import elemwise
from ._coo.common import (
clip,
triu,
Expand Down
47 changes: 20 additions & 27 deletions sparse/_compressed/compressed.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy as _copy
import numpy as np
import operator
from numpy.lib.mixins import NDArrayOperatorsMixin
Expand Down Expand Up @@ -141,6 +142,24 @@ def __init__(
if prune:
self._prune()

def copy(self, deep=True):
"""Return a copy of the array.
Parameters
----------
deep : boolean, optional
If True (default), the internal coords and data arrays are also
copied. Set to ``False`` to only make a shallow copy.
"""
return _copy.deepcopy(self) if deep else _copy.copy(self)

def _make_shallow_copy_of(self, other):
self.data = other.data
self.indices = other.indices
self.indptr = other.indptr
self.compressed_axes = other.compressed_axes
super().__init__(other.shape, fill_value=other.fill_value)

@classmethod
def from_numpy(cls, x, compressed_axes=None, fill_value=0):
coo = COO(x, fill_value=fill_value)
Expand Down Expand Up @@ -262,8 +281,7 @@ def __str__(self):
__getitem__ = getitem

def _reduce_calc(self, method, axis, keepdims=False, **kwargs):

if axis[0] is None:
if axis[0] is None or np.array_equal(axis, np.arange(self.ndim, dtype=np.intp)):
x = self.flatten().tocoo()
out = x.reduce(method, axis=None, keepdims=keepdims, **kwargs)
if keepdims:
Expand Down Expand Up @@ -744,31 +762,6 @@ def __rmatmul__(self, other):
except NotImplementedError:
return NotImplemented

def astype(self, dtype, casting="unsafe", copy=True):
"""
Copy of the array, cast to a specified type.
See also
--------
scipy.sparse.coo_matrix.astype : SciPy sparse equivalent function
numpy.ndarray.astype : NumPy equivalent ufunc.
:obj:`COO.elemwise`: Apply an arbitrary element-wise function to one or two
arguments.
"""
if self.dtype == dtype and not copy:
return self
# temporary solution
return GCXS(
(
np.array(self.data, copy=copy).astype(dtype),
np.array(self.indices, copy=copy),
np.array(self.indptr, copy=copy),
),
shape=self.shape,
compressed_axes=self.compressed_axes,
fill_value=self.fill_value,
)

def _prune(self):
"""
Prunes data so that if any fill-values are present, they are removed
Expand Down
2 changes: 0 additions & 2 deletions sparse/_coo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .core import COO, as_coo
from .umath import elemwise
from .common import (
concatenate,
clip,
Expand All @@ -26,7 +25,6 @@
__all__ = [
"COO",
"as_coo",
"elemwise",
"concatenate",
"clip",
"stack",
Expand Down
4 changes: 2 additions & 2 deletions sparse/_coo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def kron(a, b):
[0, 0, 0, 0, 0, 0, 1, 2, 3]], dtype=int64)
"""
from .core import COO
from .umath import _cartesian_product
from .._umath import _cartesian_product

check_zero_fill_value(a, b)

Expand Down Expand Up @@ -556,7 +556,7 @@ def where(condition, x=None, y=None):
--------
numpy.where : Equivalent Numpy function.
"""
from .umath import elemwise
from .._umath import elemwise

x_given = x is not None
y_given = y is not None
Expand Down
Loading

0 comments on commit 6288e7b

Please sign in to comment.