Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
C Api for simplebind, fix comment for trigoops, add atol to assert
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaiBapchya committed Oct 23, 2019
1 parent 06b86da commit 36b5f24
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 39 deletions.
38 changes: 38 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2255,6 +2255,44 @@ MXNET_DLL int MXExecutorSimpleBindEx(SymbolHandle symbol_handle,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);


MXNET_DLL int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const uint32_t num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const uint32_t provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const uint32_t num_provided_arg_shapes,
const char** provided_arg_shape_names,
const int64_t* provided_arg_shape_data,
const uint32_t* provided_arg_shape_idx,
const uint32_t num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const uint32_t num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const uint32_t num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
uint32_t* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
uint32_t* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out);


/*!
* \brief DEPRECATED. Use MXExecutorReshapeEx instead.
* Return a new executor with the same symbol and shared memory,
Expand Down
110 changes: 74 additions & 36 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1695,42 +1695,80 @@ def simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
aux_state_handles = ctypes.POINTER(NDArrayHandle)()

try:
check_call(_LIB.MXExecutorSimpleBindEx(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_int,
array('I', provided_arg_shape_data)),
c_array_buf(mx_uint,
array('i', provided_arg_shape_idx)),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
num_provided_arg_stypes,
provided_arg_stype_names,
provided_arg_stype_data,
mx_uint(len(shared_arg_name_list)),
c_str_array(shared_arg_name_list),
ctypes.byref(shared_buffer_len),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
if sys.version_info[0] > 2 and _int64_enabled():
check_call(_LIB.MXExecutorSimpleBindEx64(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_int64,
array('q', provided_arg_shape_data)),
c_array_buf(mx_uint,
array('i', provided_arg_shape_idx)),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
num_provided_arg_stypes,
provided_arg_stype_names,
provided_arg_stype_data,
mx_uint(len(shared_arg_name_list)),
c_str_array(shared_arg_name_list),
ctypes.byref(shared_buffer_len),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
else:
check_call(_LIB.MXExecutorSimpleBindEx(self.handle,
ctypes.c_int(ctx.device_typeid),
ctypes.c_int(ctx.device_id),
num_ctx_map_keys,
ctx_map_keys,
ctx_map_dev_types,
ctx_map_dev_ids,
mx_uint(provided_req_type_list_len),
provided_grad_req_names,
provided_grad_req_types,
mx_uint(len(provided_arg_shape_names)),
c_str_array(provided_arg_shape_names),
c_array_buf(mx_int,
array('I', provided_arg_shape_data)),
c_array_buf(mx_uint,
array('i', provided_arg_shape_idx)),
num_provided_arg_types,
provided_arg_type_names,
provided_arg_type_data,
num_provided_arg_stypes,
provided_arg_stype_names,
provided_arg_stype_data,
mx_uint(len(shared_arg_name_list)),
c_str_array(shared_arg_name_list),
ctypes.byref(shared_buffer_len),
shared_buffer_names,
shared_buffer_handles,
ctypes.byref(updated_shared_buffer_names),
ctypes.byref(updated_shared_buffer_handles),
ctypes.byref(num_in_args),
ctypes.byref(in_arg_handles),
ctypes.byref(arg_grad_handles),
ctypes.byref(num_aux_states),
ctypes.byref(aux_state_handles),
shared_exec_handle,
ctypes.byref(exe_handle)))
except MXNetError as e:
error_msg = "simple_bind error. Arguments:\n"
for k, v in kwargs.items():
Expand Down
105 changes: 105 additions & 0 deletions src/c_api/c_api_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,111 @@ int MXExecutorSimpleBindEx(SymbolHandle symbol_handle,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out) {
return SimpleBindExMaster(symbol_handle,
dev_type, dev_id,
num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids,
provided_grad_req_list_len, provided_grad_req_names, provided_grad_req_types,
num_provided_arg_shapes, provided_arg_shape_names,
provided_arg_shape_data, provided_arg_shape_idx,
num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes,
num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stypes,
num_shared_arg_names, shared_arg_name_list,
shared_buffer_len, shared_buffer_name_list,
shared_buffer_handle_list, updated_shared_buffer_name_list,
updated_shared_buffer_handle_list,
num_in_args, in_args, arg_grads,
num_aux_states, aux_states,
shared_exec_handle, out)
}


int MXExecutorSimpleBindEx64(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const uint32_t num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const uint32_t provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const uint32_t num_provided_arg_shapes,
const char** provided_arg_shape_names,
const int64_t* provided_arg_shape_data,
const uint32_t* provided_arg_shape_idx,
const uint32_t num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const uint32_t num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const uint32_t num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
uint32_t* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
uint32_t* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out) {
return SimpleBindExMaster(symbol_handle,
dev_type, dev_id,
num_g2c_keys, g2c_keys, g2c_dev_types, g2c_dev_ids,
provided_grad_req_list_len, provided_grad_req_names, provided_grad_req_types,
num_provided_arg_shapes, provided_arg_shape_names,
provided_arg_shape_data, provided_arg_shape_idx,
num_provided_arg_dtypes, provided_arg_dtype_names, provided_arg_dtypes,
num_provided_arg_stypes, provided_arg_stype_names, provided_arg_stypes,
num_shared_arg_names, shared_arg_name_list,
shared_buffer_len, shared_buffer_name_list,
shared_buffer_handle_list, updated_shared_buffer_name_list,
updated_shared_buffer_handle_list,
num_in_args, in_args, arg_grads,
num_aux_states, aux_states,
shared_exec_handle, out)
}


template<typename DType>
int SimpleBindExMaster(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const uint32_t num_g2c_keys,
const char** g2c_keys,
const int* g2c_dev_types,
const int* g2c_dev_ids,
const uint32_t provided_grad_req_list_len,
const char** provided_grad_req_names,
const char** provided_grad_req_types,
const uint32_t num_provided_arg_shapes,
const char** provided_arg_shape_names,
const DType* provided_arg_shape_data,
const uint32_t* provided_arg_shape_idx,
const uint32_t num_provided_arg_dtypes,
const char** provided_arg_dtype_names,
const int* provided_arg_dtypes,
const uint32_t num_provided_arg_stypes,
const char** provided_arg_stype_names,
const int* provided_arg_stypes,
const uint32_t num_shared_arg_names,
const char** shared_arg_name_list,
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
uint32_t* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
uint32_t* num_aux_states,
NDArrayHandle** aux_states,
ExecutorHandle shared_exec_handle,
ExecutorHandle* out) {
MXAPIThreadLocalEntry<> *ret = MXAPIThreadLocalStore<>::Get();
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol*>(symbol_handle);
Expand Down
6 changes: 3 additions & 3 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,17 +1295,17 @@ def check_trunc():


def create_input_for_trigonometric_ops(vals):
# Creates large vector input of size(LARGE_X*10, SMALL_Y/10) from vals using tile operator
# Creates large vector input of size(LARGE_X*10, SMALL_Y/10) from vals using broadcast_to operator
inp = nd.array(vals).reshape(1, 5)
inp = nd.broadcast_to(inp, (LARGE_X*10, SMALL_Y//10))
return inp


def assert_correctness_of_trigonometric_ops(output, expected_vals):
def assert_correctness_of_trigonometric_ops(output, expected_vals, atol=1e-3):
# checks verifies 5 values at positions(0, 1, -3, -2, -1) of the input vector
output_idx_to_inspect = [0, 1, -3, -2, -1]
for i in range(len(output_idx_to_inspect)):
assert np.abs(output[1][output_idx_to_inspect[i]].asnumpy()-expected_vals[i]) <= 1e-3
assert np.abs(output[1][output_idx_to_inspect[i]].asnumpy()-expected_vals[i]) <= atol


def test_trigonometric_ops():
Expand Down

0 comments on commit 36b5f24

Please sign in to comment.