Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTen] Compatible runtime performance optimization #36946

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 57 additions & 32 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
// phase
if (FLAGS_run_pten_kernel &&
pten::KernelFactory::Instance().HasCompatiblePtenKernel(type_)) {
if (pt_kernel_signature_.get() == nullptr || pt_kernel_.get() == nullptr) {
if (pt_kernel_signature_ == nullptr || pt_kernel_ == nullptr) {
ChoosePtenKernel(exe_ctx);
}
run_pten_kernel_ = pt_kernel_->IsValid();
Expand Down Expand Up @@ -1178,8 +1178,12 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("compute",
platform::EventRole::kInnerOp);
if (run_pten_kernel_) {
auto op_kernel_ctx = BuildPtenKernelContext(*runtime_ctx, *dev_ctx);
(*pt_kernel_)(&op_kernel_ctx);
if (pt_kernel_context_ == nullptr) {
pt_kernel_context_.reset(new pten::KernelContext());
}
BuildPtenKernelContext(*runtime_ctx, dev_ctx);
(*pt_kernel_)(pt_kernel_context_.get());
pt_kernel_context_->ClearData();
} else {
(*kernel_func_)(
ExecutionContext(*this, exec_scope, *dev_ctx, *runtime_ctx));
Expand Down Expand Up @@ -1765,16 +1769,16 @@ KernelSignature OperatorWithKernel::GetExpectedPtenKernelArgs(
return KernelSignatureMap::Instance().Get(Type());
}

pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const {
void OperatorWithKernel::BuildPtenKernelContext(
const RuntimeContext& ctx, platform::DeviceContext* dev_ctx) const {
// TODO(chenweihang): now only work for very simple case,
// many cases need to be deal with later:
// 1. the input and output are not tensor
// 2. the dispensbale, duplicable input and output
// 3. needless attributes remove
// 4. use pt Tensor directly
// 5. kernel input is not DenseTensor
pten::KernelContext op_kernel_ctx(dev_ctx);
pt_kernel_context_->SetDeviceContext(dev_ctx);

auto& input_names = std::get<0>(pt_kernel_signature_->args);
auto& attr_names = std::get<1>(pt_kernel_signature_->args);
Expand Down Expand Up @@ -1803,30 +1807,53 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
attr_names.size(), attr_defs.size()));

for (size_t i = 0; i < input_names.size(); ++i) {
auto in_def = input_defs.at(i);
VLOG(2) << "in_def: " << in_def.backend << ", " << in_def.dtype << ", "
<< in_def.layout;

auto ins_vector = ctx.inputs.at(input_names[i]);

paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs;
for (auto var : ins_vector) {
tmp_inputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(*var, in_def));
auto& in_def = input_defs.at(i);
auto& ins_vector = ctx.inputs.at(input_names[i]);
if (pt_kernel_context_->InputsSize() <= i) {
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_inputs;
for (auto* var : ins_vector) {
tmp_inputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(*var, in_def));
}
pt_kernel_context_->EmplaceBackInputs(std::move(tmp_inputs));
} else {
size_t input_size = pt_kernel_context_->InputsSize();
for (size_t j = 0; j < ins_vector.size(); ++j) {
if (input_size > i + j) {
experimental::ReMakePtenDenseTensorFromVar(
*ins_vector[j], in_def,
pt_kernel_context_->MutableInputAt<pten::DenseTensor>(i + j));
}
// TODO(chenweihang): adapt multi-input case later
}
pt_kernel_context_->MutableInputRangeAt(i) =
std::make_pair(i, i + ins_vector.size());
}
op_kernel_ctx.EmplaceBackInputs(std::move(tmp_inputs));
}

for (size_t i = 0; i < output_names.size(); ++i) {
auto out_def = output_defs.at(i);
auto outs_vector = ctx.outputs.at(output_names[i]);

paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
for (auto var : outs_vector) {
tmp_outputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(var, out_def));
auto& out_def = output_defs.at(i);
auto& outs_vector = ctx.outputs.at(output_names[i]);
if (pt_kernel_context_->OutputsSize() <= i) {
paddle::SmallVector<std::shared_ptr<pten::TensorBase>> tmp_outputs;
for (auto* var : outs_vector) {
tmp_outputs.emplace_back(
experimental::MakePtenTensorBaseFromVar(var, out_def));
}
pt_kernel_context_->EmplaceBackOutputs(std::move(tmp_outputs));
} else {
size_t output_size = pt_kernel_context_->OutputsSize();
for (size_t j = 0; j < outs_vector.size(); ++j) {
if (output_size > i + j) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[j], out_def,
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(i + j));
}
// TODO(chenweihang): adapt multi-output case later
}
pt_kernel_context_->MutableOutputRangeAt(i) =
std::make_pair(i, i + outs_vector.size());
}
op_kernel_ctx.EmplaceBackOutputs(std::move(tmp_outputs));
}

for (size_t i = 0; i < attr_names.size(); ++i) {
Expand All @@ -1836,11 +1863,11 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
// TODO(zhangyunfei): Scalar should hold scaler type, and we should check
// attribtue type by attr_defs
if (std::type_index(attr.type()) == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr(
pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(float, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::string))) {
op_kernel_ctx.EmplaceBackAttr(
pt_kernel_context_->EmplaceBackAttr(
std::move(pten::Scalar(BOOST_GET_CONST(std::string, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
Expand All @@ -1851,11 +1878,11 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
} else {
// TODO(chenweihang): support other attrs later
if (attr_defs[i].type_index == std::type_index(typeid(int))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(int, attr));
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(int, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(float))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(float, attr));
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr));
} else if (attr_defs[i].type_index == std::type_index(typeid(bool))) {
op_kernel_ctx.EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"unsupported cast op attribute `%s` when construct "
Expand All @@ -1864,8 +1891,6 @@ pten::KernelContext OperatorWithKernel::BuildPtenKernelContext(
}
}
}

return op_kernel_ctx;
}

} // namespace framework
Expand Down
7 changes: 5 additions & 2 deletions paddle/fluid/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -586,8 +586,8 @@ class OperatorWithKernel : public OperatorBase {
/* member functions for adapting to pten lib */
void ChoosePtenKernel(const ExecutionContext& ctx) const;

pten::KernelContext BuildPtenKernelContext(
const RuntimeContext& ctx, const platform::DeviceContext& dev_ctx) const;
void BuildPtenKernelContext(const RuntimeContext& ctx,
platform::DeviceContext* dev_ctx) const;

protected:
mutable std::unique_ptr<OpKernelType> kernel_type_;
Expand All @@ -605,6 +605,9 @@ class OperatorWithKernel : public OperatorBase {
mutable bool run_pten_kernel_ = false;
mutable std::unique_ptr<KernelSignature> pt_kernel_signature_;
mutable std::unique_ptr<pten::Kernel> pt_kernel_;
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
mutable std::unique_ptr<pten::KernelContext> pt_kernel_context_;
};

extern bool OpSupportGPU(const std::string& op_type);
Expand Down
7 changes: 3 additions & 4 deletions paddle/fluid/imperative/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
cc_library(imperative_flag SRCS flags.cc DEPS gflags flags)

IF(WITH_XPU)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten_utils)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS xpu_op_list proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils)
ELSE()
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten_utils)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils pten pten_utils)
ENDIF()
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
add_subdirectory(jit)
cc_library(amp SRCS amp_auto_cast.cc DEPS layer )
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal)
cc_library(tracer SRCS tracer.cc DEPS layer engine program_desc_tracer amp denormal garbage_collector)
cc_library(basic_engine SRCS basic_engine.cc DEPS layer gradient_accumulator)
cc_library(engine SRCS basic_engine.cc partial_grad_engine.cc DEPS layer gradient_accumulator)
cc_library(imperative_profiler SRCS profiler.cc DEPS flags)
Expand Down
15 changes: 10 additions & 5 deletions paddle/fluid/imperative/layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,8 @@ void VarBase::BumpInplaceVersion() {
MutableVar()->BumpInplaceVersion();
}

pten::KernelContext OpBase::pt_kernel_context_;

void OpBase::SetType(const std::string& type) {
op_ = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
}
Expand All @@ -371,7 +373,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
const NameVarMap<VarType>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
const platform::Place& place,
pten::KernelContext* pt_kernel_context) {
auto* op_kernel = dynamic_cast<const framework::OperatorWithKernel*>(&op);
PADDLE_ENFORCE_NOT_NULL(
op_kernel, platform::errors::PermissionDenied(
Expand Down Expand Up @@ -412,8 +415,8 @@ static void OpBaseRunImpl(const framework::OperatorBase& op,
* after the execution of op, but the original input is directly
* overwritten in the previous dynamic graph implemention.
*/
auto prepared_op =
PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs, default_attrs);
auto prepared_op = PreparedOp::Prepare(ins, outs, *op_kernel, place, attrs,
default_attrs, pt_kernel_context);
auto tmp_ins_ptr =
PrepareData<VarType>(*op_kernel, ins, prepared_op.kernel_type());
if (tmp_ins_ptr == nullptr) {
Expand Down Expand Up @@ -441,7 +444,8 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place);
OpBaseRunImpl<VarBase>(op, ins, outs, attrs, default_attrs, place,
&pt_kernel_context_);
}

void OpBase::Run(const framework::OperatorBase& op,
Expand All @@ -450,7 +454,8 @@ void OpBase::Run(const framework::OperatorBase& op,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs,
const platform::Place& place) {
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place);
OpBaseRunImpl<VariableWrapper>(op, ins, outs, attrs, default_attrs, place,
&pt_kernel_context_);
}

void ClearNoNeedBufferInputs(OpBase* op) {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/imperative/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/pten/include/core.h"

namespace paddle {
namespace framework {
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/imperative/op_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "paddle/fluid/imperative/type_defs.h"
#include "paddle/fluid/imperative/variable_wrapper.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/pten/include/core.h"

namespace paddle {
namespace imperative {
Expand Down Expand Up @@ -183,6 +184,8 @@ class OpBase {
const framework::AttributeMap& default_attrs,
const platform::Place& place);

static pten::KernelContext* GetKernelContext() { return &pt_kernel_context_; }

private:
static const std::string& UnknownOpType() {
static std::string kUnknownOpType{"unknown"};
Expand All @@ -197,6 +200,9 @@ class OpBase {
std::unique_ptr<framework::OperatorBase> op_;
platform::Place place_;
size_t id_{-1UL};
// In order to reduce the compatibility phase
// performance overhead, temporarily cache KernelContext
static pten::KernelContext pt_kernel_context_;
};

class GradOpNode {
Expand Down
Loading