Skip to content

Commit

Permalink
[coll] Use loky for rabit op tests. (#10828)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 20, 2024
1 parent 15c6172 commit d5e1c41
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 37 deletions.
9 changes: 7 additions & 2 deletions python-package/xgboost/testing/updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,13 @@ def check_extmem_qdm(
)

booster_it = xgb.train({"device": device}, Xy_it, num_boost_round=8)
X, y, w = it.as_arrays()
Xy = xgb.QuantileDMatrix(X, y, weight=w)
it = tm.IteratorForTest(
*tm.make_batches(
n_samples_per_batch, n_features, n_batches, use_cupy=device != "cpu"
),
cache=None,
)
Xy = xgb.QuantileDMatrix(it)
booster = xgb.train({"device": device}, Xy, num_boost_round=8)

if device == "cpu":
Expand Down
61 changes: 26 additions & 35 deletions tests/python/test_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,44 +34,48 @@ def test_socket_error():
tracker.free()


def run_rabit_ops(client, n_workers):
from xgboost.dask import CommunicatorContext, _get_dask_config, _get_rabit_args

workers = tm.get_client_workers(client)
rabit_args = client.sync(_get_rabit_args, len(workers), _get_dask_config(), client)
assert not collective.is_distributed()
n_workers_from_dask = len(workers)
assert n_workers == n_workers_from_dask
def run_rabit_ops(pool, n_workers: int, address: str) -> None:
tracker = RabitTracker(host_ip=address, n_workers=n_workers)
tracker.start()
args = tracker.worker_args()

def local_test(worker_id):
with CommunicatorContext(**rabit_args):
def local_test(worker_id: int, rabit_args: dict) -> int:
with collective.CommunicatorContext(**rabit_args):
a = 1
assert collective.is_distributed()
a = np.array([a])
reduced = collective.allreduce(a, collective.Op.SUM)
arr = np.array([a])
reduced = collective.allreduce(arr, collective.Op.SUM)
assert reduced[0] == n_workers

worker_id = np.array([worker_id])
reduced = collective.allreduce(worker_id, collective.Op.MAX)
arr = np.array([worker_id])
reduced = collective.allreduce(arr, collective.Op.MAX)
assert reduced == n_workers - 1

return 1

futures = client.map(local_test, range(len(workers)), workers=workers)
results = client.gather(futures)
fn = update_wrapper(partial(local_test, rabit_args=args), local_test)
results = pool.map(fn, range(n_workers))
assert sum(results) == n_workers


@pytest.mark.skipif(**tm.no_dask())
@pytest.mark.skipif(**tm.no_loky())
def test_rabit_ops():
from distributed import Client, LocalCluster
from loky import get_reusable_executor

n_workers = 3
with LocalCluster(n_workers=n_workers) as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)
n_workers = 4
with get_reusable_executor(max_workers=n_workers) as pool:
run_rabit_ops(pool, n_workers, "127.0.0.1")


@pytest.mark.skipif(**tm.no_ipv6())
@pytest.mark.skipif(**tm.no_loky())
def test_rabit_ops_ipv6():
from loky import get_reusable_executor

n_workers = 4
with get_reusable_executor(max_workers=n_workers) as pool:
run_rabit_ops(pool, n_workers, "::1")


def run_allreduce(pool, n_workers: int) -> None:
tracker = RabitTracker(host_ip="127.0.0.1", n_workers=n_workers)
Expand Down Expand Up @@ -133,19 +137,6 @@ def test_broadcast():
run_broadcast(pool, n_workers)


@pytest.mark.skipif(**tm.no_ipv6())
@pytest.mark.skipif(**tm.no_dask())
def test_rabit_ops_ipv6():
import dask
from distributed import Client, LocalCluster

n_workers = 3
with dask.config.set({"xgboost.scheduler_address": "[::1]"}):
with LocalCluster(n_workers=n_workers, host="[::1]") as cluster:
with Client(cluster) as client:
run_rabit_ops(client, n_workers)


@pytest.mark.skipif(**tm.no_dask())
def test_rank_assignment() -> None:
from distributed import Client, LocalCluster
Expand Down

0 comments on commit d5e1c41

Please sign in to comment.