Skip to content

Commit

Permalink
[BugFix] Fix offset caching in lowering (apache#38)
Browse files Browse the repository at this point in the history
* Hack compact dataflow check in a dirty way

* Add two-K square sum test

* Mark skipped tests

* Fix offset saving in lowering
  • Loading branch information
MasterJH5574 committed Dec 22, 2021
1 parent ae02971 commit 144bb60
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 12 deletions.
3 changes: 2 additions & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ Definition of a scope that is a stage pipeline:
if (it_atomic != block->annotations.end()) {
is_atomic = ((*it_atomic).second).as<IntImmNode>()->value;
}
if (!is_atomic) {
// Todo(ruihang): Temporary hack. Deal with the "sparse" annotation later.
if (!is_atomic && block->annotations.find("sparse") == block->annotations.end()) {
throw NotCompactDataFlowError(self->mod, GetRef<Stmt>(scope_root_subtree->stmt),
GetRef<Block>(block));
}
Expand Down
27 changes: 17 additions & 10 deletions src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,6 @@ class SparseBufferCtx {
matches_.emplace_back(axis->name == sp_iter_var->axis->name);
}
}

// update offset
PrimExpr new_offset = AggregateOffset(offsets_.back(), axis, std::move(coordinate), ana_);
offsets_.emplace_back(std::move(new_offset));
}

/*! \brief get the axis given dimension index of current buffer. */
Expand All @@ -341,7 +337,7 @@ class SparseBufferCtx {
AggregateOffset(add(offsets_[dim], 1), axis, Integer(0), ana_)};
}

private:
public:
String buf_name_;
Array<Axis> axes_;
std::vector<PrimExpr> offsets_;
Expand Down Expand Up @@ -375,7 +371,12 @@ class SparseBufferCtx {
top()->Register(dim, std::move(coordinate), std::move(orig_idx));
}

private:
void AddOffset(int dim, PrimExpr offset) {
ICHECK_EQ(dim + 1, static_cast<int>(top()->offsets_.size()));
top()->offsets_.push_back(offset);
}

public:
std::vector<Scope> stack_;
arith::Analyzer* ana_;

Expand Down Expand Up @@ -421,18 +422,22 @@ class IndexTransformer : public StmtExprMutator {
auto sf_axis = axis.as<SparseFixedAxisNode>();
PrimExpr l, r;
std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim);
offset = lower_bound(sf_axis->indices->data, coordinate, l, r);
offset = lower_bound(sf_axis->indices->data, coordinate, l, r) - l;
break;
}
case AxisKind::kSparseVariable:
auto sv_axis = axis.as<SparseVariableAxisNode>();
PrimExpr l, r;
std::tie(l, r) = sp_buf_ctx_.GetIndicesRange(dim);
offset = lower_bound(sv_axis->indices->data, coordinate, l, r);
offset = lower_bound(sv_axis->indices->data, coordinate, l, r) - l;
break;
}
}

// update offset
PrimExpr new_offset = AggregateOffset(sp_buf_ctx_.top()->offsets_.back(), axis,
offset, sp_buf_ctx_.ana_);
sp_buf_ctx_.top()->offsets_.push_back(std::move(new_offset));
return offset;
}

Expand Down Expand Up @@ -562,7 +567,8 @@ class IndexTransformer : public StmtExprMutator {
Axis axis = sp_it_var->axis;
auto parent = axis->GetParentAxis();
bool create_new_blk = false;
bool is_fixed_axis = axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed;
bool is_fixed_axis =
axis->kind() == AxisKind::kDenseFixed || axis->kind() == AxisKind::kSparseFixed;
if (!is_fixed_axis && parent.defined()) {
const AxisNode* parent_node = parent.value().get();
if (in_block.find(parent_node) != in_block.end()) {
Expand All @@ -572,7 +578,8 @@ class IndexTransformer : public StmtExprMutator {
/* parent node is in the previous blocks in the stack, no need to create new block. */
create_new_blk = false;
} else {
CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before " << axis->GetName() << " when defining a sparse block.";
CHECK(false) << "The parent axis of " << axis->GetName() << " should appear before "
<< axis->GetName() << " when defining a sparse block.";
}
}
if (create_new_blk) {
Expand Down
102 changes: 101 additions & 1 deletion tests/python/sparsetir/test_tir_sparse_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm.tir as tir
import scipy.sparse as sp
import numpy as np
import pytest
from tvm.script import tir as T


Expand Down Expand Up @@ -367,7 +368,7 @@ def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j:
J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32")
K_indptr = T.match_buffer(indptr_k, [nnz_j + 1], dtype="int32")
K_indices = T.match_buffer(indices_k, [nnz_k], dtype="int32")

for v_vi in T.serial(0, M):
with T.block("square_sum_2"):
vi = T.axis.spatial(M, v_vi)
Expand All @@ -391,6 +392,58 @@ def lowered_square_sum(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j:
B_data[vi] = B_data[vi] + A_data[K_indptr[J_indptr[vi] + vj] + vk]


@T.prim_func
def square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32):
# Used only for testing `GetIndicesRange()`.
# Currently it is ensured that `indptr_k0` is the same as `indptr_k1`, and `indices_k0` is the
# same as `indices_k1`.
I = T.dense_fixed(M)
J = T.sparse_variable(I, (N1, nnz_j), (indptr_j, indices_j), "int32")
K0 = T.sparse_variable(J, (N2, nnz_k), (indptr_k0, indices_k0), "int32")
K1 = T.sparse_variable(J, (N2, nnz_k), (indptr_k1, indices_k1), "int32")
A = T.match_sparse_buffer(a, (I, J, K0), "float32")
B = T.match_sparse_buffer(b, (I,), "float32")

with T.iter([I, J, K1], "SRR", "square_sum") as [vi, vj, vk]:
with T.init():
B[vi] = 0.0
B[vi] = B[vi] + A[vi, vj, vk]


@T.prim_func
def lowered_square_sum_two_K(a: T.handle, b: T.handle, indptr_j: T.handle, indices_j: T.handle, indptr_k0: T.handle, indices_k0: T.handle, indptr_k1: T.handle, indices_k1: T.handle, nnz_j: T.int32, nnz_k: T.int32, M: T.int32, N1: T.int32, N2: T.int32) -> None:
A_data = T.match_buffer(a, [nnz_k], dtype="float32")
B_data = T.match_buffer(b, [M], dtype="float32")
J_indptr = T.match_buffer(indptr_j, [M + 1], dtype="int32")
J_indices = T.match_buffer(indices_j, [nnz_j], dtype="int32")
K0_indptr = T.match_buffer(indptr_k0, [nnz_j + 1], dtype="int32")
K0_indices = T.match_buffer(indices_k0, [nnz_k], dtype="int32")
K1_indptr = T.match_buffer(indptr_k1, [nnz_j + 1], dtype="int32")
K1_indices = T.match_buffer(indices_k1, [nnz_k], dtype="int32")

for v_vi in T.serial(0, M):
with T.block("square_sum_2"):
vi = T.axis.spatial(M, v_vi)
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
T.writes([B_data[0 : M]])
T.block_attr({"sparse":True})
for v_vj in T.serial(0, J_indptr[v_vi + 1] - J_indptr[v_vi]):
with T.block("square_sum_1"):
vj = T.axis.reduce(J_indptr[v_vi + 1] - J_indptr[v_vi], v_vj)
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
T.writes([B_data[0 : M]])
T.block_attr({"sparse":True})
with T.init():
B_data[vi] = T.float32(0)
for v_vk in T.serial(0, K1_indptr[J_indptr[v_vi] + v_vj + 1] - K1_indptr[J_indptr[v_vi] + v_vj]):
with T.block("square_sum"):
vk = T.axis.reduce(K1_indptr[J_indptr[v_vi] + v_vj + 1] - K1_indptr[J_indptr[v_vi] + v_vj], v_vk)
T.reads([J_indptr[0 : M + 1], J_indices[0 : nnz_j], K0_indptr[0 : nnz_j + 1], K0_indices[0 : nnz_k], K1_indptr[0 : nnz_j + 1], K1_indices[0 : nnz_k], A_data[0 : nnz_k], B_data[0 : M]])
T.writes([B_data[0 : M]])
T.block_attr({"sparse":True})
B_data[vi] = B_data[vi] + A_data[T.tvm_lower_bound(K0_indices.data, K1_indices[K1_indptr[J_indptr[vi] + vj] + vk], K0_indptr[J_indptr[vi] + vj], K0_indptr[J_indptr[vi] + vj + 1], dtype="int32")]


def test_csrmm():
mod = tvm.IRModule.from_expr(csrmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
Expand All @@ -414,13 +467,15 @@ def test_csrmm():
tvm.testing.assert_allclose(y_ground_truth.reshape(-1), Y_nd.numpy(), rtol=1e-5, atol=1e-5)


@pytest.mark.skip(reason="Under implementation")
def test_csrmm_dense_iter():
mod = tvm.IRModule.from_expr(csrmm_dense_iter)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
# tvm.ir.assert_structural_equal(mod["main"], lowered_csrmm, True)
# Todo


@pytest.mark.skip(reason="Under implementation")
def test_segment_reduce():
mod = tvm.IRModule.from_expr(segment_reduce)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
Expand Down Expand Up @@ -557,6 +612,7 @@ def test_csr_element_wise():
tvm.testing.assert_allclose(b_ground_truth.data.reshape(-1), B_nd.numpy(), rtol=1e-5, atol=1e-5)


@pytest.mark.skip(reason="Under implementation")
def test_bmm():
mod = tvm.IRModule.from_expr(bmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
Expand Down Expand Up @@ -600,6 +656,49 @@ def test_square_sum():
tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5)


def test_square_sum_two_K():
mod = tvm.IRModule.from_expr(square_sum_two_K)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
tvm.ir.assert_structural_equal(mod["main"], lowered_square_sum_two_K, True)

sch = tir.Schedule(mod, debug_mask="all")
i, = sch.get_loops(sch.get_block("square_sum_2"))
sch.bind(i, "threadIdx.x")

density = 0.0125
M = N1 = N2 = 128
A_J = sp.random(M, N1, dtype="float32", density=1 - (1 - density) ** N2, format="csr")
indptr_j = A_J.indptr
indices_j = A_J.indices
nnz_j = A_J.nnz
A_K = sp.random(nnz_j, N2, dtype="float32", density=density, format="csr")
indptr_k = A_K.indptr
indices_k = A_K.indices
nnz_k = A_K.nnz
data = A_K.data

b_ij = np.asarray(A_K.sum(axis=1)).squeeze()
A_J = sp.csr_matrix((b_ij, indices_j, indptr_j), shape=(M, N1))
b_ground_truth = np.asarray(A_J.sum(axis=1)).squeeze()
b = np.zeros((M,)).astype("float32")

v_nnz_j, v_nnz_k, v_M, v_N1, v_N2 = square_sum_two_K.params[-5:]
f = tvm.build(sch.mod["main"].specialize({v_nnz_j: nnz_j, v_nnz_k: nnz_k, v_M: M, v_N1: N1, v_N2: N2}), target="cuda")

ctx = tvm.device("cuda")
A_data = tvm.nd.array(data.astype("float32"), device=ctx)
A_indptr_j = tvm.nd.array(indptr_j.astype("int32"), device=ctx)
A_indices_j = tvm.nd.array(indices_j.astype("int32"), device=ctx)
A_indptr_k0 = tvm.nd.array(indptr_k.astype("int32"), device=ctx)
A_indices_k0 = tvm.nd.array(indices_k.astype("int32"), device=ctx)
A_indptr_k1 = tvm.nd.array(indptr_k.astype("int32"), device=ctx)
A_indices_k1 = tvm.nd.array(indices_k.astype("int32"), device=ctx)
B_data = tvm.nd.array(b.astype("float32"), device=ctx)
f(A_data, B_data, A_indptr_j, A_indices_j, A_indptr_k0, A_indices_k0, A_indptr_k1, A_indices_k1)

tvm.testing.assert_allclose(b_ground_truth, B_data.numpy(), rtol=1e-5, atol=1e-5)


if __name__ == "__main__":
test_csrmm()
test_csrmm_dense_iter()
Expand All @@ -610,3 +709,4 @@ def test_square_sum():
test_csr_element_wise()
test_bmm()
test_square_sum()
test_square_sum_two_K()

0 comments on commit 144bb60

Please sign in to comment.