Skip to content

Commit

Permalink
[MXNET-23] Adding support to profile kvstore server during distribute…
Browse files Browse the repository at this point in the history
…d training (apache#11215)

* server profiling

merge with master

cleanup old code

added a check and better info message

add functions for C compatibility

fix doc

lint fixes

fix compile issues

lint fix

build error

update function signatures to preserve compatibility

fix comments

lint

* add part1 of test

* add integration test
  • Loading branch information
rahul003 authored and eric-haibin-lin committed Aug 4, 2018
1 parent 9d57660 commit 7f6099e
Show file tree
Hide file tree
Showing 12 changed files with 434 additions and 85 deletions.
1 change: 1 addition & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ integrationtest_ubuntu_cpu_dist_kvstore() {
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --no-multiprecision
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=compressed_cpu --no-multiprecision
../../tools/launch.py -n 3 --launcher local python test_server_profiling.py
}

integrationtest_ubuntu_gpu_scala() {
Expand Down
23 changes: 22 additions & 1 deletion example/image-classification/common/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@ def add_fit_args(parser):
help='the epochs to ramp-up lr to scaled large-batch value')
train.add_argument('--warmup-strategy', type=str, default='linear',
help='the ramping-up strategy for large batch sgd')
train.add_argument('--profile-worker-suffix', type=str, default='',
help='profile workers actions into this file. During distributed training\
filename saved will be rank1_ followed by this suffix')
train.add_argument('--profile-server-suffix', type=str, default='',
help='profile server actions into a file with name like rank1_ followed by this suffix \
during distributed training')
return train


Expand All @@ -150,6 +156,17 @@ def fit(args, network, data_loader, **kwargs):
if args.gc_type != 'none':
kv.set_gradient_compression({'type': args.gc_type,
'threshold': args.gc_threshold})
if args.profile_server_suffix:
mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server')
mx.profiler.set_state(state='run', profile_process='server')

if args.profile_worker_suffix:
if kv.num_workers > 1:
filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix
else:
filename = args.profile_worker_suffix
mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker')
mx.profiler.set_state(state='run', profile_process='worker')

# logging
head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
Expand Down Expand Up @@ -180,7 +197,6 @@ def fit(args, network, data_loader, **kwargs):
logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
args.disp_batches * args.batch_size / (time.time() - tic))
tic = time.time()

return

# load model
Expand Down Expand Up @@ -314,3 +330,8 @@ def fit(args, network, data_loader, **kwargs):
epoch_end_callback=checkpoint,
allow_missing=True,
monitor=monitor)

if args.profile_server_suffix:
mx.profiler.set_state(state='run', profile_process='server')
if args.profile_worker_suffix:
mx.profiler.set_state(state='run', profile_process='worker')
59 changes: 52 additions & 7 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,19 @@ MXNET_DLL int MXRandomSeedContext(int seed, int dev_type, int dev_id);
MXNET_DLL int MXNotifyShutdown();

/*!
* \brief Set up configuration of profiler
* \brief Set up configuration of profiler for the process passed as profile_process in keys
* \param num_params Number of parameters
* \param keys array of parameter keys
* \param vals array of parameter values
* \param kvstoreHandle handle to kvstore
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXSetProcessProfilerConfig(int num_params, const char* const* keys,
const char* const* vals,
KVStoreHandle kvstoreHandle);

/*!
* \brief Set up configuration of profiler for worker/current process
* \param num_params Number of parameters
* \param keys array of parameter keys
* \param vals array of parameter values
Expand All @@ -239,7 +251,21 @@ MXNET_DLL int MXNotifyShutdown();
MXNET_DLL int MXSetProfilerConfig(int num_params, const char* const* keys, const char* const* vals);

/*!
* \brief Set up state of profiler
* \brief Set up state of profiler for either worker or server process
* \param state indicate the working state of profiler,
* profiler not running when state == 0,
* profiler running when state == 1
* \param profile_process an int,
* when 0 command is for worker/current process,
* when 1 command is for server process
* \param kvstoreHandle handle to kvstore, needed for server process profiling
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXSetProcessProfilerState(int state, int profile_process,
KVStoreHandle kvStoreHandle);

/*!
* \brief Set up state of profiler for current process
* \param state indicate the working state of profiler,
* profiler not running when state == 0,
* profiler running when state == 1
Expand All @@ -250,11 +276,22 @@ MXNET_DLL int MXSetProfilerState(int state);
/*!
* \brief Save profile and stop profiler
* \param finished true if stat output should stop after this point
* \param profile_process an int,
* when 0 command is for worker/current process,
* when 1 command is for server process
* \param kvstoreHandle handle to kvstore
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXDumpProfile(int finished);
MXNET_DLL int MXDumpProcessProfile(int finished, int profile_process, KVStoreHandle kvStoreHandle);


/*!
* \brief Save profile and stop profiler for worker/current process
* \param finished true if stat output should stop after this point
* \return 0 when success, -1 when failure happens.
*/
MXNET_DLL int MXDumpProfile(int finished);

/*!
* \brief Print aggregate stats to the a string
* \param out_str Will receive a pointer to the output string
Expand All @@ -267,6 +304,16 @@ MXNET_DLL int MXAggregateProfileStatsPrint(const char **out_str, int reset);
/*!
* \brief Pause profiler tuning collection
* \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
* \param profile_process integer which denotes whether to process worker or server process
* \param kvstoreHandle handle to kvstore
* \return 0 when success, -1 when failure happens.
* \note pausing and resuming is global and not recursive
*/
MXNET_DLL int MXProcessProfilePause(int paused, int profile_process, KVStoreHandle kvStoreHandle);

/*!
* \brief Pause profiler tuning collection for worker/current process
* \param paused If nonzero, profiling pauses. Otherwise, profiling resumes/continues
* \return 0 when success, -1 when failure happens.
* \note pausing and resuming is global and not recursive
*/
Expand Down Expand Up @@ -2145,8 +2192,7 @@ typedef void (MXKVStoreServerController)(int head,
void *controller_handle);

/**
* \return Run as server (or scheduler)
*
* \brief Run as server (or scheduler)
* \param handle handle to the KVStore
* \param controller the user-defined server controller
* \param controller_handle helper handle for implementing controller
Expand All @@ -2157,8 +2203,7 @@ MXNET_DLL int MXKVStoreRunServer(KVStoreHandle handle,
void *controller_handle);

/**
* \return Send a command to all server nodes
*
* \brief Send a command to all server nodes
* \param handle handle to the KVStore
* \param cmd_id the head of the command
* \param cmd_body the body of the command
Expand Down
26 changes: 26 additions & 0 deletions include/mxnet/kvstore.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@
#endif // MXNET_USE_DIST_KVSTORE

namespace mxnet {

/*!
* \brief enum to denote types of commands kvstore sends to server regarding profiler
* kSetConfig sets profiler configs. Similar to mx.profiler.set_config()
* kState allows changing state of profiler to stop or run
* kPause allows pausing and resuming of profiler
* kDump asks profiler to dump output
*/
enum class KVStoreServerProfilerCommand {
kSetConfig, kState, kPause, kDump
};

/*!
* \brief distributed key-value store
*
Expand Down Expand Up @@ -364,6 +376,20 @@ class KVStore {
*/
virtual void SendCommandToServers(int cmd_id, const std::string& cmd_body) { }

/**
* \brief Sends server profiler commands to all server nodes
* Only the worker with rank=0 sends the command which will be received by all servers
* \param type ProfilerCommand type
* \param params parameters for that command in the form of a string
*/
virtual void SetServerProfilerCommand(const KVStoreServerProfilerCommand type,
const std::string& params) {
LOG(INFO) << "Unable to pass server the profiler command. If you are using "
<< "distributed kvstore, you need to compile with USE_DIST_KVSTORE=1."
<< "If you are training on single machine, then there is no server process"
<< "to profile. Please profile the worker process instead.";
}

/**
* \brief the prototype of a server controller
*/
Expand Down
8 changes: 6 additions & 2 deletions python/mxnet/kvstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .base import check_call, string_types, mx_uint, py_str
from .base import NDArrayHandle, KVStoreHandle
from . import optimizer as opt
from .profiler import set_kvstore_handle

def _ctype_key_value(keys, vals):
"""
Expand Down Expand Up @@ -88,7 +89,8 @@ def _get_kvstore_server_command_type(command):
'kSetMultiPrecision': 1,
'kStopServer': 2,
'kSyncMode': 3,
'kSetGradientCompression': 4}
'kSetGradientCompression': 4,
'kSetProfilerParams': 5}
assert (command in command_types), "Unknown command type to send to server"
return command_types[command]

Expand Down Expand Up @@ -670,4 +672,6 @@ def create(name='local'):
handle = KVStoreHandle()
check_call(_LIB.MXKVStoreCreate(c_str(name),
ctypes.byref(handle)))
return KVStore(handle)
kv = KVStore(handle)
set_kvstore_handle(kv.handle)
return kv
79 changes: 63 additions & 16 deletions python/mxnet/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@
from __future__ import absolute_import
import ctypes
import warnings
from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str
from .base import _LIB, check_call, c_str, ProfileHandle, c_str_array, py_str, KVStoreHandle

profiler_kvstore_handle = KVStoreHandle()

def set_kvstore_handle(handle):
global profiler_kvstore_handle
profiler_kvstore_handle = handle

def set_config(**kwargs):
"""Set up the configure of profiler (only accepts keyword arguments).
Expand All @@ -49,12 +54,17 @@ def set_config(**kwargs):
aggregate_stats : boolean,
whether to maintain aggregate stats in memory for console
dump. Has some negative performance impact.
profile_process : string
whether to profile kvstore `server` or `worker`.
server can only be profiled when kvstore is of type dist.
if this is not passed, defaults to `worker`
"""
kk = kwargs.keys()
vv = kwargs.values()
check_call(_LIB.MXSetProfilerConfig(len(kwargs),
c_str_array([key for key in kk]),
c_str_array([str(val) for val in vv])))
check_call(_LIB.MXSetProcessProfilerConfig(len(kwargs),
c_str_array([key for key in kk]),
c_str_array([str(val) for val in vv]),
profiler_kvstore_handle))


def profiler_set_config(mode='symbolic', filename='profile.json'):
Expand All @@ -73,20 +83,27 @@ def profiler_set_config(mode='symbolic', filename='profile.json'):
keys = c_str_array([key for key in ["profile_" + mode, "filename"]])
values = c_str_array([str(val) for val in [True, filename]])
assert len(keys) == len(values)
check_call(_LIB.MXSetProfilerConfig(len(keys), keys, values))
check_call(_LIB.MXSetProcessProfilerConfig(len(keys), keys, values, profiler_kvstore_handle))


def set_state(state='stop'):
def set_state(state='stop', profile_process='worker'):
"""Set up the profiler state to 'run' or 'stop'.
Parameters
----------
state : string, optional
Indicates whether to run the profiler, can
be 'stop' or 'run'. Default is `stop`.
profile_process : string
whether to profile kvstore `server` or `worker`.
server can only be profiled when kvstore is of type dist.
if this is not passed, defaults to `worker`
"""
state2int = {'stop': 0, 'run': 1}
check_call(_LIB.MXSetProfilerState(ctypes.c_int(state2int[state])))
profile_process2int = {'worker': 0, 'server': 1}
check_call(_LIB.MXSetProcessProfilerState(ctypes.c_int(state2int[state]),
profile_process2int[profile_process],
profiler_kvstore_handle))


def profiler_set_state(state='stop'):
Expand All @@ -102,7 +119,7 @@ def profiler_set_state(state='stop'):
'Please use profiler.set_state() instead')
set_state(state)

def dump(finished=True):
def dump(finished=True, profile_process='worker'):
"""Dump profile and stop profiler. Use this to save profile
in advance in case your program cannot exit normally.
Expand All @@ -111,9 +128,16 @@ def dump(finished=True):
finished : boolean
Indicates whether to stop statistic output (dumping) after this dump.
Default is True
profile_process : string
whether to profile kvstore `server` or `worker`.
server can only be profiled when kvstore is of type dist.
if this is not passed, defaults to `worker`
"""
fin = 1 if finished is True else False
check_call(_LIB.MXDumpProfile(fin))
fin = 1 if finished is True else 0
profile_process2int = {'worker': 0, 'server': 1}
check_call(_LIB.MXDumpProcessProfile(fin,
profile_process2int[profile_process],
profiler_kvstore_handle))


def dump_profile():
Expand All @@ -138,14 +162,37 @@ def dumps(reset=False):
return py_str(debug_str.value)


def pause():
"""Pause profiling."""
check_call(_LIB.MXProfilePause(int(1)))
def pause(profile_process='worker'):
"""Pause profiling.
Parameters
----------
profile_process : string
whether to profile kvstore `server` or `worker`.
server can only be profiled when kvstore is of type dist.
if this is not passed, defaults to `worker`
"""
profile_process2int = {'worker': 0, 'server': 1}
check_call(_LIB.MXProcessProfilePause(int(1),
profile_process2int[profile_process],
profiler_kvstore_handle))


def resume(profile_process='worker'):
"""
Resume paused profiling.
def resume():
"""Resume paused profiling."""
check_call(_LIB.MXProfilePause(int(0)))
Parameters
----------
profile_process : string
whether to profile kvstore `server` or `worker`.
server can only be profiled when kvstore is of type dist.
if this is not passed, defaults to `worker`
"""
profile_process2int = {'worker': 0, 'server': 1}
check_call(_LIB.MXProcessProfilePause(int(0),
profile_process2int[profile_process],
profiler_kvstore_handle))


class Domain(object):
Expand Down
Loading

0 comments on commit 7f6099e

Please sign in to comment.