Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Backport][v1.5.x] Fix the bug of MXEnginePushAsyncND and MXEnginePushSyncND #15775

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2863,12 +2863,12 @@ MXNET_DLL int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
* \param wait Whether this is a WaitForVar operation.
*/
MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
bool wait DEFAULT(false));
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL),
bool wait DEFAULT(false));

/*!
* \brief Push a synchronous operation to the engine.
Expand All @@ -2886,11 +2886,11 @@ MXNET_DLL int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
* \param opr_name The operation name.
*/
MXNET_DLL int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle DEFAULT(NULL),
int priority DEFAULT(0), const char* opr_name DEFAULT(NULL));

#ifdef __cplusplus
}
Expand Down
40 changes: 20 additions & 20 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1535,18 +1535,18 @@ int MXEnginePushSync(EngineSyncFunc sync_func, void* func_param,
}

int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name, bool wait) {
API_BEGIN();
NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name, bool wait) {
API_BEGIN();
NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle);
NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
std::vector<VarHandle> const_var_vec(num_const_nds);
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var();
std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var();
return MXEnginePushAsync(async_func, func_param, deleter, ctx_handle,
const_var_vec.data(), num_const_nds,
mutable_var_vec.data(), num_mutable_nds,
Expand All @@ -1555,18 +1555,18 @@ int MXEnginePushAsyncND(EngineAsyncFunc async_func, void* func_param,
}

int MXEnginePushSyncND(EngineSyncFunc sync_func, void* func_param,
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle const_nds_handle, int num_const_nds,
NDArrayHandle mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name) {
API_BEGIN();
NDArray* const_nds = static_cast<NDArray*>(const_nds_handle);
NDArray* mutable_nds = static_cast<NDArray*>(mutable_nds_handle);
EngineFuncParamDeleter deleter, ContextHandle ctx_handle,
NDArrayHandle* const_nds_handle, int num_const_nds,
NDArrayHandle* mutable_nds_handle, int num_mutable_nds,
EngineFnPropertyHandle prop_handle, int priority,
const char* opr_name) {
API_BEGIN();
NDArray** const_nds = reinterpret_cast<NDArray**>(const_nds_handle);
NDArray** mutable_nds = reinterpret_cast<NDArray**>(mutable_nds_handle);
std::vector<VarHandle> const_var_vec(num_const_nds);
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = (const_nds+i)->var();
for (int i = 0; i < num_const_nds; ++i) const_var_vec[i] = const_nds[i]->var();
std::vector<VarHandle> mutable_var_vec(num_mutable_nds);
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = (mutable_nds+i)->var();
for (int i = 0; i < num_mutable_nds; ++i) mutable_var_vec[i] = mutable_nds[i]->var();
return MXEnginePushSync(sync_func, func_param, deleter, ctx_handle,
const_var_vec.data(), num_const_nds,
mutable_var_vec.data(), num_mutable_nds,
Expand Down
117 changes: 74 additions & 43 deletions tests/cpp/engine/threaded_engine_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,49 +257,80 @@ TEST(Engine, PushFunc) {

TEST(Engine, PushFuncND) {
auto ctx = mxnet::Context{};
mxnet::NDArray nd(ctx);

// Test #1
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
int* a = new int(100);
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
EXPECT_EQ(res, 0);

// Test #2
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 0);
EXPECT_EQ(res, 0);

// Test #3
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
EXPECT_EQ(res, -1);

// Test #4
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
EXPECT_EQ(res, -1);

// Test #5
LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
int* b = new int(101);
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx, &nd, 1, nullptr, 0);
EXPECT_EQ(res, 0);

// Test #6
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, 1);
EXPECT_EQ(res, 0);

// Test #7
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, &nd, -1, nullptr, 0);
EXPECT_EQ(res, -1);

// Test #8
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx, nullptr, 0, &nd, -1);
EXPECT_EQ(res, -1);
std::vector<mxnet::NDArray*> nds;
const int num_nds = 5;
for (int i = 0; i < num_nds; ++i) {
mxnet::NDArray *pnd = new mxnet::NDArray(ctx);
nds.push_back(pnd);
}
for (int num_const_nds = 0; num_const_nds <= num_nds; ++num_const_nds) {
int num_mutable_nds = num_nds - num_const_nds;
void** const_nds_handle = num_const_nds > 0 ?
reinterpret_cast<void**>(nds.data()) : nullptr;
void** mutable_nds_handle = num_mutable_nds > 0 ?
reinterpret_cast<void**>(nds.data() + num_const_nds) : nullptr;

// Test #1
LOG(INFO) << "===== Test #1: PushAsyncND param and deleter =====";
int* a = new int(100);
int res = MXEnginePushAsyncND(FooAsyncFunc, a, FooFuncDeleter, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #2
LOG(INFO) << "===== Test #2: PushAsyncND NULL param and NULL deleter =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #3
LOG(INFO) << "===== Test #3: PushAsyncND invalid number of const nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, -1,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, -1);

// Test #4
LOG(INFO) << "===== Test #4: PushAsyncND invalid number of mutable nds =====";
res = MXEnginePushAsyncND(FooAsyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, -1);
EXPECT_EQ(res, -1);

// Test #5
LOG(INFO) << "===== Test #5: PushSyncND param and deleter =====";
int* b = new int(101);
res = MXEnginePushSyncND(FooSyncFunc, b, FooFuncDeleter, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #6
LOG(INFO) << "===== Test #6: PushSyncND NULL param and NULL deleter =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, 0);

// Test #7
LOG(INFO) << "===== Test #7: PushSyncND invalid number of const nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, -1,
mutable_nds_handle, num_mutable_nds);
EXPECT_EQ(res, -1);

// Test #8
LOG(INFO) << "===== Test #8: PushSyncND invalid number of mutable nds =====";
res = MXEnginePushSyncND(FooSyncFunc, nullptr, nullptr, &ctx,
const_nds_handle, num_const_nds,
mutable_nds_handle, -1);
EXPECT_EQ(res, -1);
}
for (mxnet::NDArray* pnd : nds) {
delete pnd;
}
}

TEST(Engine, basics) {
Expand Down