From 37379bf1d33f386235827944c8bee03b67f68f7f Mon Sep 17 00:00:00 2001 From: jiajuku Date: Mon, 3 Jul 2023 10:32:24 +0800 Subject: [PATCH] Fix build error on rocm4.5/rocm5 on ubuntu18.04 * declear new rccl api in paddle/fluid * Fix build error:free(): invalid pointer by pick the fix: https://github.com/skarupke/flat_hash_map/pull/26 * Filter to check codestyle for flash_hash_map.h Signed-off-by: jiajuku --- .pre-commit-config.yaml | 3 ++- paddle/fluid/platform/dynload/rccl.cc | 8 ++++++++ paddle/fluid/platform/dynload/rccl.h | 12 ++++++++++++ paddle/utils/flat_hash_map.h | 26 +++++++++++++++----------- 4 files changed, 37 insertions(+), 12 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 481f05c46e8e5..1978b0a00c7f1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -102,7 +102,8 @@ repos: exclude: | (?x)^( paddle/cinn/.+| - test/cpp/cinn/.+ + test/cpp/cinn/.+| + paddle/utils/flat_hash_map.h+ )$ # For CMake files - repo: local diff --git a/paddle/fluid/platform/dynload/rccl.cc b/paddle/fluid/platform/dynload/rccl.cc index 82838da685bf2..62bb6a88af7c0 100644 --- a/paddle/fluid/platform/dynload/rccl.cc +++ b/paddle/fluid/platform/dynload/rccl.cc @@ -26,10 +26,18 @@ RCCL_RAND_ROUTINE_EACH(DEFINE_WRAP); RCCL_RAND_ROUTINE_EACH_AFTER_2212(DEFINE_WRAP) #endif +#if NCCL_VERSION_CODE >= 2304 +RCCL_RAND_ROUTINE_EACH_AFTER_2304(DEFINE_WRAP) +#endif + #if NCCL_VERSION_CODE >= 2703 RCCL_RAND_ROUTINE_EACH_AFTER_2703(DEFINE_WRAP) #endif +#if NCCL_VERSION_CODE >= 21100 +RCCL_RAND_ROUTINE_EACH_AFTER_21100(DEFINE_WRAP) +#endif + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/rccl.h b/paddle/fluid/platform/dynload/rccl.h index 2f874bb59f593..4d988e4fb08a0 100644 --- a/paddle/fluid/platform/dynload/rccl.h +++ b/paddle/fluid/platform/dynload/rccl.h @@ -51,6 +51,11 @@ RCCL_RAND_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP) RCCL_RAND_ROUTINE_EACH_AFTER_2212(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP) #endif +#if NCCL_VERSION_CODE >= 2304 +#define RCCL_RAND_ROUTINE_EACH_AFTER_2304(__macro) __macro(ncclGetVersion); +RCCL_RAND_ROUTINE_EACH_AFTER_2304(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP) +#endif + #if NCCL_VERSION_CODE >= 2703 #define RCCL_RAND_ROUTINE_EACH_AFTER_2703(__macro) \ __macro(ncclSend); \ @@ -58,6 +63,13 @@ RCCL_RAND_ROUTINE_EACH_AFTER_2212(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP) RCCL_RAND_ROUTINE_EACH_AFTER_2703(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP) #endif +#if NCCL_VERSION_CODE >= 21100 +#define RCCL_RAND_ROUTINE_EACH_AFTER_21100(__macro) \ + __macro(ncclRedOpCreatePreMulSum); \ + __macro(ncclRedOpDestroy); +RCCL_RAND_ROUTINE_EACH_AFTER_21100(PLATFORM_DECLARE_DYNAMIC_LOAD_RCCL_WRAP) +#endif + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/utils/flat_hash_map.h b/paddle/utils/flat_hash_map.h index 56318ab90e6c8..b643fc1a574b2 100644 --- a/paddle/utils/flat_hash_map.h +++ b/paddle/utils/flat_hash_map.h @@ -126,11 +126,6 @@ struct sherwood_v3_entry { sherwood_v3_entry(int8_t distance_from_desired) : distance_from_desired(distance_from_desired) {} ~sherwood_v3_entry() {} - static sherwood_v3_entry *empty_default_table() { - static sherwood_v3_entry result[min_lookups] = { - {}, {}, {}, {special_end_value}}; - return result; - } bool has_value() const { return distance_from_desired >= 0; } bool is_empty() const { return distance_from_desired < 0; } @@ -664,13 +659,24 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal { bool empty() const { return num_elements == 0; } private: - EntryPointer entries = Entry::empty_default_table(); + EntryPointer entries = empty_default_table(); size_t num_slots_minus_one = 0; typename HashPolicySelector::type hash_policy; int8_t max_lookups = detailv3::min_lookups - 1; float _max_load_factor = 0.5f; size_t num_elements = 0; + EntryPointer empty_default_table() { + EntryPointer result = + AllocatorTraits::allocate(*this, detailv3::min_lookups); + EntryPointer special_end_item = + result + static_cast(detailv3::min_lookups - 1); + for (EntryPointer it = result; it != special_end_item; ++it) + it->distance_from_desired = -1; + special_end_item->distance_from_desired = Entry::special_end_value; + return result; + } + static int8_t compute_max_lookups(size_t num_buckets) { int8_t desired = detailv3::log2(num_buckets); return (std::max)(detailv3::min_lookups, desired); @@ -743,15 +749,13 @@ class sherwood_v3_table : private EntryAlloc, private Hasher, private Equal { void deallocate_data(EntryPointer begin, size_t num_slots_minus_one, int8_t max_lookups) { - if (begin != Entry::empty_default_table()) { - AllocatorTraits::deallocate( - *this, begin, num_slots_minus_one + max_lookups + 1); - } + AllocatorTraits::deallocate( + *this, begin, num_slots_minus_one + max_lookups + 1); } void reset_to_empty_state() { deallocate_data(entries, num_slots_minus_one, max_lookups); - entries = Entry::empty_default_table(); + entries = empty_default_table(); num_slots_minus_one = 0; hash_policy.reset(); max_lookups = detailv3::min_lookups - 1;