Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fed] Fixes for the encrypted GRPC backend. #10503

Merged
merged 6 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions plugin/federated/federated_comm.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost contributors
* Copyright 2023-2024, XGBoost contributors
*/
#include "federated_comm.h"

Expand All @@ -11,6 +11,7 @@
#include <string> // for string, stoi

#include "../../src/common/common.h" // for Split
#include "../../src/common/io.h" // for ReadAll
#include "../../src/common/json_utils.h" // for OptionalArg
#include "xgboost/json.h" // for Json
#include "xgboost/logging.h"
Expand Down Expand Up @@ -46,9 +47,9 @@ void FederatedComm::Init(std::string const& host, std::int32_t port, std::int32_
} else {
stub_ = [&] {
grpc::SslCredentialsOptions options;
options.pem_root_certs = server_cert;
options.pem_private_key = client_key;
options.pem_cert_chain = client_cert;
options.pem_root_certs = common::ReadAll(server_cert);
options.pem_private_key = common::ReadAll(client_key);
options.pem_cert_chain = common::ReadAll(client_cert);
grpc::ChannelArguments args;
args.SetMaxReceiveMessageSize(std::numeric_limits<std::int32_t>::max());
auto channel = grpc::CreateCustomChannel(host + ":" + std::to_string(port),
Expand Down
14 changes: 10 additions & 4 deletions python-package/xgboost/federated.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def __init__( # pylint: disable=R0913, W0231
n_workers: int,
port: int,
secure: bool,
server_key_path: str = "",
server_cert_path: str = "",
client_cert_path: str = "",
server_key_path: Optional[str] = None,
server_cert_path: Optional[str] = None,
client_cert_path: Optional[str] = None,
timeout: int = 300,
) -> None:
handle = ctypes.c_void_p()
Expand Down Expand Up @@ -84,7 +84,13 @@ def run_federated_server( # pylint: disable=too-many-arguments
for path in [server_key_path, server_cert_path, client_cert_path]
)
tracker = FederatedTracker(
n_workers=n_workers, port=port, secure=secure, timeout=timeout
n_workers=n_workers,
port=port,
secure=secure,
timeout=timeout,
server_key_path=server_key_path,
server_cert_path=server_cert_path,
client_cert_path=client_cert_path,
)
tracker.start()

Expand Down
153 changes: 153 additions & 0 deletions python-package/xgboost/testing/federated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# pylint: disable=unbalanced-tuple-unpacking, too-many-locals
"""Tests for federated learning."""

import multiprocessing
import os
import subprocess
import tempfile
import time
from typing import List, cast

from sklearn.datasets import dump_svmlight_file, load_svmlight_file
from sklearn.model_selection import train_test_split

import xgboost as xgb
import xgboost.federated
from xgboost import testing as tm
from xgboost.training import TrainingCallback

SERVER_KEY = "server-key.pem"
SERVER_CERT = "server-cert.pem"
CLIENT_KEY = "client-key.pem"
CLIENT_CERT = "client-cert.pem"


def run_server(port: int, world_size: int, with_ssl: bool) -> None:
"""Run federated server for test."""
if with_ssl:
xgboost.federated.run_federated_server(
world_size,
port,
server_key_path=SERVER_KEY,
server_cert_path=SERVER_CERT,
client_cert_path=CLIENT_CERT,
)
else:
xgboost.federated.run_federated_server(world_size, port)


def run_worker(
port: int, world_size: int, rank: int, with_ssl: bool, device: str
) -> None:
"""Run federated client worker for test."""
communicator_env = {
"dmlc_communicator": "federated",
"federated_server_address": f"localhost:{port}",
"federated_world_size": world_size,
"federated_rank": rank,
}
if with_ssl:
communicator_env["federated_server_cert_path"] = SERVER_CERT
communicator_env["federated_client_key_path"] = CLIENT_KEY
communicator_env["federated_client_cert_path"] = CLIENT_CERT

cpu_count = os.cpu_count()
assert cpu_count is not None
n_threads = cpu_count // world_size

# Always call this before using distributed module
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
X, y = load_svmlight_file(f"agaricus.txt-{rank}.train")
dtrain = xgb.DMatrix(X, y)
X, y = load_svmlight_file(f"agaricus.txt-{rank}.test")
dtest = xgb.DMatrix(X, y)

# Specify parameters via map, definition are same as c++ version
param = {
"max_depth": 2,
"eta": 1,
"objective": "binary:logistic",
"nthread": n_threads,
"tree_method": "hist",
"device": device,
}

# Specify validations set to watch performance
watchlist = [(dtest, "eval"), (dtrain, "train")]
num_round = 20

# Run training, all the features in training API is available.
results: TrainingCallback.EvalsLog = {}
bst = xgb.train(
param,
dtrain,
num_round,
evals=watchlist,
early_stopping_rounds=2,
evals_result=results,
)
assert tm.non_increasing(cast(List[float], results["train"]["logloss"]))
assert tm.non_increasing(cast(List[float], results["eval"]["logloss"]))

# save the model, only ask process 0 to save the model.
if xgb.collective.get_rank() == 0:
with tempfile.TemporaryDirectory() as tmpdir:
bst.save_model(os.path.join(tmpdir, "model.json"))
xgb.collective.communicator_print("Finished training\n")


def run_federated(world_size: int, with_ssl: bool, use_gpu: bool) -> None:
"""Launcher for clients and the server."""
port = 9091

server = multiprocessing.Process(
target=run_server, args=(port, world_size, with_ssl)
)
server.start()
time.sleep(1)
if not server.is_alive():
raise ValueError("Error starting Federated Learning server")

workers = []
for rank in range(world_size):
device = f"cuda:{rank}" if use_gpu else "cpu"
worker = multiprocessing.Process(
target=run_worker, args=(port, world_size, rank, with_ssl, device)
)
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
server.terminate()


def run_federated_learning(with_ssl: bool, use_gpu: bool, test_path: str) -> None:
"""Run federated learning tests."""
n_workers = 2

if with_ssl:
command = "openssl req -x509 -newkey rsa:2048 -days 7 -nodes -keyout {part}-key.pem -out {part}-cert.pem -subj /C=US/CN=localhost" # pylint: disable=line-too-long
server_key = command.format(part="server").split()
subprocess.check_call(server_key)
client_key = command.format(part="client").split()
subprocess.check_call(client_key)

train_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.train")
test_path = os.path.join(tm.data_dir(test_path), "agaricus.txt.test")

X_train, y_train = load_svmlight_file(train_path)
X_test, y_test = load_svmlight_file(test_path)

X0, X1, y0, y1 = train_test_split(X_train, y_train, test_size=0.5)
X0_valid, X1_valid, y0_valid, y1_valid = train_test_split(
X_test, y_test, test_size=0.5
)

dump_svmlight_file(X0, y0, "agaricus.txt-0.train")
dump_svmlight_file(X0_valid, y0_valid, "agaricus.txt-0.test")

dump_svmlight_file(X1, y1, "agaricus.txt-1.train")
dump_svmlight_file(X1_valid, y1_valid, "agaricus.txt-1.test")

run_federated(world_size=n_workers, with_ssl=with_ssl, use_gpu=use_gpu)
5 changes: 4 additions & 1 deletion src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,11 @@ DeviceOrd CUDAOrdinal(DeviceOrd device, bool) {
}
if (device.IsCUDA()) {
device = CUDAOrdinal(device, fail_on_invalid_gpu_id);
if (!device.IsCUDA()) {
// We allow loading a GPU-based pickle on a CPU-only machine.
LOG(WARNING) << "XGBoost is not compiled with CUDA support.";
}
}

return device;
}
} // namespace
Expand Down
4 changes: 4 additions & 0 deletions tests/ci_build/lint_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class LintersPaths:
"tests/python/test_with_pandas.py",
"tests/python-gpu/",
"tests/python-sycl/",
"tests/test_distributed/test_federated/",
"tests/test_distributed/test_gpu_federated/",
"tests/test_distributed/test_with_dask/",
"tests/test_distributed/test_gpu_with_dask/",
"tests/test_distributed/test_with_spark/",
Expand Down Expand Up @@ -94,6 +96,8 @@ class LintersPaths:
"tests/python-gpu/load_pickle.py",
"tests/python-gpu/test_gpu_training_continuation.py",
"tests/python/test_model_io.py",
"tests/test_distributed/test_federated/",
"tests/test_distributed/test_gpu_federated/",
"tests/test_distributed/test_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_spark/test_data.py",
"tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py",
Expand Down
2 changes: 2 additions & 0 deletions tests/ci_build/test_python.sh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ case "$suite" in
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/python-gpu
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_dask
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_with_spark
pytest -v -s -rxXs --fulltrace --durations=0 -m "mgpu" ${args} tests/test_distributed/test_gpu_federated
unset_pyspark_envs
uninstall_xgboost
set +x
Expand All @@ -84,6 +85,7 @@ case "$suite" in
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/python
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_dask
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_with_spark
pytest -v -s -rxXs --fulltrace --durations=0 ${args} tests/test_distributed/test_federated
unset_pyspark_envs
uninstall_xgboost
set +x
Expand Down
17 changes: 0 additions & 17 deletions tests/test_distributed/test_federated/runtests-federated.sh

This file was deleted.

88 changes: 5 additions & 83 deletions tests/test_distributed/test_federated/test_federated.py
Original file line number Diff line number Diff line change
@@ -1,86 +1,8 @@
#!/usr/bin/python
import multiprocessing
import sys
import time
import pytest

import xgboost as xgb
import xgboost.federated
from xgboost.testing.federated import run_federated_learning

SERVER_KEY = 'server-key.pem'
SERVER_CERT = 'server-cert.pem'
CLIENT_KEY = 'client-key.pem'
CLIENT_CERT = 'client-cert.pem'


def run_server(port: int, world_size: int, with_ssl: bool) -> None:
if with_ssl:
xgboost.federated.run_federated_server(port, world_size, SERVER_KEY, SERVER_CERT,
CLIENT_CERT)
else:
xgboost.federated.run_federated_server(port, world_size)


def run_worker(port: int, world_size: int, rank: int, with_ssl: bool, with_gpu: bool) -> None:
communicator_env = {
'xgboost_communicator': 'federated',
'federated_server_address': f'localhost:{port}',
'federated_world_size': world_size,
'federated_rank': rank
}
if with_ssl:
communicator_env['federated_server_cert'] = SERVER_CERT
communicator_env['federated_client_key'] = CLIENT_KEY
communicator_env['federated_client_cert'] = CLIENT_CERT

# Always call this before using distributed module
with xgb.collective.CommunicatorContext(**communicator_env):
# Load file, file will not be sharded in federated mode.
dtrain = xgb.DMatrix('agaricus.txt.train-%02d?format=libsvm' % rank)
dtest = xgb.DMatrix('agaricus.txt.test-%02d?format=libsvm' % rank)

# Specify parameters via map, definition are same as c++ version
param = {'max_depth': 2, 'eta': 1, 'objective': 'binary:logistic'}
if with_gpu:
param['tree_method'] = 'hist'
param['device'] = f"cuda:{rank}"

# Specify validations set to watch performance
watchlist = [(dtest, 'eval'), (dtrain, 'train')]
num_round = 20

# Run training, all the features in training API is available.
bst = xgb.train(param, dtrain, num_round, evals=watchlist,
early_stopping_rounds=2)

# Save the model, only ask process 0 to save the model.
if xgb.collective.get_rank() == 0:
bst.save_model("test.model.json")
xgb.collective.communicator_print("Finished training\n")


def run_federated(with_ssl: bool = True, with_gpu: bool = False) -> None:
port = 9091
world_size = int(sys.argv[1])

server = multiprocessing.Process(target=run_server, args=(port, world_size, with_ssl))
server.start()
time.sleep(1)
if not server.is_alive():
raise Exception("Error starting Federated Learning server")

workers = []
for rank in range(world_size):
worker = multiprocessing.Process(target=run_worker,
args=(port, world_size, rank, with_ssl, with_gpu))
workers.append(worker)
worker.start()
for worker in workers:
worker.join()
server.terminate()


if __name__ == '__main__':
run_federated(with_ssl=True, with_gpu=False)
run_federated(with_ssl=False, with_gpu=False)
run_federated(with_ssl=True, with_gpu=True)
run_federated(with_ssl=False, with_gpu=True)
@pytest.mark.parametrize("with_ssl", [True, False])
def test_federated_learning(with_ssl: bool) -> None:
run_federated_learning(with_ssl, False, __file__)
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest

from xgboost.testing.federated import run_federated_learning


@pytest.mark.parametrize("with_ssl", [True, False])
@pytest.mark.mgpu
def test_federated_learning(with_ssl: bool) -> None:
run_federated_learning(with_ssl, True, __file__)
Loading