From 88d45456ec0772b6126d3d664006fc4eae590329 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 21 Sep 2015 16:31:51 -0700 Subject: [PATCH] BUGFIX --- python/mxnet/model.py | 5 +++-- tests/python/train/test_mlp.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 3466bf9362d9..a84244a2a777 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -207,8 +207,9 @@ def _train_multi_device(symbol, ctx, input_shape, # If there are multiple devices, initialize the weights. for index, pair in enumerate(zip(arg_blocks, grad_blocks)): arg_list, grad_list = pair - if kv and grad_list[0] is not None: - kv.init(index, arg_list[0]) + if grad_list[0] is not None: + if kv: + kv.init(index, arg_list[0]) # attach state direct to weight opt_list = [optimizer.create_state(index, w) for w in arg_list] opt_state_blocks.append(opt_list) diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index b0849a3e81d9..bd635e980297 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -18,7 +18,8 @@ num_round = 4 prefix = './mlp' -model = mx.model.FeedForward(softmax, [mx.cpu()] * 2, +model = mx.model.FeedForward(softmax, + [mx.cpu(i) for i in range(2)], num_round=num_round, learning_rate=0.01, wd=0.0004, momentum=0.9)