Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PTen】Add dot and matmul grad kernel in pten #38713

Merged
merged 26 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
76608b5
refactor matmul directory in pten
zyfncg Dec 16, 2021
beecda8
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Dec 20, 2021
43b3267
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Dec 21, 2021
63d1681
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Dec 22, 2021
8cfb873
fix merge conflict
zyfncg Dec 22, 2021
2dffa46
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Dec 23, 2021
0f3c6d1
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Dec 23, 2021
2f077b4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Dec 23, 2021
d6217d2
add dot_grad kernel
zyfncg Dec 23, 2021
dc01e07
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Dec 23, 2021
a3658f5
add dot_grad kernel in pten
zyfncg Dec 24, 2021
41c0b85
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Dec 24, 2021
682059d
add matmul_grad kernel
zyfncg Dec 24, 2021
6a5c6b9
Merge commit 'refs/pull/38441/head' of https://github.com/PaddlePaddl…
zyfncg Dec 24, 2021
a8830fc
Merge commit 'refs/pull/38227/head' of https://github.com/PaddlePaddl…
zyfncg Dec 24, 2021
0846c48
update the code
zyfncg Dec 29, 2021
83d7207
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Dec 29, 2021
0d0f654
delete useless code in fluid
zyfncg Jan 5, 2022
e2e4af0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Jan 5, 2022
854c6fb
fix some bug of running matmul grad kernel
zyfncg Jan 5, 2022
8b60eaa
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Jan 5, 2022
3c955bd
fix merge conflict
zyfncg Jan 5, 2022
ab5c095
refactor some code
zyfncg Jan 6, 2022
bc8a80b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Jan 6, 2022
c113ab5
refactor code
zyfncg Jan 6, 2022
fc13a47
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
zyfncg Jan 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions cmake/pten_kernel.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ function(kernel_library TARGET)
endif()

list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.h)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h)
list(APPEND all_srcs ${CMAKE_CURRENT_SOURCE_DIR}/impl/${TARGET}_impl.h)
endif()
list(APPEND all_srcs ${common_srcs})
list(APPEND all_srcs ${cpu_srcs})
list(APPEND all_srcs ${gpu_srcs})
Expand Down
26 changes: 22 additions & 4 deletions paddle/fluid/framework/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1875,16 +1875,32 @@ void OperatorWithKernel::BuildPtenKernelContext(
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset], out_def,
auto* buffer_tensor =
pt_kernel_context_->MutableOutputAt<pten::DenseTensor>(start_idx +
offset));
offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(outs_vector[offset],
out_def, buffer_tensor);
}
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(outs_vector[offset],
out_def));
}
}

// Deal with the case that some outputs are NULL when run the kernel.
// For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
if (outs_vector.empty()) {
if (current_vector_size > start_idx) {
pt_kernel_context_->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
pt_kernel_context_->EmplaceBackOutputWithoutSetRange({nullptr});
}
end_idx = start_idx + 1;
Comment on lines +1899 to +1904
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加点注释吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

pt_kernel_context_->AssignOutputRange(std::make_pair(start_idx, end_idx),
i);
}
Expand Down Expand Up @@ -1997,7 +2013,9 @@ void OperatorWithKernel::WriteBackToOutputs(RuntimeContext* ctx) const {
range_pair.first, range_pair.second);

for (size_t j = 0; j < pten_outs.size(); ++j) {
experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]);
if (pten_outs[j]) {
experimental::MakeVariableFromPtenTensor(pten_outs[j], outs_vector[j]);
}
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/framework/pten_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ KernelSignatureMap& KernelSignatureMap::Instance() {
for (const auto& pair : OpInfoMap::Instance().map()) {
const auto& op_type = pair.first;
const auto* op_proto = pair.second.proto_;
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
if (pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type) &&
op_proto) {
KernelArgsNameMakerByOpProto maker(op_proto);
VLOG(10) << "Register kernel signature for " << op_type;
auto success = kernel_signature_map_->map_
Expand Down
66 changes: 47 additions & 19 deletions paddle/fluid/imperative/prepared_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,19 +338,41 @@ static void BuildDygraphPtenKernelContext(

for (size_t i = 0; i < output_names.size(); ++i) {
auto& out_def = output_defs.at(i);
auto& outs_vector = outs.at(output_names[i]);

size_t start_idx = (i == 0 ? 0 : kernel_ctx->OutputRangeAt(i - 1).second);
size_t end_idx = start_idx + outs_vector.size();
auto current_vector_size = kernel_ctx->OutputsSize();

auto iter = outs.find(output_names[i]);
if (iter == outs.end()) {
if (current_vector_size > start_idx) {
kernel_ctx->SetOutputWithoutSetRange(start_idx, {nullptr});
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange({nullptr});
}
kernel_ctx->AssignOutputRange(std::make_pair(start_idx, start_idx + 1),
i);
continue;
}

auto& outs_vector = iter->second;
size_t end_idx = start_idx + outs_vector.size();

// If the memory needed is less than the current memory allocated, we will
// reuse the current memory by using ReMakePtenDenseTensorFromVar.
// Otherwise,we will create new storage.
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
if (current_vector_size > start_idx + offset) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset]->MutableVar(), out_def,
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset));
auto* buffer_tensor =
kernel_ctx->MutableOutputAt<pten::DenseTensor>(start_idx + offset);
if (buffer_tensor) {
experimental::ReMakePtenDenseTensorFromVar(
outs_vector[offset]->MutableVar(), out_def, buffer_tensor);
} else {
kernel_ctx->SetOutputWithoutSetRange(
start_idx + offset,
experimental::MakePtenTensorBaseFromVar(
outs_vector[offset]->MutableVar(), out_def));
}
} else {
kernel_ctx->EmplaceBackOutputWithoutSetRange(
experimental::MakePtenTensorBaseFromVar(
Expand Down Expand Up @@ -465,15 +487,18 @@ static void WriteBackToOutputs(
auto& output_names = std::get<2>(pt_kernel_signature.args);

for (size_t i = 0; i < output_names.size(); ++i) {
auto& outs_vector = outs.at(output_names[i]);
auto iter = outs.find(output_names[i]);
if (iter != outs.end()) {
auto& outs_vector = iter->second;

auto& range_pair = kernel_ctx->OutputRangeAt(i);
auto pten_outs = kernel_ctx->MutableOutputBetween<pten::DenseTensor>(
range_pair.first, range_pair.second);
auto& range_pair = kernel_ctx->OutputRangeAt(i);
auto pten_outs = kernel_ctx->MutableOutputBetween<pten::DenseTensor>(
range_pair.first, range_pair.second);

for (size_t j = 0; j < pten_outs.size(); ++j) {
experimental::MakeVariableFromPtenTensor(pten_outs[j],
outs_vector[j]->MutableVar());
for (size_t j = 0; j < pten_outs.size(); ++j) {
experimental::MakeVariableFromPtenTensor(pten_outs[j],
outs_vector[j]->MutableVar());
}
}
}
}
Expand Down Expand Up @@ -530,6 +555,7 @@ static void PreparedOpRunImpl(
template <typename VarType>
static void PreparedOpRunPtImpl(
const framework::OperatorBase& op,
const framework::OpKernelType& kernel_type,
const framework::KernelSignature& pt_kernel_signature,
const pten::Kernel& pt_kernel, pten::KernelContext* pt_kernel_context,
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
Expand Down Expand Up @@ -560,17 +586,19 @@ static void PreparedOpRunPtImpl(
pt_kernel_context->ClearData();

// TODO(chenweihang): add debug flags later
// TODO(chenweihang): deal with complex cases later
if (framework::IsComplexType(kernel_type.data_type_)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是否可以使用pten_kernel的data type

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

由于传入的KernelSignatureKernel数据结构都不具有data_type信息,所以需要使用kernel_type的数据

HandleComplexGradToRealGrad<VarType>(outs);
}
}

void PreparedOp::Run(const NameVarMap<VarBase>& ins,
const NameVarMap<VarBase>& outs,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<VarBase>(op_, pt_kernel_signature_, pt_kernel_,
pt_kernel_context_, dev_ctx_, ins, outs, attrs,
default_attrs);
PreparedOpRunPtImpl<VarBase>(op_, kernel_type_, pt_kernel_signature_,
pt_kernel_, pt_kernel_context_, dev_ctx_, ins,
outs, attrs, default_attrs);
} else {
PreparedOpRunImpl<VarBase>(op_, ctx_, kernel_type_, func_, dev_ctx_, ins,
outs, attrs, default_attrs);
Expand All @@ -582,9 +610,9 @@ void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
if (run_pten_kernel_) {
PreparedOpRunPtImpl<VariableWrapper>(op_, pt_kernel_signature_, pt_kernel_,
pt_kernel_context_, dev_ctx_, ins,
outs, attrs, default_attrs);
PreparedOpRunPtImpl<VariableWrapper>(
op_, kernel_type_, pt_kernel_signature_, pt_kernel_, pt_kernel_context_,
dev_ctx_, ins, outs, attrs, default_attrs);
} else {
PreparedOpRunImpl<VariableWrapper>(op_, ctx_, kernel_type_, func_, dev_ctx_,
ins, outs, attrs, default_attrs);
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/conj_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ConjKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);

// call new kernel
pten::Conj<T>(dev_ctx, *pt_x.get(), pt_out.get());
pten::ConjKernel<T>(dev_ctx, *pt_x.get(), pt_out.get());
}
};

Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/operators/dot_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ class DotGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}

framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext& ctx) const override {
return framework::KernelSignature(
"dot_grad", {"X", "Y", framework::GradVarName("Out")}, {},
{framework::GradVarName("X"), framework::GradVarName("Y")});
}
};

template <typename T>
Expand Down
Loading