Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
auto select update on kvstore
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Sep 28, 2015
1 parent e373355 commit e335a26
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _train_multi_device(symbol, ctx, input_shape,
begin_round, end_round, optimizer,
train_data, eval_data=None, eval_metric=None,
iter_end_callback=None, epoch_end_callback=None,
update_on_kvstore=False,
update_on_kvstore=None,
logger=None):
"""Internal training function on multiple devices.
Expand Down Expand Up @@ -183,8 +183,9 @@ def _train_multi_device(symbol, ctx, input_shape,
-----
- This function will inplace update the NDArrays in arg_parans and aux_states.
- Turning update_on_kvstore on and off can affect speed of multi-gpu training.
- update_on_kvstore=True works well for inception type nets that contains many small weights.
- update_on_kvstore=False works better for Alexnet style net with bulk weights.
- It is auto selected by default.
- update_on_kvstore=True works well for inception type nets that contains many small weights.
- update_on_kvstore=False works better for Alexnet style net with bulk weights.
"""
if logger is None:
logger = logging
Expand All @@ -210,10 +211,17 @@ def _train_multi_device(symbol, ctx, input_shape,

for texec in train_execs:
texec.copy_params_from(arg_params, aux_params)

# ky value store
kv = kvstore.create() if num_device != 1 else None
if kv is None:
update_on_kvstore = False
else:
# auto decide update_on_kvstore
if update_on_kvstore is None:
max_size = max(np.prod(param.shape) for param in arg_params.values())
update_on_kvstore = max_size < 1024 * 1024 * 16
logging.info('Auto-select update_on_kvstore=%s', str(update_on_kvstore))

opt_state_blocks = []
# If there are multiple devices, initialize the weights.
Expand Down Expand Up @@ -586,7 +594,7 @@ def predict(self, X):

def fit(self, X, y=None, eval_data=None, eval_metric='acc',
iter_end_callback=None, epoch_end_callback=None,
update_on_kvstore=False, logger=None):
update_on_kvstore=None, logger=None):
"""Fit the model.
Parameters
Expand Down Expand Up @@ -618,6 +626,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc',
update_on_kvstore: boolean, optional
Whether to perform parameter update on kvstore instead of training device.
By default, the trainer will automatically decide the policy.
logger : logging logger, optional
When not specified, default logger will be used.
Expand Down Expand Up @@ -711,7 +720,7 @@ def load(prefix, iteration, ctx=None):
def create(symbol, X, y=None, ctx=None,
num_round=None, optimizer='sgd', initializer=Xavier(),
eval_data=None, eval_metric='acc', iter_end_callback=None,
update_on_kvstore=False, logger=None, **kwargs):
update_on_kvstore=None, logger=None, **kwargs):
"""Functional style to create a model.
This function will be more consistent with functional
Expand Down Expand Up @@ -755,6 +764,7 @@ def create(symbol, X, y=None, ctx=None,
update_on_kvstore: boolean, optional
Whether to perform parameter update on kvstore instead of training device.
By default, the trainer will automatically decide the policy.
logger : logging logger, optional
"""
Expand Down

0 comments on commit e335a26

Please sign in to comment.