Skip to content

Commit

Permalink
[PTen] Compatible runtime performance optimization (#36946)
Browse files Browse the repository at this point in the history
* resolve conflit with develop

* cache kernel context in tracer for perf up

* replace densetensor when build kernel context

* fix detail compile error

* append impl to static mode

* fix conflit error

* clear attrs after run kernel

* fix coverage failed

* fix cycle compile error

* remove multi-in&out adapt code

* remove tensor meta utils

* clear data when throw exception
  • Loading branch information
chenwhql authored Nov 10, 2021
1 parent ad44a40 commit 76d2fd1
Show file tree
Hide file tree
Showing 19 changed files with 420 additions and 114 deletions.
89 changes: 57 additions & 32 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// phase
if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) {
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
ChoosePtenKernel(exe_ctx);
}
run_pten_kernel_ = pt_kernel_->IsValid();
Expand Down Expand Up @@ -1178,8 +1178,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp);
if (run_pten_kernel_) {
auto op_kernel_ctx = BuildPtenKernelContext(*runtime_ctx, *dev_ctx);
(*pt_kernel_)(&op_kernel_ctx);
if (pt_kernel_context_ == nullptr) {
pt_kernel_context_.reset(new pten::KernelContext());
}
BuildPtenKernelContext(*runtime_ctx, dev_ctx);
(*pt_kernel_)(pt_kernel_context_.get());
pt_kernel_context_->ClearData();
} else {
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
Expand Down Expand Up @@ -1765,16 +1769,16 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
return KernelSignatureMap::Instance().Get(Type());
}

pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const {
void OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const {
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
// 2. the dispensbale, duplicable input and output
// 3. needless attributes remove
// 4. use pt Tensor directly
// 5. kernel input is not DenseTensor
pten::KernelContext op_kernel_ctx(dev_ctx);
pt_kernel_context_->SetDeviceContext(dev_ctx);

auto& input_names = std::get<0>(pt_kernel_signature_->args);
auto& attr_names = std::get<1>(pt_kernel_signature_->args);
Expand Down Expand Up @@ -1803,30 +1807,53 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
attr_names.size(), attr_defs.size()));

for (size_t i = 0; i < input_names.size(); ++i) {
auto in_def = input_defs.at(i);
VLOG(2) << "in_def: " << in_def.backend << ", " << in_def.dtype << ", "
<< in_def.layout;

auto ins_vector = ctx.inputs.at(input_names[i]);

paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs;
for (auto var : ins_vector) {
tmp_inputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(*var, in_def));
auto& in_def = input_defs.at(i);
auto& ins_vector = ctx.inputs.at(input_names[i]);
if (pt_kernel_context_->InputsSize() <= i) {
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs;
for (auto* var : ins_vector) {
tmp_inputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(*var, in_def));
}
pt_kernel_context_->EmplaceBackInputs(std::move(tmp_inputs));
} else {
size_t input_size = pt_kernel_context_->InputsSize();
for (size_t j = 0; j < ins_vector.size(); ++j) {
if (input_size > i + j) {
experimental::ReMakePtenDenseTensorFromVar(
*ins_vector[j], in_def,
pt_kernel_context_->MutableInputAt<pten::DenseTensor>(i + j));
}
// TODO(chenweihang): adapt multi-input case later
}
pt_kernel_context_->MutableInputRangeAt(i) =
std::make_pair(i, i + ins_vector.size());
}
op_kernel_ctx.EmplaceBackInputs(std::move(tmp_inputs));
}

for (size_t i = 0; i < output_names.size(); ++i) {
auto out_def = output_defs.at(i);
auto outs_vector = ctx.outputs.at(output_names[i]);

paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
for (auto var : outs_vector) {
tmp_outputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(var, out_def));
auto& out_def = output_defs.at(i);
auto& outs_vector = ctx.outputs.at(output_names[i]);
if (pt_kernel_context_->OutputsSize() <= i) {
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
for (auto* var : outs_vector) {
tmp_outputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(var, out_def));
}
pt_kernel_context_->EmplaceBackOutputs(std::move(tmp_outputs));
} else {
size_t output_size = pt_kernel_context_->OutputsSize();
for (size_t j = 0; j < outs_vector.size(); ++j) {
if (output_size > i + j) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[j], out_def,
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(i + j));
}
// TODO(chenweihang): adapt multi-output case later
}
pt_kernel_context_->MutableOutputRangeAt(i) =
std::make_pair(i, i + outs_vector.size());
}
op_kernel_ctx.EmplaceBackOutputs(std::move(tmp_outputs));
}

for (size_t i = 0; i < attr_names.size(); ++i) {
Expand All @@ -1836,11 +1863,11 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr(
pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
op_kernel_ctx.EmplaceBackAttr(
pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
Expand All @@ -1851,11 +1878,11 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
} else {
// TODO(chenweihang): support other attrs later
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
Expand All @@ -1864,8 +1891,6 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
}
}
}

return op_kernel_ctx;
}

} // namespace framework
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,8 @@ class OperatorWithKernel : public OperatorBase {
/* member functions for adapting to pten lib */
void ChoosePtenKernel(const ExecutionContext& ctx) const;

pten::KernelContext BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const;

protected:
mutable std::unique_ptr<OpKernelType> kernel_type_;
Expand All @@ -605,6 +605,9 @@ class OperatorWithKernel : public OperatorBase {
mutable bool run_pten_kernel_ = false;
mutable std::unique_ptr<KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<pten::Kernel> pt_kernel_;
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
mutable std::unique_ptr<pten::KernelContext> pt_kernel_context_;
};

extern bool OpSupportGPU(const std::string& op_type);
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/imperative/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
cc_library(imperative_flag SRCS flags.cc DEPS gflags flags)

IF(WITH_XPU)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten_utils)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils)
ELSE()
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten_utils)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils)
ENDIF()
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
add_subdirectory(jit)
cc_library(amp SRCS amp_auto_cast.cc DEPS layer )
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc DEPS flags)
Expand Down
15 changes: 10 additions & 5 deletions paddle/fluid/imperative/layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ void VarBase::BumpInplaceVersion() {
MutableVar()->BumpInplaceVersion();
}

pten::KernelContext OpBase::pt_kernel_context_;

void OpBase::SetType(const std::string& type) {
op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
}
Expand All @@ -371,7 +373,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
const platform::Place& place,
pten::KernelContext* pt_kernel_context) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(
op_kernel, platform::errors::PermissionDenied(
Expand Down Expand Up @@ -412,8 +415,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
* after the execution of op, but the original input is directly
* overwritten in the previous dynamic graph implemention.
*/
auto prepared_op =
PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs);
auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs,
default_attrs, pt_kernel_context);
auto tmp_ins_ptr =
PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type());
if (tmp_ins_ptr == nullptr) {
Expand Down Expand Up @@ -441,7 +444,8 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place);
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place,
&pt_kernel_context_);
}

void OpBase::Run(const framework::OperatorBase& op,
Expand All @@ -450,7 +454,8 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place);
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place,
&pt_kernel_context_);
}

void ClearNoNeedBufferInputs(OpBase* op) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/imperative/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/pten/include/core.h"

namespace paddle {
namespace framework {
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/imperative/op_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/pten/include/core.h"

namespace paddle {
namespace imperative {
Expand Down Expand Up @@ -183,6 +184,8 @@ class OpBase {
const framework::AttributeMap& default_attrs,
const platform::Place& place);

static pten::KernelContext* GetKernelContext() { return &pt_kernel_context_; }

private:
static const std::string& UnknownOpType() {
static std::string kUnknownOpType{"unknown"};
Expand All @@ -197,6 +200,9 @@ class OpBase {
std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_;
size_t id_{-1UL};
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
static pten::KernelContext pt_kernel_context_;
};

class GradOpNode {
Expand Down
Loading

0 comments on commit 76d2fd1

Please sign in to comment.