diff --git a/arraycontext/__init__.py b/arraycontext/__init__.py index c40117e8..2ea0c6d7 100644 --- a/arraycontext/__init__.py +++ b/arraycontext/__init__.py @@ -85,6 +85,7 @@ from .impl.numpy import NumpyArrayContext from .impl.pyopencl import PyOpenCLArrayContext from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext +from .impl.numpy import NumpyArrayContext from .loopy import make_loopy_program from .pytest import ( PytestArrayContextFactory, diff --git a/arraycontext/impl/pyopencl/__init__.py b/arraycontext/impl/pyopencl/__init__.py index 83dc9c05..ba6b224e 100644 --- a/arraycontext/impl/pyopencl/__init__.py +++ b/arraycontext/impl/pyopencl/__init__.py @@ -304,10 +304,33 @@ def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationU "to create this kernel?") all_inames = default_entrypoint.all_inames() - + # FIXME: This could be much smarter. inner_iname = None - if "i0" in all_inames: + # import with underscore to avoid DeprecationWarning + # from arraycontext.metadata import _FirstAxisIsElementsTag + from meshmode.transform_metadata import FirstAxisIsElementsTag + + if (len(default_entrypoint.instructions) == 1 + and isinstance(default_entrypoint.instructions[0], lp.Assignment) + and any(isinstance(tag, FirstAxisIsElementsTag) + # FIXME: Firedrake branch lacks kernel tags + for tag in getattr(default_entrypoint, "tags", ()))): + stmt, = default_entrypoint.instructions + + out_inames = [v.name for v in stmt.assignee.index_tuple] + assert out_inames + outer_iname = out_inames[0] + if len(out_inames) >= 2: + inner_iname = out_inames[1] + + elif "iel" in all_inames: + outer_iname = "iel" + + if "idof" in all_inames: + inner_iname = "idof" + + elif "i0" in all_inames: outer_iname = "i0" if "i1" in all_inames: diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index c6508e3a..d3cf1257 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -90,8 +90,9 @@ def zeros(self, shape, dtype): def zeros_like(self, ary): def _zeros_like(array): - return self._array_context.zeros( - array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) + # return self._array_context.zeros( + # array.shape, array.dtype).copy(axes=array.axes, tags=array.tags) + return 0*array return self._array_context._rec_map_container( _zeros_like, ary, default_scalar=0) @@ -101,8 +102,9 @@ def ones_like(self, ary): def full_like(self, ary, fill_value): def _full_like(subary): - return pt.full(subary.shape, fill_value, subary.dtype).copy( - axes=subary.axes, tags=subary.tags) + # return pt.full(subary.shape, fill_value, subary.dtype).copy( + # axes=subary.axes, tags=subary.tags) + return fill_value * (0*subary + 1) return self._array_context._rec_map_container( _full_like, ary, default_scalar=fill_value) diff --git a/arraycontext/pytest.py b/arraycontext/pytest.py index c778154d..eaf82879 100644 --- a/arraycontext/pytest.py +++ b/arraycontext/pytest.py @@ -36,9 +36,9 @@ from arraycontext import NumpyArrayContext from arraycontext.context import ArrayContext - # {{{ array context factories + class PytestArrayContextFactory: @classmethod def is_available(cls) -> bool: @@ -223,6 +223,27 @@ def __call__(self): def __str__(self): return "" +# {{{ _PytestArrayContextFactory + + +class _NumpyArrayContextForTests(NumpyArrayContext): + def transform_loopy_program(self, t_unit): + return t_unit + + +class _PytestNumpyArrayContextFactory(PytestArrayContextFactory): + def __init__(self, *args, **kwargs): + super().__init__() + + def __call__(self): + return _NumpyArrayContextForTests() + + def __str__(self): + return "" + +# }}} + + # {{{ _PytestArrayContextFactory diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 7bea0dc4..3894eafc 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -949,8 +949,9 @@ def _check_allclose(f, arg1, arg2, atol=5.0e-14): with pytest.raises(TypeError): ary_of_dofs + dc_of_dofs - with pytest.raises(TypeError): - dc_of_dofs + ary_of_dofs + if not isinstance(actx, NumpyArrayContext): + with pytest.raises(TypeError): + dc_of_dofs + ary_of_dofs with pytest.raises(TypeError): ary_dof + dc_of_dofs @@ -1152,7 +1153,12 @@ def test_flatten_with_leaf_class(actx_factory): # {{{ test from_numpy and to_numpy def test_numpy_conversion(actx_factory): + from arraycontext import NumpyArrayContext + actx = actx_factory() + if isinstance(actx, NumpyArrayContext): + pytest.skip("Irrelevant tests for NumpyArrayContext") + rng = np.random.default_rng() nelements = 42