Skip to content

Commit

Permalink
[PluggableDevice] Add custom runtime support (#38740)
Browse files Browse the repository at this point in the history
* [CustomRuntime] Add DeviceManager

* [CustomRuntime] Add DeviceInterface

* [CustomRuntime] Add Stream, Event, DeviceGuard, CallbackManager

* [CustomRuntime] Add plug-in device

* [CustomRuntime] Memory module support PluggableDevice

* [CustomRuntime] Add WITH_PLUGGABLE_DEVICE cmake option

* update

* [API] update API doc based on comments, test=develop

Co-authored-by: qili93 <qili93@qq.com>
  • Loading branch information
ronny1996 and qili93 authored Feb 15, 2022
1 parent 0d46a10 commit 3e7825f
Show file tree
Hide file tree
Showing 66 changed files with 5,056 additions and 138 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ option(NEW_RELEASE_JIT "PaddlePaddle next-level release strategy for backup ji
option(WITH_ASCEND_INT64 "Compile with int64 kernel for ascend NPU" OFF)
option(WITH_POCKETFFT "Compile with pocketfft support" ON)
option(WITH_RECORD_BUILDTIME "Compile PaddlePaddle with record all targets build time" OFF)
option(WITH_CUSTOM_DEVICE "Compile with custom device support" OFF)

if(WITH_RECORD_BUILDTIME)
set_property(GLOBAL PROPERTY RULE_LAUNCH_COMPILE "${CMAKE_CURRENT_SOURCE_DIR}/tools/get_build_time.sh")
Expand All @@ -265,6 +266,10 @@ if(SANITIZER_TYPE AND NOT "${SANITIZER_TYPE}" MATCHES "^(Address|Leak|Memory|Thr
return()
endif()

if (LINUX AND NOT WITH_CUSTOM_DEVICE AND NOT ON_INFER)
set(WITH_CUSTOM_DEVICE ON)
endif()

if(WIN32)
if(WITH_DISTRIBUTE)
MESSAGE(WARNING
Expand Down
4 changes: 4 additions & 0 deletions cmake/configure.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,7 @@ endif(ON_INFER)
if(WITH_CRYPTO)
add_definitions(-DPADDLE_WITH_CRYPTO)
endif(WITH_CRYPTO)

if(WITH_CUSTOM_DEVICE AND NOT WIN32)
add_definitions(-DPADDLE_WITH_CUSTOM_DEVICE)
endif()
5 changes: 5 additions & 0 deletions paddle/fluid/framework/dlpack_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ struct DLDeviceVisitor : public boost::static_visitor<::DLDevice> {
platform::errors::Unimplemented("platform::MLUPlace is not supported"));
}

inline ::DLDevice operator()(const platform::CustomPlace &place) const {
PADDLE_THROW(platform::errors::Unimplemented(
"platform::CustomPlace is not supported"));
}

inline ::DLDevice operator()(const platform::CUDAPlace &place) const {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
::DLDevice device;
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/framework/executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,20 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
#else
PADDLE_THROW(
platform::errors::Unimplemented("No MLU gc found in CPU/MLU paddle"));
#endif
} else if (platform::is_custom_place(place_)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (IsFastEagerDeletionModeEnabled()) {
VLOG(4) << "Use unsafe fast gc for " << place_ << ".";
gc.reset(new CustomDeviceUnsafeFastGarbageCollector(place_,
max_memory_size));
} else {
VLOG(4) << "Use default stream gc for " << place_ << ".";
gc.reset(
new CustomDefaultStreamGarbageCollector(place_, max_memory_size));
}
#else
PADDLE_THROW(platform::errors::Unimplemented("No CustomDevice gc found"));
#endif
}
}
Expand Down
53 changes: 53 additions & 0 deletions paddle/fluid/framework/garbage_collector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#endif
#include "gflags/gflags.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/platform/device/device_wrapper.h"

DECLARE_double(eager_delete_tensor_gb);
DECLARE_double(memory_fraction_of_eager_deletion);
Expand Down Expand Up @@ -202,6 +203,58 @@ void MLUStreamGarbageCollector::ClearCallback(
}
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
CustomDefaultStreamGarbageCollector::CustomDefaultStreamGarbageCollector(
const platform::CustomPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}

void CustomDefaultStreamGarbageCollector::Wait() const {
static_cast<platform::CustomDeviceContext *>(this->dev_ctx_)
->WaitStreamCallback();
}

void CustomDefaultStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
static_cast<platform::CustomDeviceContext *>(this->dev_ctx_)
->AddStreamCallback(callback);
}

CustomDeviceUnsafeFastGarbageCollector::CustomDeviceUnsafeFastGarbageCollector(
const platform::CustomPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {}

void CustomDeviceUnsafeFastGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback();
}

CustomStreamGarbageCollector::CustomStreamGarbageCollector(
const platform::CustomPlace &place, size_t max_memory_size)
: GarbageCollector(place, max_memory_size) {
platform::DeviceGuard guard(place);
stream_.reset(new platform::stream::Stream);
stream_->Init(place);
callback_manager_.reset(new platform::CallbackManager(stream_.get()));
}

CustomStreamGarbageCollector::~CustomStreamGarbageCollector() {
platform::DeviceGuard guard(this->dev_ctx_->GetPlace());
stream_->Synchronize();
stream_->Destroy();
}

platform::stream::Stream *CustomStreamGarbageCollector::stream() const {
return stream_.get();
}

void CustomStreamGarbageCollector::Wait() const { callback_manager_->Wait(); }

void CustomStreamGarbageCollector::ClearCallback(
const std::function<void()> &callback) {
callback_manager_->AddCallback(callback);
}
#endif

int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
Expand Down
41 changes: 41 additions & 0 deletions paddle/fluid/framework/garbage_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,47 @@ class MLUStreamGarbageCollector : public GarbageCollector {
};
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
class CustomDefaultStreamGarbageCollector : public GarbageCollector {
public:
CustomDefaultStreamGarbageCollector(const platform::CustomPlace &place,
size_t max_memory_size);

void Wait() const override;

protected:
void ClearCallback(const std::function<void()> &callback) override;
};

class CustomDeviceUnsafeFastGarbageCollector : public GarbageCollector {
public:
CustomDeviceUnsafeFastGarbageCollector(const platform::CustomPlace &place,
size_t max_memory_size);

protected:
void ClearCallback(const std::function<void()> &callback) override;
};

class CustomStreamGarbageCollector : public GarbageCollector {
public:
CustomStreamGarbageCollector(const platform::CustomPlace &place,
size_t max_memory_size);

~CustomStreamGarbageCollector();

void Wait() const override;

platform::stream::Stream *stream() const;

protected:
void ClearCallback(const std::function<void()> &callback) override;

private:
std::unique_ptr<platform::stream::Stream> stream_;
std::unique_ptr<platform::CallbackManager> callback_manager_;
};
#endif

template <typename Container>
void GarbageCollector::Add(Container &&objs) {
Add(std::forward<Container>(objs), []() {});
Expand Down
12 changes: 11 additions & 1 deletion paddle/fluid/framework/op_kernel_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,20 @@ size_t OpKernelType::Hash::operator()(const OpKernelType& key) const {
"Too many OpKernel attribute values, expected maximum "
"value is 64, received value is %d.",
cur_loc));

#ifdef PADDLE_WITH_CUSTOM_DEVICE
std::hash<int> hasher;
size_t seed =
hasher(place + data_type + data_layout + library_type + customized_value);
if (platform::is_custom_place(key.place_)) {
seed ^= std::hash<std::string>{}(key.place_.GetDeviceType()) + 0x9e3779b9 +
(seed << 6) + (seed >> 2) + 4;
}
return seed;
#else
std::hash<int> hasher;
return hasher(place + data_type + data_layout + library_type +
customized_value);
#endif
}

bool OpKernelType::operator==(const OpKernelType& o) const {
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/fluid/framework/transfer_scope_cache.h"
#include "paddle/fluid/framework/unused_var_check.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/pten/common/scalar.h"
Expand Down Expand Up @@ -244,6 +245,15 @@ void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
#else
auto dev_id = place.device;
platform::SetMLUDeviceId(dev_id);
#endif
} else if (platform::is_custom_place(place)) {
#ifndef PADDLE_WITH_CUSTOM_DEVICE
PADDLE_THROW(platform::errors::Unavailable(
"Cannot run operator on place %s, please recompile paddle or "
"reinstall Paddle with CustomDevice support.",
place));
#else
platform::DeviceManager::SetDevice(place);
#endif
}

Expand Down
15 changes: 15 additions & 0 deletions paddle/fluid/framework/parallel_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,21 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use XPU device since it's not compiled with XPU,"
"Please recompile or reinstall Paddle with XPU support."));
#endif
} else if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
if (IsFastEagerDeletionModeEnabled()) {
gc.reset(
new CustomDeviceUnsafeFastGarbageCollector(place, max_memory_size));
} else {
gc.reset(new CustomStreamGarbageCollector(place, max_memory_size));
}
VLOG(10) << "Created " << i << "-th GarbageCollector at " << place;
#else
PADDLE_THROW(platform::errors::PermissionDenied(
"Paddle can't use custom device since it's not compiled with "
"CustomDevice,"
"Please recompile or reinstall Paddle with CustomDevice support."));
#endif
} else if (platform::is_cpu_place(place)) {
gc.reset(new CPUGarbageCollector(place, max_memory_size));
Expand Down
Loading

0 comments on commit 3e7825f

Please sign in to comment.