Skip to content

Commit

Permalink
[IR][SIBuilder]
Browse files Browse the repository at this point in the history
- Add SIBuilder to handle the span propagation between passes
- Add SequentialSpan for multiple source expressions conversion between
passes
- Add test cases for SIBuilder and SequentialSpan
  • Loading branch information
Joey Tsai committed May 26, 2023
1 parent 3a15eaf commit a2325ec
Show file tree
Hide file tree
Showing 9 changed files with 965 additions and 2 deletions.
103 changes: 103 additions & 0 deletions include/tvm/ir/si_builder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/si_builder.h
* \brief build a source info during rewriting expressions.
*/
#ifndef TVM_IR_SI_BUILDER_H_
#define TVM_IR_SI_BUILDER_H_

#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/tir/stmt.h>

#include <memory>
#include <unordered_set>

namespace tvm {

/*!
* \brief SIBuilder provides helper APIs for filling spans,
* particularly useful for one-to-many, many-to-one and many-to-many pass transformations.
*/
class SIBuilder {
public:
/*!
* \brief Create SIBuilder from a given span
*/
explicit SIBuilder(const Span& span = Span());

/*!
* \brief Create SIBuilder from a given span sequence
*/
explicit SIBuilder(const Array<Span>& spans = Array<Span>());
explicit SIBuilder(const std::initializer_list<Span>& init);

/*!
* \brief Create SIBuilder via a subgraph,
* Will construct span based on the exprs in the subgraph. Including the inputs exprs.
*
* \param entry Entry expr for subgraph
* \param inputs End exprs for subgraph
*/
template <typename T, typename = std::enable_if_t<std::is_base_of<BaseExpr, T>::value>>
explicit SIBuilder(const T& entry, const tvm::Array<T>& inputs = {});
explicit SIBuilder(const tir::Stmt& entry, const tvm::Array<PrimExpr>& inputs = {});
explicit SIBuilder(const tir::Stmt& entry, const tvm::Array<tir::Stmt>& inputs = {});

~SIBuilder();

SIBuilder(const SIBuilder&) = delete;
SIBuilder& operator=(const SIBuilder&) = delete;

/*!
* \brief create new source info based on the given span or subgraph.
*
* \return The given span, or reconstructed span from subgraph.
*/
Span CreateSpan() const;

/*!
* \brief Recursively fill all span of exprs in subgraph from entry until inputs.
*
* \param entry Entry expr for subgraph.
* \param inputs End exprs for subgraph, will not be filled with new span.
*/
template <typename T, typename = std::enable_if_t<std::is_base_of<BaseExpr, T>::value>>
void RecursivelyFillSpan(
const T& entry, const std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>& inputs) const;

void RecursivelyFillSpan(
const tir::Stmt& entry,
const std::unordered_set<PrimExpr, ObjectPtrHash, ObjectPtrEqual>& inputs) const;
void RecursivelyFillSpan(
const tir::Stmt& entry,
const std::unordered_set<tir::Stmt, ObjectPtrHash, ObjectPtrEqual>& inputs) const;

private:
struct Impl;
std::unique_ptr<Impl> impl_;

std::unique_ptr<Impl> CreateImpl(const Span& span);
};

} // namespace tvm

#endif // TVM_IR_SI_BUILDER_H_
46 changes: 45 additions & 1 deletion include/tvm/ir/source_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class SpanNode : public Object {
}

static constexpr const char* _type_key = "Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
TVM_DECLARE_BASE_OBJECT_INFO(SpanNode, Object);
};

class Span : public ObjectRef {
Expand All @@ -127,6 +127,50 @@ class Span : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};

/*!
* \brief Store a list of spans for an expr generated from mulitple source exprs
*/
class SequentialSpanNode : public SpanNode {
public:
/*! \brief The original source list of spans to construct a sequential span. */
Array<Span> spans;

// override attr visitor
void VisitAttrs(AttrVisitor* v) {
SpanNode::VisitAttrs(v);
v->Visit("spans", &spans);
}

static constexpr const char* _type_key = "SequentialSpan";
TVM_DECLARE_FINAL_OBJECT_INFO(SequentialSpanNode, SpanNode);

bool SEqualReduce(const SequentialSpanNode* other, SEqualReducer equal) const {
if (spans.size() != other->spans.size()) {
return false;
}

for (size_t i = 0, e = spans.size(); i != e; ++i) {
if (!StructuralEqual()(spans[i], other->spans[i])) {
return false;
}
}
return true;
}
};

/*!
* \brief Reference class of SequentialSpanNode.
* \sa SequentialSpanNode
*/
class SequentialSpan : public Span {
public:
TVM_DLL SequentialSpan(Array<Span> spans);

TVM_DLL SequentialSpan(std::initializer_list<Span> init);

TVM_DEFINE_OBJECT_REF_METHODS(SequentialSpan, Span, SequentialSpanNode);
};

/*! \brief A program source in any language.
*
* Could represent the source from an ML framework or a source
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Node,
SourceName,
Span,
SequentialSpan,
assert_structural_equal,
load_json,
save_json,
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ def __init__(self, source_name, line, end_line, column, end_column):
)


@register_object("SequentialSpan")
class SequentialSpan(Object):
"""Specifies a location in a source program.
Parameters
----------
spans : Array
The array of spans.
"""

def __init__(self, spans):
self.__init_handle_by_constructor__(_ffi_api.SequentialSpan, spans)


@register_object
class EnvFunc(Object):
"""Environment function.
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@

# Span
Span = base.Span
SequentialSpan = base.SequentialSpan
SourceName = base.SourceName

# Type
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import tvm._ffi
from tvm.ir import Node as RelayNode
from tvm.ir import SourceName, Span
from tvm.ir import SourceName, Span, SequentialSpan
from tvm.runtime import Object

from . import _ffi_api
Expand Down
Loading

0 comments on commit a2325ec

Please sign in to comment.