Skip to content

Commit

Permalink
[CustomDevice] fix resource_pool release bug (#55229)
Browse files Browse the repository at this point in the history
  • Loading branch information
ronny1996 authored Jul 7, 2023
1 parent b8f265d commit 6af85a8
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 63 deletions.
103 changes: 58 additions & 45 deletions paddle/fluid/platform/device/custom/custom_device_resource_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,39 +30,47 @@ CustomDeviceStreamResourcePool::CustomDeviceStreamResourcePool(
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [place, dev_idx] {
auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);

phi::stream::Stream* stream = new phi::stream::Stream(place_, nullptr);
phi::DeviceManager::GetDeviceWithPlace(place_)->CreateStream(stream);
phi::stream::Stream* stream = new phi::stream::Stream;
stream->Init(place_);
this->streams_.push_back(stream);
return stream;
};

auto deleter = [place, dev_idx](phi::stream::Stream* stream) {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);

phi::DeviceManager::GetDeviceWithPlace(place_)->DestroyStream(stream);
delete stream;
};

pool_.emplace_back(
ResourcePool<CustomDeviceStreamObject>::Create(creator, deleter));
pool_.emplace_back(ResourcePool<CustomDeviceStreamObject>::Create(
creator, [](phi::stream::Stream* stream) {}));
}
}

std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>&
std::unordered_map<std::string, std::vector<CustomDeviceStreamResourcePool*>>&
CustomDeviceStreamResourcePool::GetMap() {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>
static std::unordered_map<std::string,
std::vector<CustomDeviceStreamResourcePool*>>
pool;
return pool;
}

CustomDeviceStreamResourcePool::~CustomDeviceStreamResourcePool() {
for (auto* p : streams_) {
delete p;
}
pool_.clear();
}

void CustomDeviceStreamResourcePool::Release() {
auto& pool = GetMap();
for (auto& item : pool) {
for (auto& p : item.second) {
delete p;
}
item.second.clear();
}
pool.clear();
}

CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
const paddle::Place& place) {
auto& pool = GetMap();
Expand All @@ -72,9 +80,8 @@ CustomDeviceStreamResourcePool& CustomDeviceStreamResourcePool::Instance(
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert(
{place.GetDeviceType(),
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>()});
pool.insert({place.GetDeviceType(),
std::vector<CustomDeviceStreamResourcePool*>()});
for (size_t i = 0;
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
Expand Down Expand Up @@ -121,52 +128,58 @@ CustomDeviceEventResourcePool::CustomDeviceEventResourcePool(
int dev_cnt = phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
pool_.reserve(dev_cnt);
for (int dev_idx = 0; dev_idx < dev_cnt; ++dev_idx) {
auto creator = [place, dev_idx] {
auto creator = [place, dev_idx, this] {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);

phi::event::Event* event = new phi::event::Event(place_, nullptr);
phi::DeviceManager::GetDeviceWithPlace(place_)->CreateEvent(event);
phi::event::Event* event = new phi::event::Event;
event->Init(place_);
this->events_.push_back(event);
return event;
};

auto deleter = [place, dev_idx](phi::event::Event* event) {
auto place_ = phi::CustomPlace(place.GetDeviceType(), dev_idx);
phi::DeviceManager::SetDevice(place_);

phi::DeviceManager::GetDeviceWithPlace(place_)->DestroyEvent(event);
};

pool_.emplace_back(
ResourcePool<CustomDeviceEventObject>::Create(creator, deleter));
pool_.emplace_back(ResourcePool<CustomDeviceEventObject>::Create(
creator, [](phi::event::Event* event) {}));
}
}

std::unordered_map<std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>&
std::unordered_map<std::string, std::vector<CustomDeviceEventResourcePool*>>&
CustomDeviceEventResourcePool::GetMap() {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>
static std::unordered_map<std::string,
std::vector<CustomDeviceEventResourcePool*>>
pool;
return pool;
}

CustomDeviceEventResourcePool::~CustomDeviceEventResourcePool() {
for (auto* p : events_) {
delete p;
}
pool_.clear();
}

void CustomDeviceEventResourcePool::Release() {
auto& pool = GetMap();
for (auto& item : pool) {
for (auto& p : item.second) {
delete p;
}
item.second.clear();
}
pool.clear();
}

CustomDeviceEventResourcePool& CustomDeviceEventResourcePool::Instance(
const phi::Place& place) {
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>
pool;
auto& pool = GetMap();
PADDLE_ENFORCE_EQ(
platform::is_custom_place(place),
true,
platform::errors::PreconditionNotMet(
"Required device shall be CustomPlace, but received %d. ", place));
if (pool.find(place.GetDeviceType()) == pool.end()) {
pool.insert(
{place.GetDeviceType(),
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>()});
{place.GetDeviceType(), std::vector<CustomDeviceEventResourcePool*>()});
for (size_t i = 0;
i < phi::DeviceManager::GetDeviceCount(place.GetDeviceType());
++i) {
Expand Down
20 changes: 14 additions & 6 deletions paddle/fluid/platform/device/custom/custom_device_resource_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,42 +31,50 @@ using CustomDeviceEventObject = phi::event::Event;

class CustomDeviceStreamResourcePool {
public:
static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceStreamResourcePool>>>&
static std::unordered_map<std::string,
std::vector<CustomDeviceStreamResourcePool*>>&
GetMap();

static void Release();

std::shared_ptr<CustomDeviceStreamObject> New(int dev_idx);

static CustomDeviceStreamResourcePool& Instance(const paddle::Place& place);

~CustomDeviceStreamResourcePool();

private:
explicit CustomDeviceStreamResourcePool(const paddle::Place& place);

DISABLE_COPY_AND_ASSIGN(CustomDeviceStreamResourcePool);

private:
std::vector<std::shared_ptr<ResourcePool<CustomDeviceStreamObject>>> pool_;
std::vector<phi::stream::Stream*> streams_;
};

class CustomDeviceEventResourcePool {
public:
std::shared_ptr<CustomDeviceEventObject> New(int dev_idx);

static std::unordered_map<
std::string,
std::vector<std::shared_ptr<CustomDeviceEventResourcePool>>>&
static std::unordered_map<std::string,
std::vector<CustomDeviceEventResourcePool*>>&
GetMap();

static void Release();

static CustomDeviceEventResourcePool& Instance(const paddle::Place& place);

~CustomDeviceEventResourcePool();

private:
explicit CustomDeviceEventResourcePool(const paddle::Place& place);

DISABLE_COPY_AND_ASSIGN(CustomDeviceEventResourcePool);

private:
std::vector<std::shared_ptr<ResourcePool<CustomDeviceEventObject>>> pool_;
std::vector<phi::event::Event*> events_;
};

} // namespace platform
Expand Down
15 changes: 15 additions & 0 deletions paddle/fluid/platform/profiler/custom_device/custom_tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ CustomTracer::~CustomTracer() {
#endif
}

std::unordered_map<std::string, std::unique_ptr<CustomTracer>>&
CustomTracer::GetMap() {
static std::unordered_map<std::string, std::unique_ptr<CustomTracer>>
instance;
return instance;
}

void CustomTracer::Release() {
auto& pool = GetMap();
for (auto& item : pool) {
item.second.reset();
}
pool.clear();
}

void CustomTracer::PrepareTracing() {
PADDLE_ENFORCE_EQ(
state_ == TracerState::UNINITED || state_ == TracerState::STOPED,
Expand Down
12 changes: 5 additions & 7 deletions paddle/fluid/platform/profiler/custom_device/custom_tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,16 @@ namespace platform {

class CustomTracer : public TracerBase {
public:
static std::unordered_map<std::string, std::shared_ptr<CustomTracer>>&
GetMap() {
static std::unordered_map<std::string, std::shared_ptr<CustomTracer>>
instance;
return instance;
}
static std::unordered_map<std::string, std::unique_ptr<CustomTracer>>&
GetMap();

static void Release();

static CustomTracer& GetInstance(const std::string& device_type) {
auto& instance = GetMap();
if (instance.find(device_type) == instance.cend()) {
instance.insert(
{device_type, std::make_shared<CustomTracer>(device_type)});
{device_type, std::make_unique<CustomTracer>(device_type)});
}
return *instance[device_type];
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1009,9 +1009,9 @@ PYBIND11_MODULE(libpaddle, m) {
m.def("clear_device_manager", []() {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
platform::XCCLCommContext::Release();
platform::CustomTracer::GetMap().clear();
platform::CustomDeviceEventResourcePool::GetMap().clear();
platform::CustomDeviceStreamResourcePool::GetMap().clear();
platform::CustomTracer::Release();
platform::CustomDeviceEventResourcePool::Release();
platform::CustomDeviceStreamResourcePool::Release();
phi::DeviceManager::Clear();
#endif
});
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/backends/device_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -671,8 +671,8 @@ DeviceManager& DeviceManager::Instance() {
}

void DeviceManager::Clear() {
// Instance().device_map_.clear();
// Instance().device_impl_map_.clear();
Instance().device_map_.clear();
Instance().device_impl_map_.clear();
}

std::vector<std::string> ListAllLibraries(const std::string& library_dir) {
Expand Down

0 comments on commit 6af85a8

Please sign in to comment.