Skip to content

Commit

Permalink
Merge branch 'main' into production-pilot
Browse files Browse the repository at this point in the history
  • Loading branch information
MTCam committed Oct 26, 2023
2 parents f552fac + e53fa90 commit c61cb14
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 6 deletions.
10 changes: 5 additions & 5 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,7 +754,7 @@ def _flatten_with_leaf_class(subary: ArrayOrContainer) -> Any:


def unflatten(
template: ArrayOrContainerT, ary: Any,
template: ArrayOrContainerT, ary: Array,
actx: ArrayContext, *,
strict: bool = True) -> ArrayOrContainerT:
"""Unflatten an array *ary* produced by :func:`flatten` back into an
Expand Down Expand Up @@ -822,17 +822,17 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:

# {{{ check strides

if strict and hasattr(template_subary, "strides"):
if strict and hasattr(template_subary_c, "strides"):
# Checking strides for 0 sized arrays is ill-defined
# since they cannot be indexed
if (
# Mypy has a point: nobody promised a .strides attribute.
template_subary_c.strides != subary.strides # type: ignore[attr-defined] # noqa: E501
template_subary_c.strides != subary.strides
and template_subary_c.size != 0
):
raise ValueError(
# Mypy has a point: nobody promised a .strides attribute.
f"strides do not match template: got {subary.strides}, " # type: ignore[attr-defined] # noqa: E501
f"strides do not match template: got {subary.strides}, "
f"expected {template_subary_c.strides}")

# }}}
Expand All @@ -849,7 +849,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
f"array context: got '{type(ary).__name__}', expected one of "
f"{actx.array_types}")

if ary.ndim != 1:
if len(ary.shape) != 1:
raise ValueError(
"only one dimensional arrays can be unflattened: "
f"'ary' has shape {ary.shape}")
Expand Down
4 changes: 4 additions & 0 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ class Array(Protocol):
.. attribute:: shape
.. attribute:: size
.. attribute:: dtype
.. attribute:: __getitem__
"""

@property
Expand All @@ -206,6 +207,9 @@ def size(self) -> int:
def dtype(self) -> "np.dtype[Any]":
...

def __getitem__(self, index: Union[slice, int]) -> "Array":
...


# deprecated, use ScalarLike instead
Scalar = ScalarLike
Expand Down
87 changes: 87 additions & 0 deletions arraycontext/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
"""


import operator
from typing import Any

import numpy as np

from arraycontext.container import NotAnArrayContainerError, serialize_container
Expand Down Expand Up @@ -100,6 +103,89 @@ def conjugate(self, x):

conj = conjugate

# {{{ linspace

# based on
# https://github.com/numpy/numpy/blob/v1.25.0/numpy/core/function_base.py#L24-L182

def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis=0):
num = operator.index(num)
if num < 0:
raise ValueError("Number of samples, %s, must be non-negative." % num)
div = (num - 1) if endpoint else num

# Convert float/complex array scalars to float, gh-3504
# and make sure one can use variables that have an __array_interface__,
# gh-6634

if isinstance(start, self._array_context.array_types):
raise NotImplementedError("start as an actx array")
if isinstance(stop, self._array_context.array_types):
raise NotImplementedError("stop as an actx array")

start = np.array(start) * 1.0
stop = np.array(stop) * 1.0

dt = np.result_type(start, stop, float(num))
if dtype is None:
dtype = dt
integer_dtype = False
else:
integer_dtype = np.issubdtype(dtype, np.integer)

delta = stop - start

y = self.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim)

if div > 0:
step = delta / div
#any_step_zero = _nx.asanyarray(step == 0).any()
any_step_zero = self._array_context.to_numpy((step == 0)).any()
if any_step_zero:
delta_actx = self._array_context.from_numpy(delta)

# Special handling for denormal numbers, gh-5437
y = y / div
y = y * delta_actx
else:
step_actx = self._array_context.from_numpy(step)
y = y * step_actx
else:
delta_actx = self._array_context.from_numpy(delta)
# sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0)
# have an undefined step
step = np.NaN
# Multiply with delta to allow possible override of output class.
y = y * delta_actx

y += start

# FIXME reenable, without in-place ops
# if endpoint and num > 1:
# y[-1, ...] = stop

if axis != 0:
# y = _nx.moveaxis(y, 0, axis)
raise NotImplementedError("axis != 0")

if integer_dtype:
y = self.floor(y) # pylint: disable=no-member

# FIXME: Use astype
# https://github.com/inducer/pytato/issues/456
if retstep:
return y, step
#return y.astype(dtype), step
else:
return y
#return y.astype(dtype)

# }}}

def arange(self, *args: Any, **kwargs: Any):
raise NotImplementedError

# }}}


Expand Down Expand Up @@ -180,6 +266,7 @@ def norm(self, ary, ord=None):
return actx.np.sum(abs(ary)**ord)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")

# }}}


Expand Down
4 changes: 3 additions & 1 deletion arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def _copy(subary):

return self._array_context._rec_map_container(_copy, ary)

def arange(self, *args, **kwargs):
return cl_array.arange(self._array_context.queue, *args, **kwargs)

# }}}

# {{{ array manipulation routines
Expand Down Expand Up @@ -360,7 +363,6 @@ def where_inner(inner_crit, inner_then, inner_else):

# }}}


# }}}


Expand Down
7 changes: 7 additions & 0 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
THE SOFTWARE.
"""
from functools import partial, reduce
from typing import Any

import numpy as np

Expand Down Expand Up @@ -98,6 +99,12 @@ def _full_like(subary):
return self._array_context._rec_map_container(
_full_like, ary, default_scalar=fill_value)

def arange(self, *args: Any, **kwargs: Any):
return pt.arange(*args, **kwargs)

def full(self, shape, fill_value, dtype=None):
return pt.full(shape, fill_value, dtype)

# }}}

# {{{ array manipulation routines
Expand Down
22 changes: 22 additions & 0 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,28 @@ def test_compile_anonymous_function(actx_factory):
42)


@pytest.mark.parametrize(
("args", "kwargs"), [
((1, 2, 10), {}),
((1, 2, 10), {"endpoint": False}),
((1, 2, 10), {"endpoint": True}),
((2, -3, 20), {}),
((1, 5j, 20), {"dtype": np.complex128}),
((1, 5, 20), {"dtype": np.complex128}),
((1, 5, 20), {"dtype": np.int32}),
])
def test_linspace(actx_factory, args, kwargs):
if "Jax" in actx_factory.__class__.__name__:
pytest.xfail("jax actx does not have arange")

actx = actx_factory()

actx_linspace = actx.to_numpy(actx.np.linspace(*args, **kwargs))
np_linspace = np.linspace(*args, **kwargs)

assert np.allclose(actx_linspace, np_linspace)


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down

0 comments on commit c61cb14

Please sign in to comment.