From d5df43067925a34f7f1b19f00664c29b0b3d7557 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 13 Jun 2020 09:09:00 -0700 Subject: [PATCH] [TIR][REFACTIR] Update TIR nodes std::string->String. (#5793) This PR updates the remaining TIR node's member to use String instead of std::string. --- include/tvm/ir/adt.h | 2 +- include/tvm/ir/attrs.h | 8 +++--- include/tvm/ir/env_func.h | 2 +- include/tvm/ir/op.h | 6 ++--- include/tvm/ir/transform.h | 6 ++--- include/tvm/node/reflection.h | 2 +- include/tvm/relay/op_strategy.h | 2 +- include/tvm/relay/transform.h | 2 +- include/tvm/runtime/container.h | 30 ++++++++++++++++++++--- include/tvm/tir/buffer.h | 8 +++--- include/tvm/tir/data_layout.h | 2 +- include/tvm/tir/expr.h | 8 +++--- include/tvm/tir/stmt.h | 4 +-- include/tvm/tir/transform.h | 10 ++++---- include/tvm/tir/var.h | 4 +-- python/tvm/ir/json_compact.py | 6 +++++ src/ir/op.cc | 5 ++-- src/ir/transform.cc | 5 ++-- src/node/reflection.cc | 6 ++--- src/relay/ir/transform.cc | 2 +- src/support/ffi_testing.cc | 2 +- src/target/llvm/codegen_cpu.cc | 6 +++-- src/target/llvm/codegen_llvm.cc | 3 ++- src/target/llvm/llvm_common.h | 7 ++++++ src/tir/ir/buffer.cc | 6 ++--- src/tir/ir/expr.cc | 14 +++++------ src/tir/ir/stmt.cc | 4 +-- src/tir/ir/transform.cc | 2 +- src/tir/transforms/inject_copy_intrin.cc | 2 +- src/tir/transforms/lift_attr_scope.cc | 2 +- src/tir/transforms/thread_storage_sync.cc | 2 +- 31 files changed, 102 insertions(+), 68 deletions(-) diff --git a/include/tvm/ir/adt.h b/include/tvm/ir/adt.h index 9b45c66dc76b..466a4f00fd5f 100644 --- a/include/tvm/ir/adt.h +++ b/include/tvm/ir/adt.h @@ -45,7 +45,7 @@ namespace tvm { class ConstructorNode : public RelayExprNode { public: /*! \brief The name (only a hint) */ - std::string name_hint; + String name_hint; /*! \brief Input to the constructor. */ Array inputs; /*! \brief The datatype the constructor will construct. */ diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index eda3537bc374..4cdf8c5cbe94 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -27,7 +27,7 @@ * struct MyAttrs : public tvm::AttrsNode { * float learning_rate; * int num_hidden; - * std::string name; + * String name; * // declare attribute fields in header file * TVM_DECLARE_ATTRS(MyAttrs, "attrs.MyAttrs") { * TVM_ATTR_FIELD(num_hidden).set_lower_bound(1); @@ -106,11 +106,11 @@ struct AttrError : public dmlc::Error { class AttrFieldInfoNode : public Object { public: /*! \brief name of the field */ - std::string name; + String name; /*! \brief type docstring information in str. */ - std::string type_info; + String type_info; /*! \brief detailed description of the type */ - std::string description; + String description; void VisitAttrs(AttrVisitor* v) { v->Visit("name", &name); diff --git a/include/tvm/ir/env_func.h b/include/tvm/ir/env_func.h index 2f803672a20b..65653b75562d 100644 --- a/include/tvm/ir/env_func.h +++ b/include/tvm/ir/env_func.h @@ -41,7 +41,7 @@ namespace tvm { class EnvFuncNode : public Object { public: /*! \brief Unique name of the global function */ - std::string name; + String name; /*! \brief The internal packed function */ runtime::PackedFunc func; /*! \brief constructor */ diff --git a/include/tvm/ir/op.h b/include/tvm/ir/op.h index 8fc96a43eae3..2bc2c90c7854 100644 --- a/include/tvm/ir/op.h +++ b/include/tvm/ir/op.h @@ -58,21 +58,21 @@ class OpAttrMap; class OpNode : public RelayExprNode { public: /*! \brief name of the operator */ - std::string name; + String name; /*! \brief the type of the operator */ mutable FuncType op_type; /*! * \brief detailed description of the operator * This can be used to generate docstring automatically for the operator. */ - std::string description; + String description; /* \brief Information of input arguments to the operator */ Array arguments; /*! * \brief The type key of the attribute field * This can be empty, in which case it defaults to anything. */ - std::string attrs_type_key; + String attrs_type_key; /*! * \brief attribute type index, * this field varies in each run and is not exposed to frontend. diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index ffe37074ecae..5bfb51adb0ac 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -253,10 +253,10 @@ class PassInfoNode : public Object { int opt_level; /*! \brief The name of an optimization/analysis pass. */ - std::string name; + String name; /*! \brief The passes that are required to perform the current pass. */ - Array required; + Array required; PassInfoNode() = default; @@ -407,7 +407,7 @@ class Sequential : public Pass { */ TVM_DLL Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, const String& name, const Array& required); + int opt_level, String name, Array required); /*! * \brief A special trace pass that prints the header and IR to LOG(INFO). diff --git a/include/tvm/node/reflection.h b/include/tvm/node/reflection.h index 832053fb62d9..e8ff26be42b3 100644 --- a/include/tvm/node/reflection.h +++ b/include/tvm/node/reflection.h @@ -169,7 +169,7 @@ class ReflectionVTable { * \return The corresponding attribute value. * \note This function will throw an exception if the object does not contain the field. */ - TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const std::string& attr_name) const; + TVM_DLL runtime::TVMRetValue GetAttr(Object* self, const String& attr_name) const; /*! * \brief List all the fields in the object. diff --git a/include/tvm/relay/op_strategy.h b/include/tvm/relay/op_strategy.h index c70da6221984..c5785369f8d5 100644 --- a/include/tvm/relay/op_strategy.h +++ b/include/tvm/relay/op_strategy.h @@ -46,7 +46,7 @@ class OpImplementationNode : public Object { /*! \brief Schedule function */ FTVMSchedule fschedule; /*! \brief Name of the implementation */ - std::string name; + String name; /*! \brief Priority level */ int plevel; diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 39212d8fd1e5..b287c053e8a9 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -58,7 +58,7 @@ using Sequential = tvm::transform::Sequential; */ TVM_DLL Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, const String& name, const tvm::Array& required); + int opt_level, String name, tvm::Array required); /*! \brief Remove expressions which does not effect the program result. * diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 6753ec779b55..36e2e8f5f276 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -65,6 +65,11 @@ #include #include +namespace llvm { +// String to llvm object compatibility. +class StringRef; +} // namespace llvm + namespace tvm { struct ObjectEqual; @@ -1161,7 +1166,14 @@ class String : public ObjectRef { * \param other The value for the new String * */ - inline String operator=(std::string other); + inline String& operator=(std::string other); + + /*! + * \brief Change the value the reference object points to. + * + * \param other The value for the new String + */ + inline String& operator=(const char* other); /*! * \brief Compare is less than other std::string @@ -1304,12 +1316,20 @@ class String : public ObjectRef { const char* data() const { return get()->data; } /*! - * \brief Convert String to an std::sting object + * \brief Convert String to an std::string object * * \return std::string */ operator std::string() const { return std::string{get()->data, size()}; } + // LLVM compatibility function, implemented in src/target/llvm/llvm_common.h + /*! + * \brief Convert String to an llvm::StringRef object + * + * \return llvm::StringRef + */ + inline operator llvm::StringRef() const; + /*! * \brief Check if a TVMArgValue can be converted to String, i.e. it can be std::string or String * \param val The value to be checked @@ -1382,12 +1402,14 @@ inline String::String(std::string other) { data_ = std::move(ptr); } -inline String String::operator=(std::string other) { +inline String& String::operator=(std::string other) { String replace{std::move(other)}; data_.swap(replace.data_); - return Downcast(*this); + return *this; } +inline String& String::operator=(const char* other) { return operator=(std::string(other)); } + inline String operator+(const std::string lhs, const String& rhs) { return lhs + rhs.operator std::string(); } diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 5b07cc5ce7d6..34b0155a07ac 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -65,9 +65,9 @@ class BufferNode : public Object { PrimExpr elem_offset; // Meta data /*! \brief optional name of the buffer */ - std::string name; + String name; /*! \brief storage scope of the buffer, if other than global */ - std::string scope; + String scope; /*! \brief Alignment requirement of data pointer in bytes. */ int data_alignment; /*! @@ -134,7 +134,7 @@ class Buffer : public ObjectRef { // User can specify data_alignment and offset_factor to be 0 // A default value will be picked. TVM_DLL Buffer(Var ptr, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, std::string name, std::string scope, int data_alignment, + PrimExpr elem_offset, String name, String scope, int data_alignment, int offset_factor, BufferType buffer_type); /*! @@ -187,7 +187,7 @@ class Buffer : public ObjectRef { * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - std::string name = "buffer"); + String name = "buffer"); /*! * \brief Base node for data producers. diff --git a/include/tvm/tir/data_layout.h b/include/tvm/tir/data_layout.h index f705247f6986..b7cb68688066 100644 --- a/include/tvm/tir/data_layout.h +++ b/include/tvm/tir/data_layout.h @@ -98,7 +98,7 @@ class LayoutAxis { class LayoutNode : public Object { public: /*! \brief string representation of layout, "" for scalar. */ - std::string name; + String name; /*! \brief specify each axis of the layout, * in which the variable name is the name of the axis. * The IterVar's extent indicates the size of the axis, diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 4b6b28d52ee9..cfb7f1ef0d5a 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -51,7 +51,7 @@ using FloatImmNode = tvm::FloatImmNode; class StringImmNode : public PrimExprNode { public: /*! \brief The constant value content. */ - std::string value; + String value; void VisitAttrs(AttrVisitor* v) { v->Visit("dtype", &dtype); @@ -74,7 +74,7 @@ class StringImmNode : public PrimExprNode { */ class StringImm : public PrimExpr { public: - TVM_DLL StringImm(std::string value); + TVM_DLL StringImm(String value); TVM_DEFINE_OBJECT_REF_METHODS(StringImm, PrimExpr, StringImmNode); }; @@ -889,7 +889,7 @@ class CallNode : public PrimExprNode { PureIntrinsic = 5 }; /*! \brief The name of the function/intrinsic. */ - std::string name; + String name; /*! \brief The arguments. */ Array args; /*! \brief Type of calls. */ @@ -958,7 +958,7 @@ class Call : public PrimExpr { public: using CallType = CallNode::CallType; - TVM_DLL Call(DataType dtype, std::string name, Array args, CallType call_type); + TVM_DLL Call(DataType dtype, String name, Array args, CallType call_type); TVM_DEFINE_OBJECT_REF_METHODS(Call, PrimExpr, CallNode); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 2aaf79511dae..ee8e1ebbfbe5 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -109,7 +109,7 @@ class AttrStmtNode : public StmtNode { /*! \brief this is attribute about certain node */ ObjectRef node; /*! \brief the type key of the attribute */ - std::string attr_key; + String attr_key; /*! \brief The attribute value, value is well defined at current scope. */ PrimExpr value; /*! \brief The body statement to be executed */ @@ -144,7 +144,7 @@ class AttrStmtNode : public StmtNode { */ class AttrStmt : public Stmt { public: - TVM_DLL AttrStmt(ObjectRef node, std::string type_key, PrimExpr value, Stmt body); + TVM_DLL AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body); TVM_DEFINE_OBJECT_REF_METHODS(AttrStmt, Stmt, AttrStmtNode); }; diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 371277ad1f59..a794c12b55ee 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -54,7 +54,7 @@ using tvm::transform::Sequential; */ TVM_DLL Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, const std::string& name, const tvm::Array& required); + int opt_level, String name, tvm::Array required); /*! * \brief Inject prefetch instructions into stmt. @@ -88,7 +88,7 @@ TVM_DLL Pass StorageFlatten(int cache_line_size, bool create_bound_attribute = f * Expr pad_value) * \return The pass. */ -TVM_DLL Pass InjectCopyIntrin(std::string pragma_key, runtime::PackedFunc fintrin); +TVM_DLL Pass InjectCopyIntrin(String pragma_key, runtime::PackedFunc fintrin); /*! * \brief Detect and insert sync points to co-processor. @@ -103,7 +103,7 @@ TVM_DLL Pass CoProcSync(); * \param attr_key The attribute key to be checked. * \return The pass. */ -TVM_DLL Pass LiftAttrScope(std::string attr_key); +TVM_DLL Pass LiftAttrScope(String attr_key); /*! * \brief partition loops in the stmt. @@ -222,7 +222,7 @@ TVM_DLL Pass MakePackedAPI(int num_unpacked_args); * * \return The pass. */ -TVM_DLL Pass RemapThreadAxis(Map axis_map); +TVM_DLL Pass RemapThreadAxis(Map axis_map); /*! * \brief Lower custom datatypes. @@ -260,7 +260,7 @@ TVM_DLL Pass SkipAssert(); * \param storage_scope The storage scope considered. * \return The pass. */ -TVM_DLL Pass ThreadSync(std::string storage_scope); +TVM_DLL Pass ThreadSync(String storage_scope); /*! * \brief Lower cross thread alleduce. diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 9f098248b836..2a44909f531d 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -245,7 +245,7 @@ class IterVarNode : public Object { * \brief additional tag on the iteration variable, * set this if this is binded already to a known thread tag. */ - std::string thread_tag; + String thread_tag; void VisitAttrs(AttrVisitor* v) { v->Visit("dom", &dom); @@ -278,7 +278,7 @@ class IterVarNode : public Object { */ class IterVar : public ObjectRef { public: - TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, std::string thread_tag = ""); + TVM_DLL IterVar(Range dom, Var var, IterVarType iter_type, String thread_tag = ""); /*! * \return the corresponding var in the IterVar. */ diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py index ccc13d4211af..94e9cf3e8213 100644 --- a/python/tvm/ir/json_compact.py +++ b/python/tvm/ir/json_compact.py @@ -123,6 +123,7 @@ def _convert(item, nodes): "relay.IncompleteType": _rename("IncompleteType"), "relay.TypeRelation": _rename("TypeRelation"), "relay.TypeCall": _rename("TypeCall"), + "relay.Constructor": [_update_from_std_str("name_hint")], "relay.Module": _rename("IRModule"), "relay.SourceName": _rename("SourceName"), "relay.Span": _rename("Span"), @@ -137,6 +138,11 @@ def _convert(item, nodes): # TIR "Variable": [_update_tir_var("tir.Var"), _update_from_std_str("name")], "SizeVar": [_update_tir_var("tir.SizeVar"), _update_from_std_str("name")], + "StringImm": [_update_from_std_str("value")], + "Call": [_update_from_std_str("name")], + "AttrStmt": [_update_from_std_str("attr_key")], + "Layout": [_update_from_std_str("name")], + "Buffer": [_update_from_std_str("name"), _update_from_std_str("scope")], } return create_updater(node_map, "0.6", "0.7") diff --git a/src/ir/op.cc b/src/ir/op.cc index 2c802b63c460..63d223050ff5 100644 --- a/src/ir/op.cc +++ b/src/ir/op.cc @@ -134,9 +134,8 @@ ObjectPtr CreateOp(const std::string& name) { return Op2ObjectPtr::Get(op); } -TVM_REGISTER_NODE_TYPE(OpNode).set_creator(CreateOp).set_repr_bytes([](const Object* n) { - return static_cast(n)->name; -}); +TVM_REGISTER_NODE_TYPE(OpNode).set_creator(CreateOp).set_repr_bytes( + [](const Object* n) -> std::string { return static_cast(n)->name; }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 322c1ef59ac6..d74b95abebdb 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -376,8 +376,7 @@ IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) c } Pass CreateModulePass(const runtime::TypedPackedFunc& pass_func, - int opt_level, const String& name, - const tvm::Array& required) { + int opt_level, String name, tvm::Array required) { PassInfo pass_info = PassInfo(opt_level, name, required); return ModulePass(pass_func, pass_info); } @@ -385,7 +384,7 @@ Pass CreateModulePass(const runtime::TypedPackedFunc required) { + .set_body_typed([](int opt_level, String name, tvm::Array required) { return PassInfo(opt_level, name, required); }); diff --git a/src/node/reflection.cc b/src/node/reflection.cc index 249e2d251944..8de21da9a645 100644 --- a/src/node/reflection.cc +++ b/src/node/reflection.cc @@ -36,10 +36,10 @@ using runtime::TVMRetValue; // Attr getter. class AttrGetter : public AttrVisitor { public: - const std::string& skey; + const String& skey; TVMRetValue* ret; - AttrGetter(const std::string& skey, TVMRetValue* ret) : skey(skey), ret(ret) {} + AttrGetter(const String& skey, TVMRetValue* ret) : skey(skey), ret(ret) {} bool found_ref_object{false}; @@ -84,7 +84,7 @@ class AttrGetter : public AttrVisitor { } }; -runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const std::string& field_name) const { +runtime::TVMRetValue ReflectionVTable::GetAttr(Object* self, const String& field_name) const { runtime::TVMRetValue ret; AttrGetter getter(field_name, &ret); diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 184ee58009d7..b540dd47bcd9 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -145,7 +145,7 @@ bool FunctionPassNode::SkipFunction(const Function& func) const { Pass CreateFunctionPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, const String& name, const tvm::Array& required) { + int opt_level, String name, tvm::Array required) { PassInfo pass_info = PassInfo(opt_level, name, required); return FunctionPass(pass_func, pass_info); } diff --git a/src/support/ffi_testing.cc b/src/support/ffi_testing.cc index 2b944cb2ddd6..839f52968b82 100644 --- a/src/support/ffi_testing.cc +++ b/src/support/ffi_testing.cc @@ -31,7 +31,7 @@ namespace tvm { // Attrs used to python API struct TestAttrs : public AttrsNode { int axis; - std::string name; + String name; Array padding; TypedEnvFunc func; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 9113c988acdd..6ad050ace9a3 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -316,7 +316,8 @@ llvm::Value* CodeGenCPU::CreateCallExtern(const CallNode* op) { } else { llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { - f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, op->name, module_.get()); + f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + op->name.operator llvm::StringRef(), module_.get()); } #if TVM_LLVM_VERSION >= 90 auto ext_callee = llvm::FunctionCallee(f); @@ -408,7 +409,8 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } llvm::FunctionType* ftype = llvm::FunctionType::get(t_int_, arg_types, false); llvm::Function* fcompute = llvm::Function::Create( - ftype, llvm::Function::PrivateLinkage, op->value.as()->value, module_.get()); + ftype, llvm::Function::PrivateLinkage, + op->value.as()->value.operator llvm::StringRef(), module_.get()); BasicBlock* compute_call_end = CheckCallSuccess(builder_->CreateCall(fcompute, arg_values)); // setup compute fuinction. std::unordered_map new_vmap; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 3af9fc3f4519..85e3de5844fd 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -664,7 +664,8 @@ llvm::Value* CodeGenLLVM::CreateCallExtern(const CallNode* op) { llvm::FunctionType::get(GetLLVMType(GetRef(op)), arg_type, false); llvm::Function* f = module_->getFunction(op->name); if (f == nullptr) { - f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, op->name, module_.get()); + f = llvm::Function::Create(ftype, llvm::Function::ExternalLinkage, + op->name.operator llvm::StringRef(), module_.get()); } llvm::CallInst* call = builder_->CreateCall(f, arg_value); return call; diff --git a/src/target/llvm/llvm_common.h b/src/target/llvm/llvm_common.h index 529ee7485fde..9a4ccfc9cf9c 100644 --- a/src/target/llvm/llvm_common.h +++ b/src/target/llvm/llvm_common.h @@ -32,6 +32,7 @@ #include #include #include +#include #if TVM_LLVM_VERSION >= 100 #include #include @@ -108,5 +109,11 @@ std::unique_ptr GetLLVMTargetMachine(const std::string& tar } // namespace codegen } // namespace tvm + +namespace tvm { +namespace runtime { +inline String::operator llvm::StringRef() const { return llvm::StringRef(get()->data, size()); } +} // namespace runtime +} // namespace tvm #endif // TVM_LLVM_VERSION #endif // TVM_TARGET_LLVM_LLVM_COMMON_H_ diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 2b64f11ccc3c..4e433fc718b1 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -44,7 +44,7 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { return array; } -Buffer decl_buffer(Array shape, DataType dtype, std::string name) { +Buffer decl_buffer(Array shape, DataType dtype, String name) { return Buffer(Var(name, PointerType(PrimType(dtype))), dtype, shape, Array(), PrimExpr(), name, "", 0, 0, kDefault); } @@ -380,7 +380,7 @@ PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lane } Buffer::Buffer(Var data, DataType dtype, Array shape, Array strides, - PrimExpr elem_offset, std::string name, std::string scope, int data_alignment, + PrimExpr elem_offset, String name, String scope, int data_alignment, int offset_factor, BufferType buffer_type) { auto n = make_object(); n->data = std::move(data); @@ -423,7 +423,7 @@ TVM_REGISTER_NODE_TYPE(BufferNode); TVM_REGISTER_GLOBAL("tir.Buffer").set_body([](TVMArgs args, TVMRetValue* ret) { CHECK_EQ(args.size(), 10); - auto buffer_type = args[9].operator std::string(); + auto buffer_type = args[9].operator String(); BufferType type = (buffer_type == "auto_broadcast") ? kAutoBroadcast : kDefault; *ret = Buffer(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7], args[8], type); diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index 7959ebaebfe4..94efff5180fb 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -126,7 +126,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // IterVar -IterVar::IterVar(Range dom, Var var, IterVarType t, std::string thread_tag) { +IterVar::IterVar(Range dom, Var var, IterVarType t, String thread_tag) { ObjectPtr n = make_object(); n->dom = dom; n->var = var; @@ -136,7 +136,7 @@ IterVar::IterVar(Range dom, Var var, IterVarType t, std::string thread_tag) { } TVM_REGISTER_GLOBAL("tir.IterVar") - .set_body_typed([](Range dom, Var var, int iter_type, std::string thread_tag) { + .set_body_typed([](Range dom, Var var, int iter_type, String thread_tag) { return IterVar(dom, var, static_cast(iter_type), thread_tag); }); @@ -159,16 +159,14 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) TVM_REGISTER_NODE_TYPE(IterVarNode); // StringImm -StringImm::StringImm(std::string value) { +StringImm::StringImm(String value) { ObjectPtr node = make_object(); node->dtype = DataType::Handle(); node->value = std::move(value); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](std::string value) { - return StringImm(value); -}); +TVM_REGISTER_GLOBAL("tir.StringImm").set_body_typed([](String value) { return StringImm(value); }); TVM_REGISTER_NODE_TYPE(StringImmNode); @@ -700,7 +698,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // Call -Call::Call(DataType dtype, std::string name, Array args, CallType call_type) { +Call::Call(DataType dtype, String name, Array args, CallType call_type) { for (size_t i = 0; i < args.size(); ++i) { CHECK(args[i].defined()); } @@ -743,7 +741,7 @@ bool CallNode::is_vectorizable() const { } TVM_REGISTER_GLOBAL("tir.Call") - .set_body_typed([](DataType type, std::string name, Array args, int call_type) { + .set_body_typed([](DataType type, String name, Array args, int call_type) { Array prim_expr_args; for (const auto& it : args) { CHECK(it->IsInstance() || it->IsInstance()); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 9bb1de427847..66497755c88a 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -57,7 +57,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // AttrStmt -AttrStmt::AttrStmt(ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) { +AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body) { auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); @@ -67,7 +67,7 @@ AttrStmt::AttrStmt(ObjectRef node, std::string attr_key, PrimExpr value, Stmt bo } TVM_REGISTER_GLOBAL("tir.AttrStmt") - .set_body_typed([](ObjectRef node, std::string attr_key, PrimExpr value, Stmt body) { + .set_body_typed([](ObjectRef node, String attr_key, PrimExpr value, Stmt body) { return AttrStmt(node, attr_key, value, body); }); diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index 30d5f0f50774..50106c90a5e5 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -118,7 +118,7 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) Pass CreatePrimFuncPass( const runtime::TypedPackedFunc& pass_func, - int opt_level, const std::string& name, const tvm::Array& required) { + int opt_level, String name, tvm::Array required) { PassInfo pass_info = PassInfo(opt_level, name, required); return PrimFuncPass(pass_func, pass_info); } diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index 416358cebce0..b27459f4bd45 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -183,7 +183,7 @@ Stmt InjectCopyIntrin(Stmt stmt, const std::string& pragma_key, namespace transform { -Pass InjectCopyIntrin(std::string pragma_key, PackedFunc flower_copy_fromto) { +Pass InjectCopyIntrin(String pragma_key, PackedFunc flower_copy_fromto) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = CopyIntrinInjector(pragma_key, flower_copy_fromto)(std::move(n->body)); diff --git a/src/tir/transforms/lift_attr_scope.cc b/src/tir/transforms/lift_attr_scope.cc index ca4b39e569db..1a1279f0640a 100644 --- a/src/tir/transforms/lift_attr_scope.cc +++ b/src/tir/transforms/lift_attr_scope.cc @@ -181,7 +181,7 @@ Stmt LiftAttrScope(Stmt stmt, std::string attr_key) { namespace transform { -Pass LiftAttrScope(std::string attr_key) { +Pass LiftAttrScope(String attr_key) { auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = AttrScopeLifter(attr_key).Lift(std::move(n->body)); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index b8575d28c8ce..612efb092395 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -361,7 +361,7 @@ Stmt ThreadSync(Stmt stmt, std::string storage_scope) { namespace transform { -Pass ThreadSync(std::string storage_scope) { +Pass ThreadSync(String storage_scope) { auto pass_func = [storage_scope](PrimFunc f, IRModule m, PassContext ctx) { auto* n = f.CopyOnWrite(); n->body = ThreadSync(std::move(n->body), storage_scope);