Skip to content

Commit

Permalink
[TIR][REFACTOR][API-Change] Migrate the tvm/tir/expr.h to construct s…
Browse files Browse the repository at this point in the history
…tyle. (#5773)

This PR migrate tvm/tir/expr.h to the new constructor style that is
consistent with the rest of the codebase and changes the affected files accordingly.
  • Loading branch information
tqchen authored Jun 11, 2020
1 parent 0abcad1 commit eafb2aa
Show file tree
Hide file tree
Showing 98 changed files with 1,808 additions and 1,530 deletions.
2 changes: 1 addition & 1 deletion docs/dev/relay_add_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ the arguments to the call node, as below.
TVM_REGISTER_GLOBAL("relay.op._make.add")
.set_body_typed<Expr(Expr, Expr)>([](Expr lhs, Expr rhs) {
static const Op& op = Op::Get("add");
return CallNode::make(op, {lhs, rhs}, Attrs(), {});
return Call(op, {lhs, rhs}, Attrs(), {});
});
Including a Python API Hook
Expand Down
2 changes: 1 addition & 1 deletion docs/dev/relay_add_pass.rst
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ the pass.
body.same_as(op->body)) {
return GetRef<Expr>(op);
} else {
return LetNode::make(var, value, body);
return Let(var, value, body);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion docs/dev/relay_pass_infra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ registration.
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});

auto y = relay::VarNode::make("y", tensor_type);
auto call = relay::CallNode::make(f, tvm::Array<relay::Expr>{ y });
auto call = relay::Call(f, tvm::Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});

// Create a module for optimization.
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ namespace relay {

// namespace update for backward compact
// will be removed later.
using Any = tvm::tir::AnyNode;
using AnyNode = tvm::tir::AnyNode;
using Any = tvm::tir::Any;
using Kind = TypeKind;
using Type = tvm::Type;
using TypeNode = tvm::TypeNode;
Expand Down
4 changes: 1 addition & 3 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -1337,9 +1337,6 @@ class String : public ObjectRef {
#endif
}

/*! \return the internal StringObj pointer */
const StringObj* get() const { return operator->(); }

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj);

private:
Expand Down Expand Up @@ -1502,6 +1499,7 @@ class Optional : public ObjectRef {
* otherwise return the default_value.
*/
T value_or(T default_value) const { return data_ != nullptr ? T(data_) : default_value; }

/*! \return Whether the container is not nullptr.*/
explicit operator bool() const { return *this != nullptr; }
// operator overloadings
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,7 @@ struct ObjectPtrEqual {
explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); } \
using ContainerType = ObjectName;

/*
Expand All @@ -713,6 +714,7 @@ struct ObjectPtrEqual {
explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \
const ObjectName* operator->() const { return static_cast<const ObjectName*>(data_.get()); } \
const ObjectName* get() const { return operator->(); } \
static constexpr bool _type_is_nullable = false; \
using ContainerType = ObjectName;

Expand Down
34 changes: 11 additions & 23 deletions include/tvm/te/tensor_intrin.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,24 +112,6 @@ inline const TensorIntrinNode* TensorIntrin::operator->() const {
return static_cast<const TensorIntrinNode*>(get());
}

// Internal node container of tensor intrinsic calling.
class TensorIntrinCallNode;

/*! \brief Tensor intrinsic calling node. */
class TensorIntrinCall : public ObjectRef {
public:
TensorIntrinCall() {}
explicit TensorIntrinCall(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
*/
inline const TensorIntrinCallNode* operator->() const;

/*! \brief specify container node */
using ContainerType = TensorIntrinCallNode;
};

class TensorIntrinCallNode : public Object {
public:
/*! \brief the tensor intrinsic */
Expand All @@ -155,16 +137,22 @@ class TensorIntrinCallNode : public Object {
v->Visit("reduce_axis", &reduce_axis);
v->Visit("scalar_inputs", &scalar_inputs);
}
static TensorIntrinCall make(TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions,
Array<IterVar> reduce_axis, Array<PrimExpr> scalar_inputs);

static constexpr const char* _type_key = "TensorIntrinCall";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorIntrinCallNode, Object);
};

inline const TensorIntrinCallNode* TensorIntrinCall::operator->() const {
return static_cast<const TensorIntrinCallNode*>(get());
}
/*!
* \brief Managed reference to TensorIntrinCallNode
* \sa TensorIntrinCallNode
*/
class TensorIntrinCall : public ObjectRef {
public:
TVM_DLL TensorIntrinCall(TensorIntrin intrin, Array<Tensor> tensors, Array<Region> regions,
Array<IterVar> reduce_axis, Array<PrimExpr> scalar_inputs);

TVM_DEFINE_OBJECT_REF_METHODS(TensorIntrinCall, ObjectRef, TensorIntrinCallNode);
};

} // namespace te
} // namespace tvm
Expand Down
Loading

0 comments on commit eafb2aa

Please sign in to comment.