Skip to content

Commit

Permalink
Final tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch committed May 1, 2019
1 parent 8f61d3d commit 64c3fef
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 11 deletions.
27 changes: 20 additions & 7 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,44 +318,57 @@ struct ClosureCell;
*/
class Object {
public:
ObjectPtr<ObjectCell> ptr;
explicit Object(ObjectPtr<ObjectCell> ptr) : ptr(ptr) {}
explicit Object(ObjectCell* ptr) : ptr(ptr) {}
Object() : ptr() {}
Object(const Object& obj) : ptr(obj.ptr) {}
ObjectCell* operator->() { return this->ptr.operator->(); }

ObjectPtr<ObjectCell> ptr_;
explicit Object(ObjectPtr<ObjectCell> ptr) : ptr_(ptr) {}
explicit Object(ObjectCell* ptr) : ptr_(ptr) {}
Object() : ptr_() {}
Object(const Object& obj) : ptr_(obj.ptr_) {}
ObjectCell* operator->() { return this->ptr_.operator->(); }

/*! \brief Construct a tensor object. */
static Object Tensor(const NDArray& data);
/*! \brief Construct a datatype object. */
static Object Datatype(size_t tag, const std::vector<Object>& fields);
/*! \brief Construct a tuple object. */
static Object Tuple(const std::vector<Object>& fields);
/*! \brief Construct a closure object. */
static Object Closure(size_t func_index, const std::vector<Object>& free_vars);

ObjectPtr<TensorCell> AsTensor() const;
ObjectPtr<DatatypeCell> AsDatatype() const;
ObjectPtr<ClosureCell> AsClosure() const;
};

/*! \brief An object containing an NDArray. */
struct TensorCell : public ObjectCell {
/*! \brief The NDArray. */
NDArray data;
explicit TensorCell(const NDArray& data) : ObjectCell(ObjectTag::kTensor), data(data) {}
};

/*! \brief An object representing a structure or enumeration. */
struct DatatypeCell : public ObjectCell {
/*! \brief The tag representing the constructor used. */
size_t tag;
/*! \brief The fields of the structure. */
std::vector<Object> fields;

DatatypeCell(size_t tag, const std::vector<Object>& fields)
: ObjectCell(ObjectTag::kDatatype), tag(tag), fields(fields) {}
};

/*! \brief An object representing a closure. */
struct ClosureCell : public ObjectCell {
/*! \brief The index into the VM function table. */
size_t func_index;
/*! \brief The free variables of the closure. */
std::vector<Object> free_vars;

ClosureCell(size_t func_index, const std::vector<Object>& free_vars)
: ObjectCell(ObjectTag::kClosure), func_index(func_index), free_vars(free_vars) {}
};

/*! \brief Extract the NDArray from a tensor object. */
NDArray ToNDArray(const Object& obj);

/*!
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,8 @@ class TVMRetValue : public TVMPODValue_ {
TVMRetValue& operator=(Object other) {
this->Clear();
type_code_ = kObject;
value_.v_handle = other.ptr.data_;
other.ptr.data_ = nullptr;
value_.v_handle = other.ptr_.data_;
other.ptr_.data_ = nullptr;
return *this;
}
TVMRetValue& operator=(PackedFunc f) {
Expand Down
4 changes: 2 additions & 2 deletions src/lang/reflection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class NodeIndexer : public AttrVisitor {
}

void Visit(const char* key, Object* value) final {
ObjectCell* ptr = value->ptr.get();
ObjectCell* ptr = value->ptr_.get();
if (vm_obj_index.count(ptr)) return;
CHECK_EQ(vm_obj_index.size(), vm_obj_list.size());
vm_obj_index[ptr] = vm_obj_list.size();
Expand Down Expand Up @@ -214,7 +214,7 @@ class JSONAttrGetter : public AttrVisitor {
}
void Visit(const char* key, Object* value) final {
node_->attrs[key] = std::to_string(
vm_obj_index_->at(value->ptr.get()));
vm_obj_index_->at(value->ptr_.get()));
}
// Get the node
void Get(Node* node) {
Expand Down

0 comments on commit 64c3fef

Please sign in to comment.