Skip to content

Commit

Permalink
[IR][SIBuilder]
Browse files Browse the repository at this point in the history
- Make null implementation as base class
- Add comments and change naming based on reviewing
  • Loading branch information
Joey Tsai committed May 26, 2023
1 parent a2325ec commit 13a3bc7
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 84 deletions.
10 changes: 5 additions & 5 deletions include/tvm/ir/si_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
namespace tvm {

/*!
* \brief SIBuilder provides helper APIs for filling spans,
* particularly useful for one-to-many, many-to-one and many-to-many pass transformations.
* \brief Source Information Builder, SIBuilder provides helper APIs for filling spans,
* particularly useful for one-to-many, many-to-one and many-to-many IR transformations.
*/
class SIBuilder {
public:
Expand Down Expand Up @@ -68,11 +68,11 @@ class SIBuilder {
SIBuilder& operator=(const SIBuilder&) = delete;

/*!
* \brief create new source info based on the given span or subgraph.
* \brief build a span of source information, which is based on the given span or subgraph.
*
* \return The given span, or reconstructed span from subgraph.
* \return the built span
*/
Span CreateSpan() const;
Span Build() const;

/*!
* \brief Recursively fill all span of exprs in subgraph from entry until inputs.
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def __init__(self, source_name, line, end_line, column, end_column):

@register_object("SequentialSpan")
class SequentialSpan(Object):
"""Specifies a location in a source program.
"""A sequence of source spans
This span is specific for an expression, which is from multiple expressions
after an IR transform.
Parameters
----------
Expand Down
114 changes: 49 additions & 65 deletions src/ir/si_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ using RelayExprSet = std::unordered_set<relay::Expr, ObjectPtrHash, ObjectPtrEqu
using PrimExprSet = std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>;
using StmtSet = std::unordered_set<tir::Stmt, ObjectPtrHash, ObjectPtrEqual>;

class RelayCollapse : public relay::ExprVisitor {
class RelayCollectSpans : public relay::ExprVisitor {
public:
explicit RelayCollapse(const RelayExprSet& inputs = {}) : inputs_(inputs) {}
explicit RelayCollectSpans(const RelayExprSet& inputs = {}) : inputs_(inputs) {}

Span Collapse(const relay::Expr& entry);
// From entry to inputs, recursively collect spans. The spans of inputs are included.
Span CollectSpans(const relay::Expr& entry);

void VisitExpr(const relay::Expr& expr) final;

Expand All @@ -46,7 +47,7 @@ class RelayCollapse : public relay::ExprVisitor {
const RelayExprSet& inputs_;
};

void RelayCollapse::VisitExpr(const relay::Expr& expr) {
void RelayCollectSpans::VisitExpr(const relay::Expr& expr) {
if (visit_counter_.count(expr.get())) {
return;
}
Expand All @@ -61,7 +62,7 @@ void RelayCollapse::VisitExpr(const relay::Expr& expr) {
relay::ExprVisitor::VisitExpr(expr);
}

Span RelayCollapse::Collapse(const relay::Expr& entry) {
Span RelayCollectSpans::CollectSpans(const relay::Expr& entry) {
VisitExpr(entry);
return SequentialSpan(spans_);
}
Expand All @@ -71,6 +72,7 @@ class RelayRecursivelyFill : public relay::ExprMutator {
explicit RelayRecursivelyFill(const Span& span, const RelayExprSet& inputs = {})
: span_(span), inputs_(inputs) {}

// From entry until inputs, recursively fill spans into expressions. Inputs are not filled.
void Fill(const relay::Expr& entry);

relay::Expr VisitExpr(const relay::Expr& expr) final;
Expand All @@ -94,9 +96,9 @@ relay::Expr RelayRecursivelyFill::VisitExpr(const relay::Expr& expr) {

void RelayRecursivelyFill::Fill(const relay::Expr& entry) { Mutate(entry); }

class TirCollapse : public tir::StmtExprVisitor {
class TirCollectSpans : public tir::StmtExprVisitor {
public:
explicit TirCollapse(const PrimExprSet& expr_inputs = {}, const StmtSet& stmt_inputs = {})
explicit TirCollectSpans(const PrimExprSet& expr_inputs = {}, const StmtSet& stmt_inputs = {})
: expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {}

void VisitExpr(const PrimExpr& expr) final;
Expand All @@ -105,8 +107,10 @@ class TirCollapse : public tir::StmtExprVisitor {
bool IsInput(const PrimExpr& expr);
bool IsInput(const tir::Stmt& stmt);

Span Collapse(const PrimExpr& expr);
Span Collapse(const tir::Stmt& stmt);
// From entry to inputs, recursively collect spans. The spans of inputs are included.
Span CollectSpans(const PrimExpr& expr);
// From entry to inputs, recursively collect spans. The spans of inputs are included.
Span CollectSpans(const tir::Stmt& stmt);

private:
Array<Span> spans_;
Expand All @@ -115,25 +119,25 @@ class TirCollapse : public tir::StmtExprVisitor {
const StmtSet& stmt_inputs_;
};

Span TirCollapse::Collapse(const PrimExpr& expr) {
Span TirCollectSpans::CollectSpans(const PrimExpr& expr) {
operator()(expr);
return SequentialSpan(spans_);
}

Span TirCollapse::Collapse(const tir::Stmt& stmt) {
Span TirCollectSpans::CollectSpans(const tir::Stmt& stmt) {
operator()(stmt);
return SequentialSpan(spans_);
}

bool TirCollapse::IsInput(const PrimExpr& expr) {
bool TirCollectSpans::IsInput(const PrimExpr& expr) {
return expr_inputs_.find(expr) != expr_inputs_.end();
}

bool TirCollapse::IsInput(const tir::Stmt& stmt) {
bool TirCollectSpans::IsInput(const tir::Stmt& stmt) {
return stmt_inputs_.find(stmt) != stmt_inputs_.end();
}

void TirCollapse::VisitExpr(const PrimExpr& expr) {
void TirCollectSpans::VisitExpr(const PrimExpr& expr) {
if (visit_counter_.count(expr.get())) {
return;
}
Expand All @@ -148,7 +152,7 @@ void TirCollapse::VisitExpr(const PrimExpr& expr) {
StmtExprVisitor::VisitExpr(expr);
}

void TirCollapse::VisitStmt(const tir::Stmt& stmt) {
void TirCollectSpans::VisitStmt(const tir::Stmt& stmt) {
if (visit_counter_.count(stmt.get())) {
return;
}
Expand All @@ -169,7 +173,9 @@ class TirRecursivelyFill : public tir::StmtExprMutator {
const StmtSet& stmt_inputs = {})
: span_(span), expr_inputs_(expr_inputs), stmt_inputs_(stmt_inputs) {}

// From entry until inputs, recursively fill spans into expressions. Inputs are not filled.
tir::Stmt Fill(const tir::Stmt& s) { return operator()(s); }
// From entry until inputs, recursively fill spans into expressions. Inputs are not filled.
PrimExpr Fill(const PrimExpr& e) { return operator()(e); }

bool IsInput(const PrimExpr& expr);
Expand Down Expand Up @@ -209,20 +215,20 @@ PrimExpr TirRecursivelyFill::VisitExpr(const PrimExpr& expr) {
}

struct SIBuilder::Impl {
virtual Span CreateSpan() const = 0;
virtual void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const = 0;
virtual void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const = 0;
virtual void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const = 0;
virtual void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const = 0;
virtual void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) = 0;
virtual void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) = 0;
virtual void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) = 0;
virtual void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) = 0;
virtual Span Build() const { return Span(); }
virtual void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const {}
virtual void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const {}
virtual void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const {}
virtual void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const {}
virtual void CollectSpansSpan(const relay::Expr& entry, const RelayExprSet& inputs) {}
virtual void CollectSpansSpan(const PrimExpr& entry, const PrimExprSet& inputs) {}
virtual void CollectSpansSpan(const tir::Stmt& entry, const PrimExprSet& inputs) {}
virtual void CollectSpansSpan(const tir::Stmt& entry, const StmtSet& inputs) {}
};

SIBuilder::~SIBuilder() = default;

Span SIBuilder::CreateSpan() const { return impl_->CreateSpan(); }
Span SIBuilder::Build() const { return impl_->Build(); }

template <>
void SIBuilder::RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const {
Expand All @@ -243,54 +249,32 @@ void SIBuilder::RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& input
}

std::unique_ptr<SIBuilder::Impl> SIBuilder::CreateImpl(const Span& span) {
struct NullImpl : public SIBuilder::Impl {
Span CreateSpan() const final { return Span(); }

void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const final{};
void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const final{};
void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const final{};
void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const final{};
void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) final{};
void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) final{};
void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final{};
void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) final{};
};

struct Impl : public SIBuilder::Impl {
explicit Impl(const Span& span) : span_(span) {}

Span CreateSpan() const final { return span_; }

Span Build() const final { return span_; }
void RecursivelyFillSpan(const relay::Expr& entry, const RelayExprSet& inputs) const final {
RelayRecursivelyFill(CreateSpan(), inputs).Fill(entry);
RelayRecursivelyFill(Build(), inputs).Fill(entry);
}

void RecursivelyFillSpan(const PrimExpr& entry, const PrimExprSet& inputs) const final {
TirRecursivelyFill(CreateSpan(), inputs).Fill(entry);
TirRecursivelyFill(Build(), inputs).Fill(entry);
}

void RecursivelyFillSpan(const tir::Stmt& entry, const PrimExprSet& inputs) const final {
TirRecursivelyFill(CreateSpan(), inputs).Fill(entry);
TirRecursivelyFill(Build(), inputs).Fill(entry);
}

void RecursivelyFillSpan(const tir::Stmt& entry, const StmtSet& inputs) const final {
TirRecursivelyFill(CreateSpan(), {}, inputs).Fill(entry);
TirRecursivelyFill(Build(), {}, inputs).Fill(entry);
}

void CollapseSpan(const relay::Expr& entry, const RelayExprSet& inputs) final {
span_ = RelayCollapse(inputs).Collapse(entry);
void CollectSpansSpan(const relay::Expr& entry, const RelayExprSet& inputs) final {
span_ = RelayCollectSpans(inputs).CollectSpans(entry);
}

void CollapseSpan(const PrimExpr& entry, const PrimExprSet& inputs) final {
span_ = TirCollapse(inputs).Collapse(entry);
void CollectSpansSpan(const PrimExpr& entry, const PrimExprSet& inputs) final {
span_ = TirCollectSpans(inputs).CollectSpans(entry);
}

void CollapseSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final {
span_ = TirCollapse(inputs).Collapse(entry);
void CollectSpansSpan(const tir::Stmt& entry, const PrimExprSet& inputs) final {
span_ = TirCollectSpans(inputs).CollectSpans(entry);
}

void CollapseSpan(const tir::Stmt& entry, const StmtSet& inputs) final {
span_ = TirCollapse({}, inputs).Collapse(entry);
void CollectSpansSpan(const tir::Stmt& entry, const StmtSet& inputs) final {
span_ = TirCollectSpans({}, inputs).CollectSpans(entry);
}

private:
Expand All @@ -305,7 +289,7 @@ std::unique_ptr<SIBuilder::Impl> SIBuilder::CreateImpl(const Span& span) {
return std::make_unique<Impl>(span);
}

return std::make_unique<NullImpl>();
return std::make_unique<SIBuilder::Impl>();
}

SIBuilder::SIBuilder(const Span& span) : impl_(CreateImpl(span)) {}
Expand All @@ -316,23 +300,23 @@ SIBuilder::SIBuilder(const std::initializer_list<Span>& init)
template <>
SIBuilder::SIBuilder(const relay::Expr& expr, const Array<relay::Expr>& inputs)
: impl_(CreateImpl(Span())) {
impl_->CollapseSpan(expr, RelayExprSet(inputs.begin(), inputs.end()));
impl_->CollectSpansSpan(expr, RelayExprSet(inputs.begin(), inputs.end()));
}

template <>
SIBuilder::SIBuilder(const PrimExpr& expr, const Array<PrimExpr>& inputs)
: impl_(CreateImpl(Span())) {
impl_->CollapseSpan(expr, PrimExprSet(inputs.begin(), inputs.end()));
impl_->CollectSpansSpan(expr, PrimExprSet(inputs.begin(), inputs.end()));
}

SIBuilder::SIBuilder(const tir::Stmt& s, const Array<PrimExpr>& inputs)
: impl_(CreateImpl(Span())) {
impl_->CollapseSpan(s, PrimExprSet(inputs.begin(), inputs.end()));
impl_->CollectSpansSpan(s, PrimExprSet(inputs.begin(), inputs.end()));
}

SIBuilder::SIBuilder(const tir::Stmt& s, const Array<tir::Stmt>& inputs)
: impl_(CreateImpl(Span())) {
impl_->CollapseSpan(s, StmtSet(inputs.begin(), inputs.end()));
impl_->CollectSpansSpan(s, StmtSet(inputs.begin(), inputs.end()));
}

// Register build pipeline related options
Expand Down
26 changes: 13 additions & 13 deletions tests/cpp/si_builder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ TEST(SIBuilder, CreateSapn) {
Span span_1 = _CreateSpan("first");
{
SIBuilder si_builder(span_1);
EXPECT_EQ(span_1, si_builder.CreateSpan());
EXPECT_EQ(span_1, si_builder.Build());
}

Span span_2 = _CreateSpan("second");
Expand All @@ -114,9 +114,9 @@ TEST(SIBuilder, CreateSapn) {
SIBuilder si_builder_2({span_1, span_2});
SIBuilder si_builder_3{span_1, span_2};

Span created_span_1 = si_builder_1.CreateSpan();
Span created_span_2 = si_builder_2.CreateSpan();
Span created_span_3 = si_builder_3.CreateSpan();
Span created_span_1 = si_builder_1.Build();
Span created_span_2 = si_builder_2.Build();
Span created_span_3 = si_builder_3.Build();

auto created_seq_span_1 = created_span_1.as<SequentialSpanNode>();
auto created_seq_span_2 = created_span_2.as<SequentialSpanNode>();
Expand All @@ -140,7 +140,7 @@ TEST(SIBuilder, DisableSIBuilder) {
Span span_1 = _CreateSpan("first");
{
SIBuilder si_builder(span_1);
EXPECT_NE(span_1, si_builder.CreateSpan());
EXPECT_NE(span_1, si_builder.Build());
}
}

Expand Down Expand Up @@ -179,7 +179,7 @@ TEST(SIBuilder, RelayRecursivelyFill) {
checker.Check(z, expected_z);
}

TEST(SIBuilder, RelayCollapse) {
TEST(SIBuilder, RelayCollectSpans) {
using namespace tvm;
auto pass_ctx = transform::PassContext::Create();
pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
Expand All @@ -206,15 +206,15 @@ TEST(SIBuilder, RelayCollapse) {
relay::Expr z = relay::Call(add_op, {y, x}, tvm::Attrs(), {}, z_node_span);

SIBuilder si_builder(z, {a});
Span created_span = si_builder.CreateSpan();
Span created_span = si_builder.Build();
auto created_seq_span = created_span.as<SequentialSpanNode>();
EXPECT_EQ(created_seq_span->spans.size(), 4);
for (std::size_t i = 0; i != created_seq_span->spans.size(); i++) {
EXPECT_TRUE(StructuralEqual()(created_seq_span->spans[i], target[i]));
}
}

TEST(SIBuilder, TirCollapsePrimExpr) {
TEST(SIBuilder, TirCollectSpansPrimExpr) {
using namespace tvm;
auto pass_ctx = transform::PassContext::Create();
pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
Expand All @@ -241,7 +241,7 @@ TEST(SIBuilder, TirCollapsePrimExpr) {
z->span = z_node_span;

SIBuilder si_builder(z, {x});
Span created_span = si_builder.CreateSpan();
Span created_span = si_builder.Build();
auto created_seq_span = created_span.as<SequentialSpanNode>();

EXPECT_EQ(created_seq_span->spans.size(), 4);
Expand All @@ -250,7 +250,7 @@ TEST(SIBuilder, TirCollapsePrimExpr) {
}
}

TEST(SIBuilder, TirCollapseStmtWithPrimInput) {
TEST(SIBuilder, TirCollectSpansStmtWithPrimInput) {
using namespace tvm;
auto pass_ctx = transform::PassContext::Create();
pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
Expand All @@ -274,7 +274,7 @@ TEST(SIBuilder, TirCollapseStmtWithPrimInput) {
auto stmt = fmaketest();
stmt->span = stmt_node_span;
SIBuilder si_builder(stmt, {x});
Span created_span = si_builder.CreateSpan();
Span created_span = si_builder.Build();
auto created_seq_span = created_span.as<SequentialSpanNode>();

EXPECT_EQ(created_seq_span->spans.size(), 3);
Expand All @@ -283,7 +283,7 @@ TEST(SIBuilder, TirCollapseStmtWithPrimInput) {
}
}

TEST(SIBuilder, TirCollapseStmtWithStmtInput) {
TEST(SIBuilder, TirCollectSpansStmtWithStmtInput) {
using namespace tvm;
auto pass_ctx = transform::PassContext::Create();
pass_ctx->config.Set("ir.enable_si_builder", Bool(true));
Expand All @@ -300,7 +300,7 @@ TEST(SIBuilder, TirCollapseStmtWithStmtInput) {
tir::Block block({}, {}, {}, "block", body, init, Array<tir::Buffer>(),
Array<tir::MatchBufferRegion>(), Map<String, ObjectRef>(), block_node_span);
SIBuilder si_builder(block, {init});
Span created_span = si_builder.CreateSpan();
Span created_span = si_builder.Build();
auto created_seq_span = created_span.as<SequentialSpanNode>();

EXPECT_EQ(created_seq_span->spans.size(), 3);
Expand Down

0 comments on commit 13a3bc7

Please sign in to comment.