diff --git a/taichi/ir/type.cpp b/taichi/ir/type.cpp index fd65cbf1cbe40..7a1fe21309a16 100644 --- a/taichi/ir/type.cpp +++ b/taichi/ir/type.cpp @@ -10,7 +10,7 @@ TLANG_NAMESPACE_BEGIN // This part doesn't look good, but we will remove it soon anyway. #define PER_TYPE(x) \ DataType PrimitiveType::x = \ - DataType(Program::get_type_factory().get_primitive_type( \ + DataType(TypeFactory::get_instance().get_primitive_type( \ PrimitiveType::primitive_type::x)); #include "taichi/inc/data_type.inc.h" @@ -46,7 +46,7 @@ bool DataType::is_pointer() const { void DataType::set_is_pointer(bool is_ptr) { if (is_ptr && !ptr_->is()) { - ptr_ = Program::get_type_factory().get_pointer_type(ptr_); + ptr_ = TypeFactory::get_instance().get_pointer_type(ptr_); } if (!is_ptr && ptr_->is()) { ptr_ = ptr_->cast()->get_pointee_type(); @@ -70,7 +70,7 @@ std::string PrimitiveType::to_string() const { DataType LegacyVectorType(int width, DataType data_type, bool is_pointer) { TI_ASSERT(width == 1); if (is_pointer) { - return Program::get_type_factory().get_pointer_type(data_type.get_ptr()); + return TypeFactory::get_instance().get_pointer_type(data_type.get_ptr()); } else { return data_type; } diff --git a/taichi/ir/type_factory.cpp b/taichi/ir/type_factory.cpp index aa1d8250a7a40..de39096c0e2ab 100644 --- a/taichi/ir/type_factory.cpp +++ b/taichi/ir/type_factory.cpp @@ -2,6 +2,11 @@ TLANG_NAMESPACE_BEGIN +TypeFactory &TypeFactory::get_instance() { + static TypeFactory *type_factory = new TypeFactory; + return *type_factory; +} + Type *TypeFactory::get_primitive_type(PrimitiveType::primitive_type id) { std::lock_guard _(mut_); @@ -28,4 +33,7 @@ Type *TypeFactory::get_pointer_type(Type *element) { return pointer_types_[key].get(); } +TypeFactory::TypeFactory() { +} + TLANG_NAMESPACE_END diff --git a/taichi/ir/type_factory.h b/taichi/ir/type_factory.h index 8a74b4ddc8e7f..7700cd86a9c55 100644 --- a/taichi/ir/type_factory.h +++ b/taichi/ir/type_factory.h @@ -8,6 +8,8 @@ TLANG_NAMESPACE_BEGIN class TypeFactory { public: + static TypeFactory &get_instance(); + Type *get_primitive_type(PrimitiveType::primitive_type id); Type *get_vector_type(int num_elements, Type *element); @@ -15,6 +17,8 @@ class TypeFactory { Type *get_pointer_type(Type *element); private: + TypeFactory(); + std::unordered_map> primitive_types_; diff --git a/taichi/lang_util.cpp b/taichi/lang_util.cpp index 1f7300857c3b3..363079ad4deeb 100644 --- a/taichi/lang_util.cpp +++ b/taichi/lang_util.cpp @@ -346,7 +346,7 @@ class TypePromotionMapping { DataType query(DataType x, DataType y) { auto primitive = mapping[std::make_pair(to_primitive_type(x), to_primitive_type(y))]; - return Program::get_type_factory().get_primitive_type(primitive); + return TypeFactory::get_instance().get_primitive_type(primitive); } private: diff --git a/taichi/program/program.cpp b/taichi/program/program.cpp index 9bb3777fb6be2..3bdc06cee359b 100644 --- a/taichi/program/program.cpp +++ b/taichi/program/program.cpp @@ -197,9 +197,10 @@ Program::Program(Arch desired_arch) { } TypeFactory &Program::get_type_factory() { - // type_factory should never be destroyed, hence the raw new operator. - static TypeFactory *type_factory = new TypeFactory; - return *type_factory; + TI_WARN( + "Program::get_type_factory() will be deprecated, Please use " + "TypeFactory::get_instance()"); + return TypeFactory::get_instance(); } FunctionType Program::compile(Kernel &kernel) {