From fa6ea1a08c832b63c801a61491a6f8b47f60b478 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Sun, 15 Mar 2020 05:29:03 +0000 Subject: [PATCH] Add horood test to CI --- ci/docker/runtime_functions.sh | 2 + python/mxnet/kvstore/horovod.py | 18 +++-- .../dist_device_sync_kvstore_horovod.py | 78 +++++++++++++++++++ tests/python/unittest/test_kvstore_horovod.py | 35 ++------- 4 files changed, 98 insertions(+), 35 deletions(-) create mode 100644 tests/nightly/dist_device_sync_kvstore_horovod.py diff --git a/ci/docker/runtime_functions.sh b/ci/docker/runtime_functions.sh index c5ecdd70c523..d272867f3808 100755 --- a/ci/docker/runtime_functions.sh +++ b/ci/docker/runtime_functions.sh @@ -1318,11 +1318,13 @@ integrationtest_ubuntu_gpu_dist_kvstore() { export MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 export MXNET_SUBGRAPH_VERBOSE=0 export DMLC_LOG_STACK_TRACE_DEPTH=10 + pip3 install --no-cache-dir horovod cd tests/nightly/ python3 ../../tools/launch.py -n 4 --launcher local python3 dist_device_sync_kvstore.py python3 ../../tools/launch.py -n 4 --launcher local python3 dist_device_sync_kvstore_custom.py python3 ../../tools/launch.py --p3 -n 4 --launcher local python3 dist_device_sync_kvstore_custom.py python3 ../../tools/launch.py -n 4 --launcher local python3 dist_sync_kvstore.py --type=init_gpu + mpirun -n 2 -H localhost:2 python3 dist_device_sync_kvstore_horovod.py popd } diff --git a/python/mxnet/kvstore/horovod.py b/python/mxnet/kvstore/horovod.py index 96bc4d9b2e20..776655a2898f 100644 --- a/python/mxnet/kvstore/horovod.py +++ b/python/mxnet/kvstore/horovod.py @@ -34,21 +34,20 @@ def __init__(self): def type(self): return 'horovod' - def broadcast(self, key, value, out=None, priority=0): + def broadcast(self, key, value, out, priority=0): """ Broadcast the `value` NDArray at rank 0 to all ranks Parameters ---------- key : str, or int - The key. + The key is used to name the tensor for allreduce. Its + usage is different from that of parameter servers. value : NDArray - The value corresponding to the key to broadcast. If `out` is not specified, - `value` NDArray will be updated in-place. + The tensor that is to be broadcasted. out : NDArray, list of NDArray Output tensor that receives value broadcasted from root process - If not specified, output will be written to `value` priority : int, optional The priority of the operation. @@ -123,11 +122,14 @@ def pushpull(self, key, value, out=None, priority=0): if out is None: value = value if isinstance(value, list) else [value] for v in value: - hvd.allreduce_(v, average=False, name=str(key), priority=priority) + hvd.allreduce_(v, average=False, name=str(key), + priority=priority) else: out = out if isinstance(out, list) else [out] - for o in out: - o[:] = hvd.allreduce(value, average=False, name=str(key), priority=priority) + value = value if isinstance(value, list) else [value] + for o, v in zip(out, value): + o[:] = hvd.allreduce(v, average=False, name=str(key), + priority=priority) def set_optimizer(self, optimizer): pass diff --git a/tests/nightly/dist_device_sync_kvstore_horovod.py b/tests/nightly/dist_device_sync_kvstore_horovod.py new file mode 100644 index 000000000000..ac44d5a2b104 --- /dev/null +++ b/tests/nightly/dist_device_sync_kvstore_horovod.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys +sys.path.insert(0, "../../python/") +import mxnet as mx +import numpy as np +import numpy.random as rnd +import time +import argparse + +# parser +parser = argparse.ArgumentParser(description='kvstore test') +args = parser.parse_args() + + +def check_diff_to_scalar(A, x, rank=None): + """ assert A == x""" + assert(np.sum(np.abs((A - x).asnumpy())) == 0), (rank, A.asnumpy(), x) + + +# setup +keys = ['3', '5', '7'] +init_test_keys = [str(i) for i in range(200,300)] +init_test_keys_big = [str(i) for i in range(300,400)] +init_test_keys_device = [str(i) for i in range(400,500)] +init_test_keys_device_big = [str(i) for i in range(500,600)] + +shape = (2, 3) +big_shape = (1200, 1200) # bigger than MXNET_KVSTORE_BIGARRAY_BOUND + +kv = mx.kv.create('horovod') +my_rank = kv.rank +my_num_workers = kv.num_workers + + +def test_pushpull(): + ctx = mx.gpu(kv.local_rank) if mx.context.num_gpus() > 0 else mx.cpu(kv.local_rank) + scale = kv.rank + 1 + tensor = mx.nd.ones(shape, ctx) * scale + kv.pushpull('3', tensor) + + expected = (kv.num_workers + 1) * kv.num_workers / 2 + check_diff_to_scalar(tensor, expected) + + +def test_broadcast(): + ctx = mx.gpu(kv.local_rank) if mx.context.num_gpus() > 0 else mx.cpu(kv.local_rank) + val = mx.nd.zeros(shape, ctx) + kv.broadcast('0', mx.nd.ones(shape), out=val) + expected = 1 + check_diff_to_scalar(val, expected, kv.rank) + + +def test_type(): + assert kv.type == 'horovod' + + +if __name__ == "__main__": + test_type() + test_broadcast() + test_pushpull() diff --git a/tests/python/unittest/test_kvstore_horovod.py b/tests/python/unittest/test_kvstore_horovod.py index e99a0db6c9ab..3f801a814f14 100644 --- a/tests/python/unittest/test_kvstore_horovod.py +++ b/tests/python/unittest/test_kvstore_horovod.py @@ -25,33 +25,17 @@ def test_horovod_basic(): kv = mx.kv.create('horovod') assert kv.type == 'horovod' - print('TEST num_worker: {}'.format(kv.num_workers)) - print('TEST rank: {}'.format(kv.rank)) - print('TEST local_rank: {}'.format(kv.local_rank)) - # assert kv.num_workers == 1 - # assert kv.rank == 0 - # assert kv.local_rank == 0 + assert kv.num_workers == 1 + assert kv.rank == 0 + assert kv.local_rank == 0 def test_horovod_broadcast(): - # broadcast a single key-value pair - kv = mx.kv.create('horovod') - a = mx.nd.ones(shape) * kv.rank - expected = np.zeros(shape) - kv.broadcast('1', value=a) - if kv.rank != 0: - print('TEST broadcast value: \n{}'.format(a.asnumpy())) - # assert a.asnumpy().all() == expected.all() - - -def test_horovod_broadcast_inplace(): kv = mx.kv.create('horovod') a = mx.nd.ones(shape) * kv.rank b = mx.nd.zeros(shape) kv.broadcast('1', value=a, out=b) - if kv.rank != 0: - print('TEST broadcast inplace value: \n{}'.format(b.asnumpy())) - # assert a.asnumpy().all() == b.asnumpy().all() + assert a.asnumpy().all() == b.asnumpy().all() def test_horovod_allreduce(): @@ -59,12 +43,9 @@ def test_horovod_allreduce(): nworker = kv.num_workers a = mx.nd.ones(shape) kv.pushpull('1', a) - print('TEST allreduce: \n{}'.format(a.asnumpy())) + assert a.asnumpy().all() == np.ones(shape).all() -test_horovod_basic() -test_horovod_broadcast() -test_horovod_allreduce() -# if __name__ == '__main__': -# import nose -# nose.runmodule() +if __name__ == '__main__': + import nose + nose.runmodule()