diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h index f21e484216ea..c7887aad378f 100644 --- a/include/mxnet/lib_api.h +++ b/include/mxnet/lib_api.h @@ -37,7 +37,7 @@ #include #include -#define MX_LIBRARY_VERSION 1 +#define MX_LIBRARY_VERSION 2 /* * Import from DLPack https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h @@ -198,6 +198,7 @@ enum MXDType { kInt32 = 4, kInt8 = 5, kInt64 = 6, + kUNSET = 100, }; enum MXReturnValue { @@ -209,10 +210,22 @@ enum MXReturnValue { * \brief Tensor data structure used by custom operator */ struct MXTensor { - MXTensor() : data_ptr(NULL) {} - - MXTensor(void *data_ptr, const std::vector &shape, MXDType dtype) - : data_ptr(data_ptr), shape(shape), dtype(dtype) {} + MXTensor() : data_ptr(NULL), dtype(kUNSET), verID(0) {} + + MXTensor(void *data_ptr, const std::vector &shape, MXDType dtype, + size_t vID) + : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID) {} + + /*! \brief populate internal tensor fields */ + void setTensor(void *dptr, MXDType type, const int64_t* dims, + int ndims, size_t vID) { + data_ptr = dptr; dtype = type; verID = vID; + shape.clear(); + for (int j = 0; j < ndims; j++) { + shape.push_back(dims[j]); + } + setDLTensor(); + } /*! \brief populate DLTensor fields */ void setDLTensor() { @@ -277,6 +290,14 @@ struct MXTensor { return size; } + /*! \brief helper function to compare two MXTensors */ + inline bool isSame(const MXTensor &oth) const { + return data_ptr == oth.data_ptr && + dtype == oth.dtype && + verID == oth.verID && + shape == oth.shape; + } + // data is flatten 1D repr of tensor, elements are in continuous memory // user can access each element using the shape of tensor void *data_ptr; @@ -287,6 +308,9 @@ struct MXTensor { // type can only be MXDType enum types MXDType dtype; + // version number updated if the tensor has changed since the last use by custom op + size_t verID; + // corresponding DLTensor repr of MXTensor // easy way to reuse functions taking DLTensor DLTensor dltensor; @@ -684,15 +708,9 @@ typedef int (*opCallInferType_t)(inferType_t, const char* const*, const char* co #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute" typedef int (*opCallFComp_t)(fcomp_t, const char* const*, const char* const*, int, - const int64_t**, int*, void**, int*, int, - const int64_t**, int*, void**, int*, int, - xpu_malloc_t, void*); - -#define MXLIB_OPCALLBKWD_STR "_opCallBackward" -typedef int (*opCallBkwd_t)(fcomp_t, const char* const*, const char* const*, int, - const int64_t**, int*, void**, int*, int, - const int64_t**, int*, void**, int*, int, - xpu_malloc_t, void*); + const int64_t**, int*, void**, int*, size_t*, int, + const int64_t**, int*, void**, int*, size_t*, int, + xpu_malloc_t, void*); #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs" typedef int (*opCallMutateInputs_t)(mutateInputs_t, const char* const*, const char* const*, int, @@ -703,9 +721,9 @@ typedef int (*opCallCreateOpState_t)(createOpState_t, const char* const*, const void**); #define MXLIB_OPCALLFSTATEFULCOMP_STR "_opCallFStatefulCompute" -typedef int (*opCallFStatefulComp_t)(bool, void*, const int64_t**, int*, void**, int*, int, - const int64_t**, int*, void**, int*, int, - xpu_malloc_t, void*); +typedef int (*opCallFStatefulComp_t)(bool, void*, const int64_t**, int*, void**, int*, size_t*, + int, const int64_t**, int*, void**, int*, size_t*, + int, xpu_malloc_t, void*); #define MXLIB_INITIALIZE_STR "initialize" typedef int (*initialize_t)(int); @@ -876,9 +894,9 @@ extern "C" { _opCallFCompute(fcomp_t fcomp, const char* const* keys, const char* const* vals, int num, const int64_t** inshapes, int* indims, - void** indata, int* intypes, int num_in, + void** indata, int* intypes, size_t* inIDs, int num_in, const int64_t** outshapes, int* outdims, - void** outdata, int* outtypes, int num_out, + void** outdata, int* outtypes, size_t* outIDs, int num_out, xpu_malloc_t cpu_malloc, void* cpu_alloc) { // create map of attributes from list std::map attrs; @@ -889,23 +907,14 @@ extern "C" { // create a vector of tensors for inputs std::vector inputs(num_in); for (int i = 0; i < num_in; i++) { - inputs[i].data_ptr = indata[i]; - inputs[i].dtype = (MXDType)intypes[i]; - for (int j = 0; j < indims[i]; j++) { - inputs[i].shape.push_back(inshapes[i][j]); - } - inputs[i].setDLTensor(); + inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], inIDs[i]); } // create a vector of tensors for outputs std::vector outputs(num_out); for (int i = 0; i < num_out; i++) { - outputs[i].data_ptr = outdata[i]; - outputs[i].dtype = (MXDType) outtypes[i]; - for (int j = 0; j < outdims[i]; j++) { - outputs[i].shape.push_back(outshapes[i][j]); - } - outputs[i].setDLTensor(); + outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], + outIDs[i]); } OpResource res(cpu_malloc, cpu_alloc); @@ -973,30 +982,21 @@ extern "C" { #endif _opCallFStatefulCompute(bool is_forward, void* state_op, const int64_t** inshapes, int* indims, - void** indata, int* intypes, int num_in, + void** indata, int* intypes, size_t* inIDs, int num_in, const int64_t** outshapes, int* outdims, - void** outdata, int* outtypes, int num_out, + void** outdata, int* outtypes, size_t* outIDs, int num_out, xpu_malloc_t cpu_malloc, void* cpu_alloc) { // create a vector of tensors for inputs std::vector inputs(num_in); for (int i = 0; i < num_in; i++) { - inputs[i].data_ptr = indata[i]; - inputs[i].dtype = (MXDType)intypes[i]; - for (int j = 0; j < indims[i]; j++) { - inputs[i].shape.push_back(inshapes[i][j]); - } - inputs[i].setDLTensor(); + inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], indims[i], inIDs[i]); } // create a vector of tensors for outputs std::vector outputs(num_out); for (int i = 0; i < num_out; i++) { - outputs[i].data_ptr = outdata[i]; - outputs[i].dtype = (MXDType) outtypes[i]; - for (int j = 0; j < outdims[i]; j++) { - outputs[i].shape.push_back(outshapes[i][j]); - } - outputs[i].setDLTensor(); + outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], outdims[i], + outIDs[i]); } OpResource res(cpu_malloc, cpu_alloc); CustomStatefulOp* op_ptr = reinterpret_cast(state_op); diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index ca39ef22b20a..bb98d96f733a 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -395,6 +395,7 @@ int MXLoadLib(const char *path) { std::vector in_shapes, out_shapes; std::vector in_dims, out_dims; std::vector in_types, out_types; + std::vector in_verIDs, out_verIDs; // convert input tensors to constituent parts for (size_t i = 0; i < inputs.size(); i++) { @@ -402,6 +403,7 @@ int MXLoadLib(const char *path) { in_shapes.push_back(inputs[i].shape().data()); in_dims.push_back(inputs[i].shape().ndim()); in_types.push_back(inputs[i].dtype()); + in_verIDs.push_back(inputs[i].version()); } // convert output tensors to constituent parts @@ -410,6 +412,7 @@ int MXLoadLib(const char *path) { out_shapes.push_back(outputs[i].shape().data()); out_dims.push_back(outputs[i].shape().ndim()); out_types.push_back(outputs[i].dtype()); + out_verIDs.push_back(outputs[i].version()); } // get memory resource @@ -438,9 +441,10 @@ int MXLoadLib(const char *path) { // call fcompute function CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), attr_keys.size(), in_shapes.data(), in_dims.data(), in_data.data(), - in_types.data(), in_data.size(), + in_types.data(), in_verIDs.data(), in_data.size(), out_shapes.data(), out_dims.data(), out_data.data(), - out_types.data(), out_data.size(), cpu_malloc, &cpu_alloc)) + out_types.data(), out_verIDs.data(), out_data.size(), + cpu_malloc, &cpu_alloc)) << "Error calling FCompute for custom operator '" << name_str << "'"; // return type void @@ -570,6 +574,7 @@ int MXLoadLib(const char *path) { std::vector in_shapes, out_shapes; std::vector in_dims, out_dims; std::vector in_types, out_types; + std::vector in_verIDs, out_verIDs; // convert input tensors to constituent parts for (size_t i = 0; i < inputs.size(); i++) { @@ -577,6 +582,7 @@ int MXLoadLib(const char *path) { in_shapes.push_back(inputs[i].shape().data()); in_dims.push_back(inputs[i].shape().ndim()); in_types.push_back(inputs[i].dtype()); + in_verIDs.push_back(inputs[i].version()); } // convert output tensors to constituent parts @@ -585,6 +591,7 @@ int MXLoadLib(const char *path) { out_shapes.push_back(outputs[i].shape().data()); out_dims.push_back(outputs[i].shape().ndim()); out_types.push_back(outputs[i].dtype()); + out_verIDs.push_back(outputs[i].version()); } // get memory resource @@ -618,9 +625,9 @@ int MXLoadLib(const char *path) { // call fcompute function CHECK(callFStatefulComp(is_forward, state_op_inst, in_shapes.data(), in_dims.data(), - in_data.data(), in_types.data(), in_data.size(), - out_shapes.data(), out_dims.data(), out_data.data(), - out_types.data(), out_data.size(), cpu_malloc, &cpu_alloc)) + in_data.data(), in_types.data(), in_verIDs.data(), in_data.size(), + out_shapes.data(), out_dims.data(), out_data.data(), out_types.data(), + out_verIDs.data(), out_data.size(), cpu_malloc, &cpu_alloc)) << "Error calling FStatefulCompute for custom operator '" << name_str << "'"; };