Skip to content

Commit

Permalink
Relax IR Parser (#6)
Browse files Browse the repository at this point in the history
* Copy jared's frontend

* Remove some extraneous code + add TODOs

* Skeleton AST

* Added more skeleton AST, worked on parsing shape annotations. Something is wrong with span_to_span

* Fix spans

* Type annotations parsing correctly

* some match_shape support

* More bug fixes! Some stuff parses. Importing into tests is messed up. We probably need to restructure this code as well.

* refactor parser and fill out more stubs

* some parser tests

* yolo dataflow

* checkpoint for rebase

* hook up AST

* add inline TIR parsing

* some cleanup

* support call_packed parsing to ExternFunc call

* remove stub ops

* improve docstrings

* address nits

* support coercing tuples to ShapeExpr when possible for call_dps

Co-authored-by: electriclilies <lilyorthsmith@gmail.com>
  • Loading branch information
2 people authored and junrushao committed Oct 14, 2022
1 parent 738ce46 commit fe23a23
Show file tree
Hide file tree
Showing 17 changed files with 1,535 additions and 228 deletions.
95 changes: 46 additions & 49 deletions include/tvm/relax/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ namespace relax {

using Expr = RelayExpr;
using ExprNode = RelayExprNode;
using relay::Id;
using relay::Call;
using relay::Id;
using relay::Tuple;
using relay::TupleGetItem;

Expand All @@ -53,8 +53,7 @@ class ShapeExprNode : public ExprNode {
}

bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const {
return equal(values, other->values) &&
equal(checked_type_, other->checked_type_) &&
return equal(values, other->values) && equal(checked_type_, other->checked_type_) &&
equal(shape_, other->shape_);
}

Expand All @@ -72,15 +71,15 @@ class ShapeExprNode : public ExprNode {

class ShapeExpr : public Expr {
public:
TVM_DLL ShapeExpr(Array<PrimExpr> values);
TVM_DLL explicit ShapeExpr(Array<PrimExpr> values, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, Expr, ShapeExprNode);
};


/*! \brief The variable class for all Relax bindings. */
class VarNode : public ExprNode {
public:
/*! \brief The identifier of the variable, is used for comparing stable equality across transformations. */
/*! \brief The identifier of the variable, which is used for comparing stable equality across
* transformations. */
Id vid;
/*! \brief The type annotation, used by binding sites and parameter declarations. */
runtime::Optional<Type> type_annotation;
Expand All @@ -97,11 +96,9 @@ class VarNode : public ExprNode {
}

bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
return equal(vid, other->vid) &&
equal(type_annotation, other->type_annotation) &&
return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) &&
// Do we use the analysis information in equality?
equal(checked_type_, other->checked_type_) &&
equal(shape_, other->shape_);
equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -120,16 +117,12 @@ class VarNode : public ExprNode {

class Var : public Expr {
public:
TVM_DLL Var(String name_hint,
runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation,
Span span = Span())
: Var(Id(name_hint), shape_annotation, type_annotation, span) {}

TVM_DLL Var(Id vid,
runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation,
Span span = Span());
TVM_DLL explicit Var(String name_hint, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span())
: Var(Id(name_hint), shape_annotation, type_annotation, span) {}

TVM_DLL explicit Var(Id vid, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
};

Expand All @@ -147,10 +140,8 @@ class DataflowVarNode : public VarNode {
}

bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const {
return equal(vid, other->vid) &&
equal(type_annotation, other->type_annotation) &&
equal(shape_, other->shape_) &&
equal(checked_type_, other->checked_type_);
return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) &&
equal(shape_, other->shape_) && equal(checked_type_, other->checked_type_);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -168,15 +159,22 @@ class DataflowVarNode : public VarNode {

class DataflowVar : public Var {
public:
using Var::Var; // inherit constructors from Var
TVM_DLL explicit DataflowVar(String name_hint, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span())
: DataflowVar(Id(name_hint), shape_annotation, type_annotation, span) {}

TVM_DLL explicit DataflowVar(Id vid, runtime::Optional<Expr> shape_annotation,
runtime::Optional<Type> type_annotation, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode);
};


/*! \brief The base class of a variable binding in Relax. */
class BindingNode : public Object {
public:
void VisitAttrs(AttrVisitor* v) {}
mutable Span span;

void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; }
void SHashReduce(SHashReducer hash_reduce) const {}

Expand All @@ -188,10 +186,10 @@ class BindingNode : public Object {

class Binding : public ObjectRef {
public:
TVM_DLL explicit Binding(Span span);
TVM_DEFINE_OBJECT_REF_METHODS(Binding, ObjectRef, BindingNode);
};


/*! \brief Symbolic shape match, binds the variables of the LHS with the rhs. */
class MatchShape;
class MatchShapeNode : public BindingNode {
Expand All @@ -202,6 +200,7 @@ class MatchShapeNode : public BindingNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const {
Expand All @@ -221,7 +220,7 @@ class MatchShapeNode : public BindingNode {

class MatchShape : public Binding {
public:
TVM_DLL MatchShape(Array<PrimExpr> pattern, Expr value);
TVM_DLL explicit MatchShape(Array<PrimExpr> pattern, Expr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode);
};

Expand All @@ -234,6 +233,7 @@ class VarBindingNode : public BindingNode {
void VisitAttrs(AttrVisitor* v) {
v->Visit("var", &var);
v->Visit("value", &value);
v->Visit("span", &span);
}

bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const {
Expand All @@ -251,23 +251,28 @@ class VarBindingNode : public BindingNode {

class VarBinding : public Binding {
public:
TVM_DLL VarBinding(Var var, Expr value);
TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode);
};


class BindingBlock;

class BindingBlockNode : public Object {
public:
mutable Span span;
Array<Binding> bindings;

void VisitAttrs(AttrVisitor* v) {
v->Visit("span", &span);
v->Visit("bindings", &bindings);
}

bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const {
return equal(bindings, other->bindings);
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }

static constexpr const char* _type_key = "relax.expr.BindingBlock";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand All @@ -276,21 +281,17 @@ class BindingBlockNode : public Object {

class BindingBlock : public ObjectRef {
public:
TVM_DLL BindingBlock(Array<Binding> bindings);
TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode);
};


class DataflowBlock;
class DataflowBlockNode : public BindingBlockNode {
public:
void VisitAttrs(AttrVisitor* v) {
v->Visit("bindings", &bindings);
}
bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const {
return equal(bindings, other->bindings);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }

static constexpr const char* _type_key = "relax.expr.DataflowBlock";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
Expand All @@ -299,7 +300,7 @@ class DataflowBlockNode : public BindingBlockNode {

class DataflowBlock : public BindingBlock {
public:
TVM_DLL DataflowBlock(Array<Binding> bindings);
TVM_DLL explicit DataflowBlock(Array<Binding> bindings, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode);
};

Expand Down Expand Up @@ -340,11 +341,10 @@ class SeqExprNode : public ExprNode {

class SeqExpr : public Expr {
public:
TVM_DLL SeqExpr(Array<BindingBlock> blocks, Expr body);
TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode);
};


/*! \brief A Relax function, eventually to replace the current Relay function definition. */
class FunctionNode : public BaseFuncNode {
public:
Expand Down Expand Up @@ -372,8 +372,7 @@ class FunctionNode : public BaseFuncNode {

bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
equal->MarkGraphNode();
return equal.DefEqual(params, other->params) &&
equal(body, other->body) &&
return equal.DefEqual(params, other->params) && equal(body, other->body) &&
equal(ret_type, other->ret_type) && equal(checked_type_, other->checked_type_) &&
equal(shape_, other->shape_);
}
Expand All @@ -396,12 +395,11 @@ class FunctionNode : public BaseFuncNode {

class Function : public Expr {
public:
TVM_DLL Function(runtime::Optional<GlobalVar> name, Array<Var> params,
Expr body, Type ret_type);
TVM_DLL explicit Function(runtime::Optional<GlobalVar> name, Array<Var> params, Expr body,
Type ret_type, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode);
};


/*! \brief The extern function, which can represent packed function. */
class ExternFuncNode : public BaseFuncNode {
public:
Expand All @@ -410,15 +408,14 @@ class ExternFuncNode : public BaseFuncNode {

void VisitAttrs(AttrVisitor* v) {
v->Visit("global_symbol", &global_symbol);
v->Visit("span", &span);
}

bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const {
return equal(global_symbol, other->global_symbol);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(global_symbol);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(global_symbol); }

static constexpr const char* _type_key = "relax.expr.ExternFunc";
static constexpr const bool _type_has_method_sequal_reduce = true;
Expand All @@ -428,7 +425,7 @@ class ExternFuncNode : public BaseFuncNode {

class ExternFunc : public Expr {
public:
TVM_DLL ExternFunc(String global_symbol);
TVM_DLL ExternFunc(String global_symbol, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode);
};

Expand Down
43 changes: 31 additions & 12 deletions include/tvm/relax/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ namespace relax {

class ShapeTypeNode : public TypeNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("span", &span);
}

bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const {
return true;
Expand All @@ -53,16 +56,9 @@ class ShapeTypeNode : public TypeNode {

class ShapeType : public Type {
public:
explicit ShapeType();
explicit ShapeType(runtime::ObjectPtr<runtime::Object> n) : Type(n) {}
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(ShapeType);
const ShapeTypeNode* operator->() const {
return static_cast<const ShapeTypeNode*>(data_.get());
}
const ShapeTypeNode* get() const {
return operator->();
}
using ContainerType = ShapeTypeNode;
TVM_DLL ShapeType(Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode);
};

class DynTensorTypeNode : public BaseTensorTypeNode {
Expand Down Expand Up @@ -108,11 +104,34 @@ class DynTensorType : public Type {
* \param shape The shape of the tensor.
* \param dtype The runtime dtype of the tensor's elements.
*/
TVM_DLL DynTensorType(int rank, DataType dtype);
TVM_DLL DynTensorType(int rank, DataType dtype, Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode);
};

class DimTypeNode : public TypeNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("span", &span);
}

bool SEqualReduce(const DimTypeNode* other, SEqualReducer equal) const {
return true;
}

void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }

static constexpr const char* _type_key = "relax.DimType";
TVM_DECLARE_FINAL_OBJECT_INFO(DimTypeNode, TypeNode);
};

class DimType : public Type {
public:
TVM_DLL DimType(Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(DimType, Type, DimTypeNode);
};

} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_TYPE_H_
1 change: 1 addition & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ class TupleNode : public ExprNode {
v->Visit("virtual_device_", &virtual_device_);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
v->Visit("shape_", &shape_);
}

bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import tvm
import tvm._ffi

from . import Span
from .base import Node
from . import _ffi_api

Expand Down Expand Up @@ -166,8 +167,8 @@ class TupleType(Type):
The fields in the tuple
"""

def __init__(self, fields):
self.__init_handle_by_constructor__(_ffi_api.TupleType, fields)
def __init__(self, fields, span: Span = None):
self.__init_handle_by_constructor__(_ffi_api.TupleType, fields, span)


@tvm._ffi.register_object("TypeConstraint")
Expand Down
Loading

0 comments on commit fe23a23

Please sign in to comment.