diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index 8a84429987d90..517b4a28a690f 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -47,6 +47,7 @@ ENDIF() cc_library(cpu_info SRCS cpu_info.cc DEPS ${CPU_INFO_DEPS}) cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) cc_library(os_info SRCS os_info.cc DEPS enforce) +cc_test(os_info_test SRCS os_info_test.cc DEPS os_info) IF(WITH_GPU) nv_library(cuda_graph_with_memory_pool SRCS cuda_graph_with_memory_pool.cc DEPS device_context allocator_facade cuda_graph) diff --git a/paddle/fluid/platform/os_info.cc b/paddle/fluid/platform/os_info.cc index 5ba7f1d144e12..07263153164e2 100644 --- a/paddle/fluid/platform/os_info.cc +++ b/paddle/fluid/platform/os_info.cc @@ -13,7 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/os_info.h" +#include +#include #include +#include +#include #if defined(__linux__) #include #include @@ -21,32 +25,181 @@ limitations under the License. */ #elif defined(_MSC_VER) #include #endif +#include "paddle/fluid/platform/macros.h" // import DISABLE_COPY_AND_ASSIGN namespace paddle { namespace platform { +namespace internal { -ThreadId::ThreadId() { +static uint64_t main_tid = + std::hash()(std::this_thread::get_id()); + +template +class ThreadDataRegistry { + class ThreadDataHolder; + + public: + // Singleton + static ThreadDataRegistry& GetInstance() { + static ThreadDataRegistry instance; + return instance; + } + + const T& GetCurrentThreadData() { return CurrentThreadData(); } + + void SetCurrentThreadData(const T& val) { + std::lock_guard lock(lock_); + CurrentThreadData() = val; + } + + // Returns current snapshot of all threads. Make sure there is no thread + // create/destory when using it. + template ::value>> + std::unordered_map GetAllThreadDataByValue() { + std::unordered_map data_copy; + std::lock_guard lock(lock_); + data_copy.reserve(tid_map_.size()); + for (auto& kv : tid_map_) { + data_copy.emplace(kv.first, kv.second->GetData()); + } + return std::move(data_copy); + } + + void RegisterData(uint64_t tid, ThreadDataHolder* tls_obj) { + std::lock_guard lock(lock_); + tid_map_[tid] = tls_obj; + } + + void UnregisterData(uint64_t tid) { + if (tid == main_tid) { + return; + } + std::lock_guard lock(lock_); + tid_map_.erase(tid); + } + + private: + class ThreadDataHolder { + public: + ThreadDataHolder() { + tid_ = std::hash()(std::this_thread::get_id()); + ThreadDataRegistry::GetInstance().RegisterData(tid_, this); + } + + ~ThreadDataHolder() { + ThreadDataRegistry::GetInstance().UnregisterData(tid_); + } + + T& GetData() { return data_; } + + private: + uint64_t tid_; + T data_; + }; + + ThreadDataRegistry() = default; + + DISABLE_COPY_AND_ASSIGN(ThreadDataRegistry); + + T& CurrentThreadData() { + static thread_local ThreadDataHolder thread_data; + return thread_data.GetData(); + } + + std::mutex lock_; + std::unordered_map tid_map_; // not owned +}; + +class InternalThreadId { + public: + InternalThreadId(); + + const ThreadId& GetTid() const { return id_; } + + private: + ThreadId id_; +}; + +InternalThreadId::InternalThreadId() { // C++ std tid - std_tid_ = std::hash()(std::this_thread::get_id()); + id_.std_tid = std::hash()(std::this_thread::get_id()); // system tid #if defined(__linux__) - sys_tid_ = syscall(SYS_gettid); + id_.sys_tid = static_cast(syscall(SYS_gettid)); #elif defined(_MSC_VER) - sys_tid_ = GetCurrentThreadId(); -#else // unsupported platforms - sys_tid_ = 0; + id_.sys_tid = static_cast(::GetCurrentThreadId()); +#else // unsupported platforms, use std_tid + id_.sys_tid = id_.std_tid; #endif // cupti tid std::stringstream ss; ss << std::this_thread::get_id(); - cupti_tid_ = static_cast(std::stoull(ss.str())); + id_.cupti_tid = static_cast(std::stoull(ss.str())); +} + +} // namespace internal + +uint64_t GetCurrentThreadSysId() { + return internal::ThreadDataRegistry::GetInstance() + .GetCurrentThreadData() + .GetTid() + .sys_tid; } -ThreadIdRegistry::~ThreadIdRegistry() { - std::lock_guard lock(lock_); - for (auto id_pair : id_map_) { - delete id_pair.second; +uint64_t GetCurrentThreadStdId() { + return internal::ThreadDataRegistry::GetInstance() + .GetCurrentThreadData() + .GetTid() + .std_tid; +} + +ThreadId GetCurrentThreadId() { + return internal::ThreadDataRegistry::GetInstance() + .GetCurrentThreadData() + .GetTid(); +} + +std::unordered_map GetAllThreadIds() { + auto tids = + internal::ThreadDataRegistry::GetInstance() + .GetAllThreadDataByValue(); + std::unordered_map res; + for (const auto& kv : tids) { + res[kv.first] = kv.second.GetTid(); } + return res; +} + +static constexpr const char* kDefaultThreadName = "unset"; + +std::string GetCurrentThreadName() { + const auto& thread_name = + internal::ThreadDataRegistry::GetInstance() + .GetCurrentThreadData(); + return thread_name.empty() ? kDefaultThreadName : thread_name; +} + +std::unordered_map GetAllThreadNames() { + return internal::ThreadDataRegistry::GetInstance() + .GetAllThreadDataByValue(); +} + +bool SetCurrentThreadName(const std::string& name) { + auto& instance = internal::ThreadDataRegistry::GetInstance(); + const auto& cur_name = instance.GetCurrentThreadData(); + if (!cur_name.empty() || cur_name == kDefaultThreadName) { + return false; + } + instance.SetCurrentThreadData(name); + return true; +} + +uint32_t GetProcessId() { +#if defined(_MSC_VER) + return static_cast(GetCurrentProcessId()); +#else + return static_cast(getpid()); +#endif } } // namespace platform diff --git a/paddle/fluid/platform/os_info.h b/paddle/fluid/platform/os_info.h index c38198f91b36b..c84738247a46f 100644 --- a/paddle/fluid/platform/os_info.h +++ b/paddle/fluid/platform/os_info.h @@ -14,15 +14,12 @@ limitations under the License. */ #pragma once -#include -#include +#include #include -#include "paddle/fluid/platform/enforce.h" // import LIKELY -#include "paddle/fluid/platform/macros.h" // import DISABLE_COPY_AND_ASSIGN -#include "paddle/fluid/platform/port.h" #ifdef _POSIX_C_SOURCE #include #endif +#include "paddle/fluid/platform/port.h" namespace paddle { namespace platform { @@ -41,59 +38,38 @@ inline uint64_t PosixInNsec() { } // All kinds of Ids for OS thread -class ThreadId { - public: - ThreadId(); +struct ThreadId { + uint64_t std_tid = 0; // std::hash + uint64_t sys_tid = 0; // OS-specific, Linux: gettid + uint32_t cupti_tid = 0; // thread_id used by Nvidia CUPTI +}; - uint64_t MainTid() const { return SysTid(); } +// Better performance than GetCurrentThreadId +uint64_t GetCurrentThreadStdId(); - uint64_t StdTid() const { return std_tid_; } +// Better performance than GetCurrentThreadId +uint64_t GetCurrentThreadSysId(); - uint32_t CuptiTid() const { return cupti_tid_; } +ThreadId GetCurrentThreadId(); - uint64_t SysTid() const { return sys_tid_ != 0 ? sys_tid_ : std_tid_; } +// Return the map from StdTid to ThreadId +// Returns current snapshot of all threads. Make sure there is no thread +// create/destory when using it. +std::unordered_map GetAllThreadIds(); - private: - uint64_t std_tid_ = 0; // std::hash - uint32_t cupti_tid_ = 0; // thread_id used by Nvidia CUPTI - uint64_t sys_tid_ = 0; // OS-specific, Linux: gettid -}; +// Returns 'unset' if SetCurrentThreadName is never called. +std::string GetCurrentThreadName(); -class ThreadIdRegistry { - public: - // singleton - static ThreadIdRegistry& GetInstance() { - static ThreadIdRegistry instance; - return instance; - } - - const ThreadId* GetThreadId(uint64_t std_id) { - std::lock_guard lock(lock_); - if (LIKELY(id_map_.find(std_id) != id_map_.end())) { - return id_map_[std_id]; - } - return nullptr; - } - - const ThreadId& CurrentThreadId() { - static thread_local ThreadId* tid_ = nullptr; - if (LIKELY(tid_ != nullptr)) { - return *tid_; - } - tid_ = new ThreadId; - std::lock_guard lock(lock_); - id_map_[tid_->StdTid()] = tid_; - return *tid_; - } - - private: - ThreadIdRegistry() = default; - DISABLE_COPY_AND_ASSIGN(ThreadIdRegistry); - ~ThreadIdRegistry(); - - std::mutex lock_; - std::unordered_map id_map_; -}; +// Return the map from StdTid to ThreadName +// Returns current snapshot of all threads. Make sure there is no thread +// create/destory when using it. +std::unordered_map GetAllThreadNames(); + +// Thread name is immutable, only the first call will succeed. +// Returns false on failure. +bool SetCurrentThreadName(const std::string& name); + +uint32_t GetProcessId(); } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/os_info_test.cc b/paddle/fluid/platform/os_info_test.cc new file mode 100644 index 0000000000000..b309bb985122d --- /dev/null +++ b/paddle/fluid/platform/os_info_test.cc @@ -0,0 +1,40 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/fluid/platform/os_info.h" +#include +#include "gtest/gtest.h" + +TEST(ThreadInfo, TestThreadIdUtils) { + using paddle::platform::GetCurrentThreadStdId; + using paddle::platform::GetCurrentThreadId; + using paddle::platform::GetAllThreadIds; + EXPECT_EQ(std::hash()(std::this_thread::get_id()), + GetCurrentThreadId().std_tid); + auto ids = GetAllThreadIds(); + EXPECT_TRUE(ids.find(GetCurrentThreadStdId()) != ids.end()); +} + +TEST(ThreadInfo, TestThreadNameUtils) { + using paddle::platform::GetCurrentThreadStdId; + using paddle::platform::GetCurrentThreadName; + using paddle::platform::SetCurrentThreadName; + using paddle::platform::GetAllThreadNames; + EXPECT_EQ("unset", GetCurrentThreadName()); + EXPECT_TRUE(SetCurrentThreadName("MainThread")); + EXPECT_FALSE(SetCurrentThreadName("MainThread")); + auto names = GetAllThreadNames(); + EXPECT_TRUE(names.find(GetCurrentThreadStdId()) != names.end()); + EXPECT_EQ("MainThread", names[GetCurrentThreadStdId()]); + EXPECT_EQ("MainThread", GetCurrentThreadName()); +} diff --git a/paddle/fluid/platform/profiler/host_event_recorder.cc b/paddle/fluid/platform/profiler/host_event_recorder.cc index 14054418c5d24..b8495ca45ca84 100644 --- a/paddle/fluid/platform/profiler/host_event_recorder.cc +++ b/paddle/fluid/platform/profiler/host_event_recorder.cc @@ -16,7 +16,7 @@ namespace paddle { namespace platform { ThreadEventRecorder::ThreadEventRecorder() { - thread_id_ = ThreadIdRegistry::GetInstance().CurrentThreadId().MainTid(); + thread_id_ = GetCurrentThreadSysId(); HostEventRecorder::GetInstance().RegisterThreadRecorder(thread_id_, this); }