Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RUNTIME][NDArray] Allowing External Libraries to Subclass NDArrays #2613

Merged
merged 9 commits into from
Feb 21, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/extension/python/tvm_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class NDSubClass(tvm.nd.NDArrayBase):
leverage TVM's FFI without any modification.
"""
# Should be consistent with the type-trait set in the backend
_array_type_index = 1
_array_type_info = 1

@staticmethod
def create(addtional_info):
Expand Down
12 changes: 6 additions & 6 deletions apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class NDSubClass;
namespace tvm {
namespace runtime {
template<>
struct extension_class_info<tvm_ext::IntVector> {
struct extension_type_info<tvm_ext::IntVector> {
static const int code = 17;
};
template<>
struct array_type_index<tvm_ext::NDSubClass> {
struct array_type_info<tvm_ext::NDSubClass> {
static const int code = 1;
};
} // namespace tvm
Expand All @@ -39,13 +39,13 @@ namespace tvm_ext {
* To use this extension, an external library should
*
* 1) Inherit TVM's NDArray and NDArray container,
* and define the trait `array_type_index` for this class.
* and define the trait `array_type_info` for this class.
*
* 2) Define a constructor in the inherited class that accepts
* a pointer to TVM's Container, which is nullable.
*
* 3) On Python frontend, inherit `tvm.nd.NDArrayBase`,
* define the class attribute `_array_type_index` consistent to
* define the class attribute `_array_type_info` consistent to
* the C++ type trait, and register the subclass using `tvm.register_extension`.
*/
class NDSubClass : public tvm::runtime::NDArray {
Expand All @@ -54,11 +54,11 @@ class NDSubClass : public tvm::runtime::NDArray {
public:
SubContainer(int addtional_info) :
addtional_info_(addtional_info) {
array_type_index_ = array_type_index<NDSubClass>::code;
array_type_info_ = array_type_info<NDSubClass>::code;
}
static bool Is(NDArray::Container *container) {
SubContainer *c = static_cast<SubContainer*>(container);
return c->array_type_index_ == array_type_index<NDSubClass>::code;
return c->array_type_info_ == array_type_info<NDSubClass>::code;
}
int addtional_info_{0};
};
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,14 @@ class NDArray {
* All subclasses of NDArray should override code > 0.
*/
template<typename T>
struct array_type_index {
struct array_type_info {
/*! \brief the value of the traits */
static const int code = -1;
};

// Overrides the type trait for tvm's NDArray.
template<>
struct array_type_index<NDArray> {
struct array_type_info<NDArray> {
static const int code = 0;
};

Expand Down Expand Up @@ -257,10 +257,10 @@ class NDArray::Container {
* Default value 0 means normal NDArray::Conatainer.
*
* We can extend a more specialized NDArray::Container
* and use the array_type_index_ to indicate
* and use the array_type_info_ to indicate
* the specific array subclass.
*/
int32_t array_type_index_{0};
int32_t array_type_info_{0};
/*! \brief The internal reference counter */
std::atomic<int> ref_counter_{0};
/*!
Expand Down
26 changes: 13 additions & 13 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ inline std::string TVMType2String(TVMType t);
* \tparam T the typename
*/
template<typename T>
struct extension_class_info {
struct extension_type_info {
static const int code = 0;
};

Expand Down Expand Up @@ -461,7 +461,7 @@ class TVMPODValue_ {
TNDArray AsNDArray() const {
if (type_code_ == kNull) return TNDArray(nullptr);
auto *container = static_cast<NDArray::Container*>(value_.v_handle);
CHECK_EQ(container->array_type_index_, array_type_index<TNDArray>::code);
CHECK_EQ(container->array_type_info_, array_type_info<TNDArray>::code);
return TNDArray(container);
}
template<typename TExtension>
Expand Down Expand Up @@ -736,10 +736,10 @@ class TVMRetValue : public TVMPODValue_ {
}
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
extension_type_info<T>::code != 0>::type>
TVMRetValue& operator=(const T& other) {
this->SwitchToClass<T>(
extension_class_info<T>::code, other);
extension_type_info<T>::code, other);
return *this;
}
/*!
Expand Down Expand Up @@ -1103,7 +1103,7 @@ class TVMArgsSetter {
// extension
template<typename T,
typename = typename std::enable_if<
extension_class_info<T>::code != 0>::type>
extension_type_info<T>::code != 0>::type>
inline void operator()(size_t i, const T& value) const;
// NodeRef related extenstions: in tvm/packed_func_ext.h
inline void operator()(size_t i, const NodeRef& other) const; // NOLINT(*)
Expand Down Expand Up @@ -1249,25 +1249,25 @@ template<typename T, typename>
inline TVMArgValue::operator T() const {
return detail::
TVMValueCast<T, TVMArgValue,
(extension_class_info<T>::code != 0),
(array_type_index<T>::code > 0)>
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
}

template<typename T, typename>
inline TVMRetValue::operator T() const {
return detail::
TVMValueCast<T, TVMRetValue,
(extension_class_info<T>::code != 0),
(array_type_index<T>::code > 0)>
(extension_type_info<T>::code != 0),
(array_type_info<T>::code > 0)>
::Apply(this);
}

template<typename T, typename>
inline void TVMArgsSetter::operator()(size_t i, const T& value) const {
static_assert(extension_class_info<T>::code != 0,
static_assert(extension_type_info<T>::code != 0,
"Need to have extesion code");
type_codes_[i] = extension_class_info<T>::code;
type_codes_[i] = extension_type_info<T>::code;
values_[i].v_handle = const_cast<T*>(&value);
}

Expand All @@ -1284,9 +1284,9 @@ struct ExtTypeInfo {

template<typename T>
inline ExtTypeVTable* ExtTypeVTable::Register_() {
const int code = extension_class_info<T>::code;
const int code = extension_type_info<T>::code;
static_assert(code != 0,
"require extension_class_info traits to be declared with non-zero code");
"require extension_type_info traits to be declared with non-zero code");
ExtTypeVTable vt;
vt.clone = ExtTypeInfo<T>::clone;
vt.destroy = ExtTypeInfo<T>::destroy;
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class Registry {
/*!
* \brief Macro to register extension type.
* This must be registered in a cc file
* after the trait extension_class_info is defined.
* after the trait extension_type_info is defined.
*/
#define TVM_REGISTER_EXT_TYPE(T) \
TVM_STR_CONCAT(TVM_TYPE_REG_VAR_DEF, __COUNTER__) = \
Expand Down
6 changes: 3 additions & 3 deletions nnvm/include/nnvm/compiler/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,17 @@ namespace tvm {
namespace runtime {

template<>
struct extension_class_info<nnvm::Symbol> {
struct extension_type_info<nnvm::Symbol> {
static const int code = 16;
};

template<>
struct extension_class_info<nnvm::Graph> {
struct extension_type_info<nnvm::Graph> {
static const int code = 17;
};

template<>
struct extension_class_info<nnvm::compiler::AttrDict> {
struct extension_type_info<nnvm::compiler::AttrDict> {
static const int code = 18;
};

Expand Down
4 changes: 2 additions & 2 deletions nnvm/src/compiler/packed_func_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ TVM_REGISTER_GLOBAL("nnvm.compiler._register_alter_op_layout")
if (ret.type_code() == TVMTypeCode::kNull) {
return false;
}
CHECK_EQ(ret.type_code(), tvm::runtime::extension_class_info<Symbol>::code)
<< " expected " << "Symbol (code = " << tvm::runtime::extension_class_info<Symbol>::code
CHECK_EQ(ret.type_code(), tvm::runtime::extension_type_info<Symbol>::code)
<< " expected " << "Symbol (code = " << tvm::runtime::extension_type_info<Symbol>::code
<< ") but get code = " << ret.type_code();
*ret_symbol = *(static_cast<Symbol*>(ret.value().v_handle));
return true;
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/_ffi/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def _make_array(handle, is_view, is_container):
handle = ctypes.cast(handle, TVMArrayHandle)
fcreate = _CLASS_NDARRAY
if is_container and _TVM_ND_CLS:
array_type_index = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_index.value
if array_type_index > 0:
fcreate = _TVM_ND_CLS[array_type_index]
array_type_info = ctypes.cast(handle, TVMNDArrayContainerHandle).array_type_info.value
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
return fcreate(handle, is_view)

_TVM_COMPATS = ()
Expand All @@ -101,7 +101,7 @@ def _reg_extension(cls, fcreate):

def _reg_ndarray(cls, fcreate):
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_index] = fcreate
_TVM_ND_CLS[cls._array_type_info] = fcreate

_CLASS_NDARRAY = None

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ ctypedef struct TVMNDArrayContainer:
DLTensor dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
int32_t array_type_index
int32_t array_type_info

ctypedef TVMNDArrayContainer* TVMNDArrayContainerHandle

Expand Down
10 changes: 5 additions & 5 deletions python/tvm/_ffi/_cython/ndarray.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ cdef class NDArrayBase:

cdef c_make_array(void* chandle, is_view, is_container):
global _TVM_ND_CLS
cdef int32_t array_type_index
cdef int32_t array_type_info
fcreate = _CLASS_NDARRAY
if is_container and len(_TVM_ND_CLS) > 0:
array_type_index = (<TVMNDArrayContainerHandle>chandle).array_type_index
if array_type_index > 0:
fcreate = _TVM_ND_CLS[array_type_index]
array_type_info = (<TVMNDArrayContainerHandle>chandle).array_type_info
if array_type_info > 0:
fcreate = _TVM_ND_CLS[array_type_info]
ret = fcreate(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
return ret
Expand All @@ -100,7 +100,7 @@ cdef _TVM_ND_CLS = {}

def _reg_ndarray(cls, fcreate):
global _TVM_ND_CLS
_TVM_ND_CLS[cls._array_type_index] = fcreate
_TVM_ND_CLS[cls._array_type_info] = fcreate

def _make_array(handle, is_view, is_container):
cdef unsigned long long ptr
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/_ffi/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,11 @@ def register_extension(cls, fcreate=None):
The registered class is requires one property: _tvm_handle.

If the registered class is a subclass of NDArray,
it is required to have a class attribute _array_type_index.
it is required to have a class attribute _array_type_info.
Otherwise, it is required to have a class attribute _tvm_tcode.

- ```_tvm_handle``` returns integer represents the address of the handle.
- ```_tvm_tcode``` or ```_array_type_index``` gives integer represents type
- ```_tvm_tcode``` or ```_array_type_info``` gives integer represents type
code of the class.

Returns
Expand Down Expand Up @@ -350,7 +350,7 @@ def _tvm_handle(self):
"""
if issubclass(cls, _NDArrayBase):
assert fcreate is not None
assert hasattr(cls, "_array_type_index")
assert hasattr(cls, "_array_type_info")
_reg_ndarray(cls, fcreate)
else:
assert hasattr(cls, "_tvm_tcode")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,6 @@ class TVMNDArrayContainer(ctypes.Structure):
_fields_ = [("dl_tensor", TVMArray),
("manager_ctx", ctypes.c_void_p),
("deleter", ctypes.c_void_p),
("array_type_index", ctypes.c_int32)]
("array_type_info", ctypes.c_int32)]

TVMNDArrayContainerHandle = ctypes.POINTER(TVMNDArrayContainer)
2 changes: 1 addition & 1 deletion tests/cpp/packed_func_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ namespace tvm {
namespace runtime {

template<>
struct extension_class_info<test::IntVector> {
struct extension_type_info<test::IntVector> {
static const int code = kExtBegin + 1;
};
} // runtime
Expand Down