Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[C/PyTorch] Userbuffers and comm+GEMM overlap algorithms refactored and moved to TE/common #1067

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e911bac
moved userbuffers code to TE/common
denera Aug 16, 2024
4842566
moved comm+GEMM overlap code to TE/common
denera Aug 23, 2024
c587e76
removed PyTorch depdency from comm+GEMM overlap in TE/common
denera Aug 26, 2024
4cc258b
added TE/PyTorch wrappers for refactored comm+GEMM overlap code in TE…
denera Aug 26, 2024
b9370a0
updated TE/PyTorch Python API to match the refactored comm+GEMM overl…
denera Aug 26, 2024
b03cf2d
updated unit tests to work with refactored comm+GEMM overlap code
denera Aug 27, 2024
9994989
added a pylint exception to comm+GEMM overlap test runner
denera Aug 27, 2024
8c54738
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
82a18c0
fixing linting errors
denera Aug 27, 2024
29fe3bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
64ffbbf
added documentation for te.initialize_ub
denera Aug 27, 2024
d840201
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 27, 2024
69ee948
fixed compile errors when building with NVTE_UB_WITH_MPI=1
denera Aug 27, 2024
f787c4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2024
3237517
fixed default bootstrap backend
denera Aug 28, 2024
2e6da4d
switched default bootstrap backend priority to MPI > Gloo > NCCL
denera Aug 28, 2024
aaca26e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 28, 2024
a04d85a
updated bootstrap backend documentation
denera Aug 28, 2024
d6f1225
close UB bootstrap socket to avoid interfering with CUDA Multicast sh…
denera Aug 29, 2024
271cbf7
added torch::Tensor wrappers for communication buffer and atomic coun…
denera Aug 29, 2024
4586653
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
23f7dca
automated handling of world, local and node ranks/sizes within C++ Co…
denera Sep 6, 2024
620c1f9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,10 @@ class CommOverlapHelper : torch::CustomClassHolder {
};

class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOverlapBase {
private:
torch::Tensor _ubuf_torch;
torch::Tensor _ubuf_counter;
Copy link

@anderson101866 anderson101866 Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the _ubuf_counter become redundant now? It seems to be only instantiated in constructor but no further usage.


public:
CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
CommOverlapHelper *helper, int tp_size, int num_splits = 3,
Expand Down Expand Up @@ -588,6 +592,9 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve
}; // CommOverlap

class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::CommOverlapP2PBase {
private:
torch::Tensor _ubuf_torch;
torch::Tensor _ubuf_counter;
public:
CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::ScalarType buffer_dtype,
CommOverlapHelper *helper, int tp_size,
Expand Down
31 changes: 26 additions & 5 deletions transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,17 @@ CommOverlap::CommOverlap(const std::vector<size_t> &buffer_shape, at::ScalarType
std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5),
std::bind(&CommOverlapHelper::ub_barrier, helper, _1), num_splits,
num_max_streams, comm_cga_size, num_comm_sm, set_sm_margin, atomic_gemm) {
// Even though we never use these PyTorch tensor wrappers directly, they're still necessary to
// for PyTorch to factor externally allocated memory into its memory pool and garbage collection
// threshold calculation.
_ubuf_torch = torch::from_blob(
_ubuf.dptr(), {static_cast<int64_t>(_ubuf.size(0)), static_cast<int64_t>(_ubuf.size(1))},
at::device(torch::kCUDA).dtype(buffer_dtype));
if (_atomic_gemm) {
_ubuf_counter = torch::from_blob(
_counter.dptr(), {static_cast<int64_t>(_num_splits * 2)},
at::device(torch::kCUDA).dtype(torch::kInt32));
}
}

/*
Expand Down Expand Up @@ -228,8 +239,7 @@ std::vector<at::Tensor> CommOverlap::bulk_overlap(
(comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
auto output_tensor =
torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1},
torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype())));
torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options());

return {D, output_tensor};
} // CommOverlap::bulk_overlap
Expand Down Expand Up @@ -332,7 +342,19 @@ CommOverlapP2P::CommOverlapP2P(const std::vector<size_t> &buffer_shape, at::Scal
helper->mylocal, helper->numlocal, helper->mynode, helper->numnodes, tp_size,
std::bind(&CommOverlapHelper::ub_allgather, helper, _1, _2, _3, _4, _5),
std::bind(&CommOverlapHelper::ub_barrier, helper, _1), comm_type, num_max_streams,
comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {}
comm_cga_size, num_comm_sm, set_sm_margin, use_ce, atomic_gemm, aggregate) {
// Even though we never use these PyTorch tensor wrappers directly, they're still necessary to
// for PyTorch to factor externally allocated memory into its memory pool and garbage collection
// threshold calculation.
_ubuf_torch = torch::from_blob(
_ubuf.dptr(), {static_cast<int64_t>(_ubuf.size(0)), static_cast<int64_t>(_ubuf.size(1))},
at::device(torch::kCUDA).dtype(buffer_dtype));
if (_atomic_gemm) {
_ubuf_counter = torch::from_blob(
_counter.dptr(), {static_cast<int64_t>(_num_splits * 2)},
at::device(torch::kCUDA).dtype(torch::kInt32));
}
}

/*
** Split AllGather + AtomicGEMM using P2P communication
Expand Down Expand Up @@ -457,6 +479,5 @@ torch::Tensor CommOverlapP2P::get_ubuf_output(int comm_type) {
int output_c_dim0 =
(_comm_type == te::CommOverlapType::AG) ? _ubuf.size(0) : _ubuf.size(0) / _tp_size;
int output_c_dim1 = _ubuf.size(1);
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1},
torch::device(torch::kCUDA).dtype(GetATenDType(_ubuf.dtype())));
return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf_torch.options());
}
Loading