diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 4d9bc6a2e83f..bba9e14aeeb1 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -318,16 +318,20 @@ struct ClosureCell; */ class Object { public: - ObjectPtr ptr; - explicit Object(ObjectPtr ptr) : ptr(ptr) {} - explicit Object(ObjectCell* ptr) : ptr(ptr) {} - Object() : ptr() {} - Object(const Object& obj) : ptr(obj.ptr) {} - ObjectCell* operator->() { return this->ptr.operator->(); } - + ObjectPtr ptr_; + explicit Object(ObjectPtr ptr) : ptr_(ptr) {} + explicit Object(ObjectCell* ptr) : ptr_(ptr) {} + Object() : ptr_() {} + Object(const Object& obj) : ptr_(obj.ptr_) {} + ObjectCell* operator->() { return this->ptr_.operator->(); } + + /*! \brief Construct a tensor object. */ static Object Tensor(const NDArray& data); + /*! \brief Construct a datatype object. */ static Object Datatype(size_t tag, const std::vector& fields); + /*! \brief Construct a tuple object. */ static Object Tuple(const std::vector& fields); + /*! \brief Construct a closure object. */ static Object Closure(size_t func_index, const std::vector& free_vars); ObjectPtr AsTensor() const; @@ -335,27 +339,36 @@ class Object { ObjectPtr AsClosure() const; }; +/*! \brief An object containing an NDArray. */ struct TensorCell : public ObjectCell { + /*! \brief The NDArray. */ NDArray data; explicit TensorCell(const NDArray& data) : ObjectCell(ObjectTag::kTensor), data(data) {} }; +/*! \brief An object representing a structure or enumeration. */ struct DatatypeCell : public ObjectCell { + /*! \brief The tag representing the constructor used. */ size_t tag; + /*! \brief The fields of the structure. */ std::vector fields; DatatypeCell(size_t tag, const std::vector& fields) : ObjectCell(ObjectTag::kDatatype), tag(tag), fields(fields) {} }; +/*! \brief An object representing a closure. */ struct ClosureCell : public ObjectCell { + /*! \brief The index into the VM function table. */ size_t func_index; + /*! \brief The free variables of the closure. */ std::vector free_vars; ClosureCell(size_t func_index, const std::vector& free_vars) : ObjectCell(ObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {} }; +/*! \brief Extract the NDArray from a tensor object. */ NDArray ToNDArray(const Object& obj); /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 5800ba23a531..9fcefcbbe4b1 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -745,8 +745,8 @@ class TVMRetValue : public TVMPODValue_ { TVMRetValue& operator=(Object other) { this->Clear(); type_code_ = kObject; - value_.v_handle = other.ptr.data_; - other.ptr.data_ = nullptr; + value_.v_handle = other.ptr_.data_; + other.ptr_.data_ = nullptr; return *this; } TVMRetValue& operator=(PackedFunc f) { diff --git a/src/lang/reflection.cc b/src/lang/reflection.cc index 73cf828e8e12..bc3d2895b811 100644 --- a/src/lang/reflection.cc +++ b/src/lang/reflection.cc @@ -87,7 +87,7 @@ class NodeIndexer : public AttrVisitor { } void Visit(const char* key, Object* value) final { - ObjectCell* ptr = value->ptr.get(); + ObjectCell* ptr = value->ptr_.get(); if (vm_obj_index.count(ptr)) return; CHECK_EQ(vm_obj_index.size(), vm_obj_list.size()); vm_obj_index[ptr] = vm_obj_list.size(); @@ -214,7 +214,7 @@ class JSONAttrGetter : public AttrVisitor { } void Visit(const char* key, Object* value) final { node_->attrs[key] = std::to_string( - vm_obj_index_->at(value->ptr.get())); + vm_obj_index_->at(value->ptr_.get())); } // Get the node void Get(Node* node) {