Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Implement reshape function #776

Merged
merged 2 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sparse/mlir_backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
)
from ._ops import (
add,
reshape,
)

__all__ = [
"add",
"asarray",
"asdtype",
"reshape",
]
2 changes: 1 addition & 1 deletion sparse/mlir_backend/_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def from_sps(cls, arr: np.ndarray) -> "Dense":

return dense_instance

def to_sps(self, shape: tuple[int, ...]) -> sps.csr_array:
def to_sps(self, shape: tuple[int, ...]) -> np.ndarray:
data = ranked_memref_to_numpy(self.data)
return data.reshape(shape)

Expand Down
53 changes: 51 additions & 2 deletions sparse/mlir_backend/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from mlir import ir
from mlir.dialects import arith, func, linalg, sparse_tensor, tensor

import numpy as np

from ._common import fn_cache
from ._constructors import Tensor
from ._constructors import Tensor, numpy_to_ranked_memref
from ._core import CWD, DEBUG, MLIR_C_RUNNER_UTILS, ctx, pm
from ._dtypes import DType, FloatingDType
from ._dtypes import DType, FloatingDType, Index


@fn_cache
Expand Down Expand Up @@ -68,6 +70,31 @@ def add(a, b):
return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])


# @fn_cache
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
def get_reshape_module(
a_tensor_type: ir.RankedTensorType,
shape_tensor_type: ir.RankedTensorType,
out_tensor_type: ir.RankedTensorType,
) -> ir.Module:
with ir.Location.unknown(ctx):
module = ir.Module.create()

with ir.InsertionPoint(module.body):

@func.FuncOp.from_py_func(a_tensor_type, shape_tensor_type)
def reshape(a, shape):
return tensor.reshape(out_tensor_type, a, shape)

reshape.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
if DEBUG:
(CWD / "reshape_module.mlir").write_text(str(module))
pm.run(module.operation)
if DEBUG:
(CWD / "reshape_module_opt.mlir").write_text(str(module))

return mlir.execution_engine.ExecutionEngine(module, opt_level=2, shared_libs=[MLIR_C_RUNNER_UTILS])


def add(x1: Tensor, x2: Tensor) -> Tensor:
ret_obj = x1._format_class()
out_tensor_type = x1._obj.get_tensor_definition(x1.shape)
Expand All @@ -88,3 +115,25 @@ def add(x1: Tensor, x2: Tensor) -> Tensor:
*x2._obj.to_module_arg(),
)
return Tensor(ret_obj, shape=out_tensor_type.shape)


def reshape(x: Tensor, /, shape: tuple[int, ...]) -> Tensor:
ret_obj = x._format_class()
x_tensor_type = x._obj.get_tensor_definition(x.shape)
out_tensor_type = x._obj.get_tensor_definition(shape)

with ir.Location.unknown(ctx):
shape_tensor_type = ir.RankedTensorType.get([len(shape)], Index.get_mlir_type())

# TODO: Add proper caching
hameerabbasi marked this conversation as resolved.
Show resolved Hide resolved
reshape_module = get_reshape_module(x_tensor_type, shape_tensor_type, out_tensor_type)

shape = np.array(shape)
reshape_module.invoke(
"reshape",
ctypes.pointer(ctypes.pointer(ret_obj)),
*x._obj.to_module_arg(),
ctypes.pointer(ctypes.pointer(numpy_to_ranked_memref(shape))),
)

return Tensor(ret_obj, shape=out_tensor_type.shape)
82 changes: 77 additions & 5 deletions sparse/mlir_backend/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ def sampler_real_floating(size: tuple[int, ...]):
raise NotImplementedError(f"{dtype=} not yet supported.")


def get_exampe_csf_arrays(dtype: np.dtype) -> tuple:
pos_1 = np.array([0, 1, 3], dtype=np.int64)
crd_1 = np.array([1, 0, 1], dtype=np.int64)
pos_2 = np.array([0, 3, 5, 7], dtype=np.int64)
crd_2 = np.array([0, 1, 3, 0, 3, 0, 1], dtype=np.int64)
data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype)
return pos_1, crd_1, pos_2, crd_2, data


@parametrize_dtypes
@pytest.mark.parametrize("shape", [(100,), (10, 200), (5, 10, 20)])
def test_dense_format(dtype, shape):
Expand Down Expand Up @@ -176,11 +185,7 @@ def test_add(rng, dtype):
@parametrize_dtypes
def test_csf_format(dtype):
SHAPE = (2, 2, 4)
pos_1 = np.array([0, 1, 3], dtype=np.int64)
crd_1 = np.array([1, 0, 1], dtype=np.int64)
pos_2 = np.array([0, 3, 5, 7], dtype=np.int64)
crd_2 = np.array([0, 1, 3, 0, 3, 0, 1], dtype=np.int64)
data = np.array([1, 2, 3, 4, 5, 6, 7], dtype=dtype)
pos_1, crd_1, pos_2, crd_2, data = get_exampe_csf_arrays(dtype)
csf = [pos_1, crd_1, pos_2, crd_2, data]

csf_tensor = sparse.asarray(csf, shape=SHAPE, dtype=sparse.asdtype(dtype), format="csf")
Expand All @@ -192,3 +197,70 @@ def test_csf_format(dtype):
csf_2 = [pos_1, crd_1, pos_2, crd_2, data * 2]
for actual, expected in zip(res_tensor, csf_2, strict=False):
np.testing.assert_array_equal(actual, expected)


@parametrize_dtypes
def test_reshape(rng, dtype):
DENSITY = 0.5
sampler = generate_sampler(dtype, rng)

# CSR, CSC, COO
for shape, new_shape in [((100, 50), (25, 200)), ((80, 1), (8, 10))]:
for format in ["csr", "csc", "coo"]:
if format == "coo":
# NOTE: Blocked by https://github.com/llvm/llvm-project/pull/109135
continue
if format == "csc":
# NOTE: Blocked by https://github.com/llvm/llvm-project/issues/109641
continue

arr = sps.random_array(
shape, density=DENSITY, format=format, dtype=dtype, random_state=rng, data_sampler=sampler
)
if format == "coo":
arr.sum_duplicates()

tensor = sparse.asarray(arr)

actual = sparse.reshape(tensor, shape=new_shape).to_scipy_sparse()
expected = arr.todense().reshape(new_shape)

np.testing.assert_array_equal(actual.todense(), expected)

# CSF
csf_shape = (2, 2, 4)
for shape, new_shape, expected_arrs in [
(
csf_shape,
(4, 4, 1),
[
np.array([0, 0, 3, 5, 7]),
np.array([0, 1, 3, 0, 3, 0, 1]),
np.array([0, 1, 2, 3, 4, 5, 6, 7]),
np.array([0, 0, 0, 0, 0, 0, 0]),
np.array([1, 2, 3, 4, 5, 6, 7]),
],
),
(
csf_shape,
(2, 1, 8),
[
np.array([0, 1, 2]),
np.array([0, 0]),
np.array([0, 3, 7]),
np.array([4, 5, 7, 0, 3, 4, 5]),
np.array([1, 2, 3, 4, 5, 6, 7]),
],
),
]:
csf = get_exampe_csf_arrays(dtype)
csf_tensor = sparse.asarray(csf, shape=shape, dtype=sparse.asdtype(dtype), format="csf")

result = sparse.reshape(csf_tensor, shape=new_shape).to_scipy_sparse()

for actual, expected in zip(result, expected_arrs, strict=False):
np.testing.assert_array_equal(actual, expected)

# DENSE
# NOTE: dense reshape is probably broken in MLIR
# dense = np.arange(math.prod(SHAPE), dtype=dtype).reshape(SHAPE)
Loading