Skip to content

Commit

Permalink
removed compile-time link to NVTRC and cuDNN, fixed PyBind11 definiti…
Browse files Browse the repository at this point in the history
…ons for TE/common enums, recovered UB_SKIPMC fix for P2P overlaps

Signed-off-by: Alp Dener <adener@nvidia.com>
  • Loading branch information
denera committed Aug 1, 2024
1 parent 7c0cc8d commit 2e55bb2
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 37 deletions.
5 changes: 1 addition & 4 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ set(transformer_engine_LIBS)
list(APPEND transformer_engine_LIBS
CUDA::cublas
CUDA::cuda_driver
CUDA::cudart
CUDA::nvrtc
CUDA::nvToolsExt
CUDNN::cudnn)
CUDA::cudart)
target_include_directories(transformer_engine PRIVATE
${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}
${CUDNN_FRONTEND_INCLUDE_DIR})
Expand Down
26 changes: 7 additions & 19 deletions transformer_engine/common/util/pybind_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,6 @@
#include <transformer_engine/fused_attn.h>
#include <transformer_engine/transformer_engine.h>

// NOTE: transformer_engine::DType is a strongly-typed enum class so it should not have
// `.export_values()`. NVTE_ prefixed enums are classic C enums to make core headers
// includable in external pure-C libraries, and they need `.export_values()` to work
// correctly in Python.

#define NVTE_ADD_COMMON_PYBIND11_BINDINGS(m) \
do { \
pybind11::enum_<transformer_engine::DType>(m, "DType", pybind11::module_local()) \
Expand All @@ -41,23 +36,20 @@
.value("QGELU", NVTE_Activation_Type::QGELU) \
.value("QGEGLU", NVTE_Activation_Type::QGEGLU) \
.value("SRELU", NVTE_Activation_Type::SRELU) \
.value("SREGLU", NVTE_Activation_Type::SREGLU) \
.export_values(); \
.value("SREGLU", NVTE_Activation_Type::SREGLU); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
.value("NVTE_POST_SCALE_BIAS", NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) \
.value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI) \
.export_values(); \
.value("NVTE_ALIBI", NVTE_Bias_Type::NVTE_ALIBI); \
pybind11::enum_<NVTE_Mask_Type>(m, "NVTE_Mask_Type", pybind11::module_local()) \
.value("NVTE_NO_MASK", NVTE_Mask_Type::NVTE_NO_MASK) \
.value("NVTE_PADDING_MASK", NVTE_Mask_Type::NVTE_PADDING_MASK) \
.value("NVTE_CAUSAL_MASK", NVTE_Mask_Type::NVTE_CAUSAL_MASK) \
.value("NVTE_PADDING_CAUSAL_MASK", NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) \
.value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \
.value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) \
.export_values(); \
NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \
pybind11::enum_<NVTE_QKV_Layout>(m, "NVTE_QKV_Layout", pybind11::module_local()) \
.value("NVTE_SB3HD", NVTE_QKV_Layout::NVTE_SB3HD) \
.value("NVTE_SBH3D", NVTE_QKV_Layout::NVTE_SBH3D) \
Expand All @@ -73,19 +65,16 @@
.value("NVTE_TH3D", NVTE_QKV_Layout::NVTE_TH3D) \
.value("NVTE_THD_T2HD", NVTE_QKV_Layout::NVTE_THD_T2HD) \
.value("NVTE_THD_TH2D", NVTE_QKV_Layout::NVTE_THD_TH2D) \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD) \
.export_values(); \
.value("NVTE_THD_THD_THD", NVTE_QKV_Layout::NVTE_THD_THD_THD); \
pybind11::enum_<NVTE_Fused_Attn_Backend>(m, "NVTE_Fused_Attn_Backend", \
pybind11::module_local()) \
.value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \
.value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \
.value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) \
.export_values(); \
.value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \
pybind11::enum_<NVTE_Comm_Overlap_Type>(m, "NVTE_Comm_Overlap_Type", pybind11::module_local()) \
.value("AG", NVTE_Comm_Overlap_Type::ALL_GATHER) \
.value("RS", NVTE_Comm_Overlap_Type::REDUCE_SCATTER) \
.export_values(); \
.value("RS", NVTE_Comm_Overlap_Type::REDUCE_SCATTER); \
pybind11::enum_<NVTE_Comm_Overlap_Algo>(m, "NVTE_Comm_Overlap_Algo", pybind11::module_local()) \
.value("BULK_OVERLAP_RS", NVTE_Comm_Overlap_Algo::BULK_OVERLAP_RS) \
.value("BULK_OVERLAP_AG", NVTE_Comm_Overlap_Algo::BULK_OVERLAP_AG) \
Expand All @@ -94,8 +83,7 @@
.value("SPLIT_PIPELINED_AG_P2P", NVTE_Comm_Overlap_Algo::SPLIT_PIPELINED_AG_P2P) \
.value("ATOMIC_GEMM_RS", NVTE_Comm_Overlap_Algo::ATOMIC_GEMM_RS) \
.value("ATOMIC_GEMM_RS_P2P", NVTE_Comm_Overlap_Algo::ATOMIC_GEMM_RS_P2P) \
.value("ATOMIC_GEMM_AG_P2P", NVTE_Comm_Overlap_Algo::ATOMIC_GEMM_AG_P2P) \
.export_values(); \
.value("ATOMIC_GEMM_AG_P2P", NVTE_Comm_Overlap_Algo::ATOMIC_GEMM_AG_P2P); \
m.attr("NVTE_COMM_OVERLAP_MAX_STREAMS") = pybind11::int_(NVTE_COMM_OVERLAP_MAX_STREAMS); \
m.def("device_supports_multicast", &transformer_engine::device_supports_multicast, \
pybind11::call_guard<pybind11::gil_scoped_release>()); \
Expand Down
17 changes: 3 additions & 14 deletions transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,20 +399,9 @@ CommOverlapP2P::CommOverlapP2P(torch::Tensor sample, CommOverlapHelper &helper,
}

void *ubuf_ptr;
if (te::getenv<bool>("UB_SKIPMC")) {
// Multicast is disabled so we have to pre-allocate the buffer here.
_ubuf = torch::empty({(sample.size(0) / _tp_size) * _num_ubuf_chunks, sample.size(1)},
sample.options());
ubuf_ptr = _ubuf.data_ptr();
register_gpu_buffer(&ubuf_ptr, _ubuf_bytes, false);
} else {
// Multicast requires UB to allocate the buffer with specific memory options
// that PyTorch allocator does not support.
register_gpu_buffer(&ubuf_ptr, _ubuf_bytes, true);
_ubuf =
torch::from_blob(ubuf_ptr, {(sample.size(0) / _tp_size) * _num_ubuf_chunks, sample.size(1)},
sample.options());
}
register_gpu_buffer(&ubuf_ptr, _ubuf_bytes, false);
_ubuf = torch::from_blob(
ubuf_ptr, {(sample.size(0) / _tp_size) * _num_ubuf_chunks, sample.size(1)}, sample.options());

// Create tensor chunks for easy management
char *ubuf_byte_ptr = reinterpret_cast<char *>(ubuf_ptr);
Expand Down

0 comments on commit 2e55bb2

Please sign in to comment.