Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #5 from dmlc/master
Browse files Browse the repository at this point in the history
merge dmlc/master
  • Loading branch information
mli committed Sep 22, 2015
2 parents d4f2674 + 3e625ae commit b60b496
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ def _train_multi_device(symbol, ctx, input_shape,
# If there are multiple devices, initialize the weights.
for index, pair in enumerate(zip(arg_blocks, grad_blocks)):
arg_list, grad_list = pair
if kv and grad_list[0] is not None:
kv.init(index, arg_list[0])
if grad_list[0] is not None:
if kv:
kv.init(index, arg_list[0])
# attach state direct to weight
opt_list = [optimizer.create_state(index, w) for w in arg_list]
opt_state_blocks.append(opt_list)
Expand Down
3 changes: 2 additions & 1 deletion tests/python/train/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@

num_round = 4
prefix = './mlp'
model = mx.model.FeedForward(softmax, [mx.cpu()] * 2,
model = mx.model.FeedForward(softmax,
[mx.cpu(i) for i in range(2)],
num_round=num_round,
learning_rate=0.01, wd=0.0004,
momentum=0.9)
Expand Down

0 comments on commit b60b496

Please sign in to comment.