Skip to content

Commit

Permalink
Work with IPv6 in the new tracker. (#10125)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Mar 19, 2024
1 parent 53fc175 commit ca4801f
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 59 deletions.
42 changes: 26 additions & 16 deletions include/xgboost/collective/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,28 +436,38 @@ class TCPSocket {
* \brief Accept new connection, returns a new TCP socket for the new connection.
*/
TCPSocket Accept() {
HandleT newfd = accept(Handle(), nullptr, nullptr);
SockAddress addr;
TCPSocket newsock;
auto rc = this->Accept(&newsock, &addr);
SafeColl(rc);
return newsock;
}

[[nodiscard]] Result Accept(TCPSocket *out, SockAddress *addr) {
#if defined(_WIN32)
auto interrupt = WSAEINTR;
#else
auto interrupt = EINTR;
#endif
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
system::ThrowAtError("accept");
}
TCPSocket newsock{newfd};
return newsock;
}

[[nodiscard]] Result Accept(TCPSocket *out, SockAddrV4 *addr) {
struct sockaddr_in caddr;
socklen_t caddr_len = sizeof(caddr);
HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
if (newfd == InvalidSocket()) {
return system::FailWithCode("Failed to accept.");
if (this->Domain() == SockDomain::kV4) {
struct sockaddr_in caddr;
socklen_t caddr_len = sizeof(caddr);
HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
return system::FailWithCode("Failed to accept.");
}
*addr = SockAddress{SockAddrV4{caddr}};
*out = TCPSocket{newfd};
} else {
struct sockaddr_in6 caddr;
socklen_t caddr_len = sizeof(caddr);
HandleT newfd = accept(Handle(), reinterpret_cast<sockaddr *>(&caddr), &caddr_len);
if (newfd == InvalidSocket() && system::LastError() != interrupt) {
return system::FailWithCode("Failed to accept.");
}
*addr = SockAddress{SockAddrV6{caddr}};
*out = TCPSocket{newfd};
}
*addr = SockAddrV4{caddr};
*out = TCPSocket{newfd};
return Success();
}

Expand Down
4 changes: 2 additions & 2 deletions python-package/xgboost/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,8 @@ def make_categorical(
categories = np.arange(0, n_categories)
for col in df.columns:
if rng.binomial(1, cat_ratio, size=1)[0] == 1:
df[col] = df[col].astype("category")
df[col] = df[col].cat.set_categories(categories)
df.loc[:, col] = df[col].astype("category")
df.loc[:, col] = df[col].cat.set_categories(categories)

if sparsity > 0.0:
for i in range(n_features):
Expand Down
11 changes: 7 additions & 4 deletions src/collective/coll.cc
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include "coll.h"

#include <algorithm> // for min, max, copy_n
#include <cstddef> // for size_t
#include <cstdint> // for int8_t, int64_t
#include <functional> // for bit_and, bit_or, bit_xor, plus
#include <string> // for string
#include <type_traits> // for is_floating_point_v, is_same_v
#include <utility> // for move

Expand Down Expand Up @@ -56,6 +57,8 @@ bool constexpr IsFloatingPointV() {
return cpu_impl::RingAllreduce(comm, data, erased_fn, type);
};

std::string msg{"Floating point is not supported for bit wise collective operations."};

auto rc = DispatchDType(type, [&](auto t) {
using T = decltype(t);
switch (op) {
Expand All @@ -70,21 +73,21 @@ bool constexpr IsFloatingPointV() {
}
case Op::kBitwiseAND: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
return Fail(msg);
} else {
return fn(std::bit_and<>{}, t);
}
}
case Op::kBitwiseOR: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
return Fail(msg);
} else {
return fn(std::bit_or<>{}, t);
}
}
case Op::kBitwiseXOR: {
if constexpr (IsFloatingPointV<T>()) {
return Fail("Invalid type.");
return Fail(msg);
} else {
return fn(std::bit_xor<>{}, t);
}
Expand Down
25 changes: 17 additions & 8 deletions src/collective/comm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
} << [&] {
return next->NonBlocking(true);
} << [&] {
SockAddrV4 addr;
SockAddress addr;
return listener->Accept(prev.get(), &addr);
} << [&] { return prev->NonBlocking(true); };
} << [&] {
return prev->NonBlocking(true);
};
if (!rc.OK()) {
return rc;
}
Expand Down Expand Up @@ -157,10 +159,13 @@ Result ConnectTrackerImpl(proto::PeerInfo info, std::chrono::seconds timeout, st
}

for (std::int32_t r = 0; r < comm.Rank(); ++r) {
SockAddrV4 addr;
auto peer = std::shared_ptr<TCPSocket>(TCPSocket::CreatePtr(comm.Domain()));
rc = std::move(rc) << [&] { return listener->Accept(peer.get(), &addr); }
<< [&] { return peer->RecvTimeout(timeout); };
rc = std::move(rc) << [&] {
SockAddress addr;
return listener->Accept(peer.get(), &addr);
} << [&] {
return peer->RecvTimeout(timeout);
};
if (!rc.OK()) {
return rc;
}
Expand All @@ -187,7 +192,9 @@ RabitComm::RabitComm(std::string const& host, std::int32_t port, std::chrono::se
: HostComm{std::move(host), port, timeout, retry, std::move(task_id)},
nccl_path_{std::move(nccl_path)} {
auto rc = this->Bootstrap(timeout_, retry_, task_id_);
CHECK(rc.OK()) << rc.Report();
if (!rc.OK()) {
SafeColl(Fail("Failed to bootstrap the communication group.", std::move(rc)));
}
}

#if !defined(XGBOOST_USE_NCCL)
Expand Down Expand Up @@ -247,18 +254,20 @@ Comm* RabitComm::MakeCUDAVar(Context const*, std::shared_ptr<Coll>) const {
// get ring neighbors
std::string snext;
tracker.Recv(&snext);
if (!rc.OK()) {
return Fail("Failed to receive the rank for the next worker.", std::move(rc));
}
auto jnext = Json::Load(StringView{snext});

proto::PeerInfo ninfo{jnext};

// get the rank of this worker
this->rank_ = BootstrapPrev(ninfo.rank, world);
this->tracker_.rank = rank_;

std::vector<std::shared_ptr<TCPSocket>> workers;
rc = ConnectWorkers(*this, &listener, lport, ninfo, timeout, retry, &workers);
if (!rc.OK()) {
return rc;
return Fail("Failed to connect to other workers.", std::move(rc));
}

CHECK(this->channels_.empty());
Expand Down
41 changes: 25 additions & 16 deletions src/collective/tracker.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#if defined(__unix__) || defined(__APPLE__)
#include <netdb.h> // gethostbyname
Expand Down Expand Up @@ -27,12 +27,14 @@
#include "tracker.h"
#include "xgboost/collective/result.h" // for Result, Fail, Success
#include "xgboost/collective/socket.h" // for GetHostName, FailWithCode, MakeSockAddress, ...
#include "xgboost/json.h"
#include "xgboost/json.h" // for Json

namespace xgboost::collective {
Tracker::Tracker(Json const& config)
: n_workers_{static_cast<std::int32_t>(
RequiredArg<Integer const>(config, "n_workers", __func__))},
: sortby_{static_cast<SortBy>(
OptionalArg<Integer const>(config, "sortby", static_cast<Integer::Int>(SortBy::kHost)))},
n_workers_{
static_cast<std::int32_t>(RequiredArg<Integer const>(config, "n_workers", __func__))},
port_{static_cast<std::int32_t>(OptionalArg<Integer const>(config, "port", Integer::Int{0}))},
timeout_{std::chrono::seconds{OptionalArg<Integer const>(
config, "timeout", static_cast<std::int64_t>(collective::DefaultTimeoutSec()))}} {}
Expand All @@ -56,13 +58,15 @@ Result Tracker::WaitUntilReady() const {
return Success();
}

RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr)
RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockAddress addr)
: sock_{std::move(sock)} {
std::int32_t rank{0};
Json jcmd;
std::int32_t port{0};

rc_ = Success() << [&] { return proto::Magic{}.Verify(&sock_); } << [&] {
rc_ = Success() << [&] {
return proto::Magic{}.Verify(&sock_);
} << [&] {
return proto::Connect{}.TrackerRecv(&sock_, &world_, &rank, &task_id_);
} << [&] {
std::string cmd;
Expand All @@ -83,28 +87,33 @@ RabitTracker::WorkerProxy::WorkerProxy(std::int32_t world, TCPSocket sock, SockA
}
return Success();
} << [&] {
auto host = addr.Addr();
info_ = proto::PeerInfo{host, port, rank};
if (addr.IsV4()) {
auto host = addr.V4().Addr();
info_ = proto::PeerInfo{host, port, rank};
} else {
auto host = addr.V6().Addr();
info_ = proto::PeerInfo{host, port, rank};
}
return Success();
};
}

RabitTracker::RabitTracker(Json const& config) : Tracker{config} {
std::string self;
auto rc = collective::GetHostAddress(&self);
auto host = OptionalArg<String>(config, "host", self);
host_ = OptionalArg<String>(config, "host", self);

host_ = host;
listener_ = TCPSocket::Create(SockDomain::kV4);
rc = listener_.Bind(host, &this->port_);
CHECK(rc.OK()) << rc.Report();
auto addr = MakeSockAddress(xgboost::StringView{host_}, 0);
listener_ = TCPSocket::Create(addr.IsV4() ? SockDomain::kV4 : SockDomain::kV6);
rc = listener_.Bind(host_, &this->port_);
SafeColl(rc);
listener_.Listen();
}

Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {
auto& workers = *p_workers;

std::sort(workers.begin(), workers.end(), WorkerCmp{});
std::sort(workers.begin(), workers.end(), WorkerCmp{this->sortby_});

std::vector<std::thread> bootstrap_threads;
for (std::int32_t r = 0; r < n_workers_; ++r) {
Expand Down Expand Up @@ -224,7 +233,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {

while (state.ShouldContinue()) {
TCPSocket sock;
SockAddrV4 addr;
SockAddress addr;
this->ready_ = true;
auto rc = listener_.Accept(&sock, &addr);
if (!rc.OK()) {
Expand Down Expand Up @@ -291,7 +300,7 @@ Result RabitTracker::Bootstrap(std::vector<WorkerProxy>* p_workers) {

[[nodiscard]] Json RabitTracker::WorkerArgs() const {
auto rc = this->WaitUntilReady();
CHECK(rc.OK()) << rc.Report();
SafeColl(rc);

Json args{Object{}};
args["DMLC_TRACKER_URI"] = String{host_};
Expand Down
23 changes: 18 additions & 5 deletions src/collective/tracker.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2023, XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#pragma once
#include <chrono> // for seconds
Expand Down Expand Up @@ -36,6 +36,16 @@ namespace xgboost::collective {
* signal an error to the tracker and the tracker will notify other workers.
*/
class Tracker {
protected:
// How to sort the workers, either by host name or by task ID. When using a multi-GPU
// setting, multiple workers can occupy the same host, in which case one should sort
// workers by task. Due to compatibility reason, the task ID is not always available, so
// we use host as the default.
enum class SortBy : std::int8_t {
kHost = 0,
kTask = 1,
} sortby_;

protected:
std::int32_t n_workers_{0};
std::int32_t port_{-1};
Expand Down Expand Up @@ -76,7 +86,7 @@ class RabitTracker : public Tracker {
Result rc_;

public:
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddrV4 addr);
explicit WorkerProxy(std::int32_t world, TCPSocket sock, SockAddress addr);
WorkerProxy(WorkerProxy const& that) = delete;
WorkerProxy(WorkerProxy&& that) = default;
WorkerProxy& operator=(WorkerProxy const&) = delete;
Expand All @@ -96,11 +106,14 @@ class RabitTracker : public Tracker {

void Send(StringView value) { this->sock_.Send(value); }
};
// provide an ordering for workers, this helps us get deterministic topology.
// Provide an ordering for workers, this helps us get deterministic topology.
struct WorkerCmp {
SortBy sortby;
explicit WorkerCmp(SortBy sortby) : sortby{sortby} {}

[[nodiscard]] bool operator()(WorkerProxy const& lhs, WorkerProxy const& rhs) {
auto const& lh = lhs.Host();
auto const& rh = rhs.Host();
auto const& lh = sortby == Tracker::SortBy::kHost ? lhs.Host() : lhs.TaskID();
auto const& rh = sortby == Tracker::SortBy::kHost ? rhs.Host() : rhs.TaskID();

if (lh != rh) {
return lh < rh;
Expand Down
1 change: 0 additions & 1 deletion src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <cstdint> // for int32_t, uint32_t, int64_t, uint64_t
#include <cstdlib> // for atoi
#include <cstring> // for memcpy, size_t, memset
#include <functional> // for less
#include <iomanip> // for operator<<, setiosflags
#include <iterator> // for back_insert_iterator, distance, back_inserter
#include <limits> // for numeric_limits
Expand Down
2 changes: 1 addition & 1 deletion tests/ci_build/Dockerfile.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ RUN \
mamba create -n gpu_test -c rapidsai-nightly -c rapidsai -c nvidia -c conda-forge -c defaults \
python=3.10 cudf=$RAPIDS_VERSION_ARG* rmm=$RAPIDS_VERSION_ARG* cudatoolkit=$CUDA_VERSION_ARG \
nccl>=$(cut -d "-" -f 1 << $NCCL_VERSION_ARG) \
dask \
dask=2024.1.1 \
dask-cuda=$RAPIDS_VERSION_ARG* dask-cudf=$RAPIDS_VERSION_ARG* cupy \
numpy pytest pytest-timeout scipy scikit-learn pandas matplotlib wheel python-kubernetes urllib3 graphviz hypothesis \
pyspark>=3.4.0 cloudpickle cuda-python && \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_categorical(self, local_cuda_client: Client) -> None:

X_onehot, _ = make_categorical(local_cuda_client, 10000, 30, 13, True)
X_onehot = dask_cudf.from_dask_dataframe(X_onehot)
run_categorical(local_cuda_client, "gpu_hist", X, X_onehot, y)
run_categorical(local_cuda_client, "hist", "cuda", X, X_onehot, y)

@given(
params=hist_parameter_strategy,
Expand Down
Loading

0 comments on commit ca4801f

Please sign in to comment.