-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NewIR]Change feed list to variable list && support GPU #55401
Changes from all commits
e762fbd
f316d37
a0a8860
3b0bd59
407fa01
def81b7
f0c12cd
69ebc87
aa40caf
27ffdc0
8220386
bf97e13
2b114ba
7da618d
46a30b8
f2a74ef
8dd9292
5fc76f1
0e9f279
c6eaf29
8248344
a84895a
4eb6b15
4d8707e
4feb7fc
d8be6ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,7 +53,7 @@ phi::KernelKey GetKernelKey( | |
ir::Operation* op, | ||
const phi::Place& place, | ||
const std::unordered_map<ir::Value, ir::OpResult>& map_value_pair, | ||
const dialect::OpYamlInfoParser* op_info_parser = nullptr) { | ||
std::unique_ptr<dialect::OpYamlInfoParser> op_info_parser = nullptr) { | ||
if (op->name() == "pd.feed") { | ||
// NOTE, for now feed op don't need a kernel, so the data type from Op | ||
// Result the next op use base program datatype | ||
|
@@ -223,11 +223,11 @@ phi::KernelKey GetKernelKey( | |
return res; | ||
} | ||
|
||
std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { | ||
std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog, | ||
phi::Place place) { | ||
auto program = std::make_unique<ir::Program>(ir::IrContext::Instance()); | ||
|
||
auto block = prog->block(); | ||
phi::Place cpu_place(phi::AllocationType::CPU); | ||
|
||
ir::IrContext* ctx = ir::IrContext::Instance(); | ||
ctx->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); | ||
|
@@ -244,14 +244,19 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { | |
VLOG(6) << "op name " << (*it)->name(); | ||
paddle::dialect::OpYamlInfoInterface op_info_interface = | ||
(*it)->dyn_cast<paddle::dialect::OpYamlInfoInterface>(); | ||
OpYamlInfoParser* op_info_parser = nullptr; | ||
std::unique_ptr<OpYamlInfoParser> op_info_parser; | ||
if (op_info_interface) { | ||
op_info_parser = new OpYamlInfoParser(op_info_interface.GetOpInfo()); | ||
op_info_parser.reset(new OpYamlInfoParser(op_info_interface.GetOpInfo())); | ||
} | ||
|
||
std::string kernel_fn_str; | ||
if (op_info_parser != nullptr) { | ||
kernel_fn_str = op_info_parser->OpRuntimeInfo().kernel_func[0]; | ||
} | ||
|
||
auto kernel_key = | ||
GetKernelKey(*it, cpu_place, map_value_pair, op_info_parser); | ||
GetKernelKey(*it, place, map_value_pair, std::move(op_info_parser)); | ||
VLOG(6) << "kernel type " << kernel_key; | ||
// create new Op | ||
|
||
// only for single output | ||
// need update new kernel key layout and data tyep | ||
|
@@ -305,11 +310,6 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { | |
// constuct input | ||
std::vector<ir::OpResult> vec_inputs; | ||
|
||
std::string kernel_fn_str; | ||
if (op_info_parser != nullptr) { | ||
kernel_fn_str = op_info_parser->OpRuntimeInfo().kernel_func[0]; | ||
} | ||
|
||
if ((*it)->num_operands() > 0) { | ||
for (size_t i = 0; i < (*it)->num_operands(); ++i) { | ||
auto cur_in = (*it)->operand(i); | ||
|
@@ -404,6 +404,35 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog) { | |
} | ||
|
||
program->block()->push_back(op); | ||
|
||
if ((*it)->name() == "pd.feed" && platform::is_gpu_place(place)) { | ||
// add shaddow feed op | ||
phi::KernelKey shaddow_key{ | ||
phi::Backend::GPU, | ||
phi::DataLayout::ANY, | ||
TransToPhiDataType( | ||
(*it)->result(0).type().dyn_cast<DenseTensorType>().dtype())}; | ||
std::unordered_map<std::string, ir::Attribute> attr_map{ | ||
{"op_name", ir::StrAttribute::get(ctx, "pd.shaddow_feed")}, | ||
{"kernel_name", ir::StrAttribute::get(ctx, "shaddow_feed")}, | ||
{"kernel_key", dialect::KernelAttribute::get(ctx, shaddow_key)}}; | ||
|
||
auto out_type = paddle::dialect::AllocatedDenseTensorType::get( | ||
ctx, | ||
phi::TransToPhiPlace(shaddow_key.backend()), | ||
(*it)->result(0).type().dyn_cast<dialect::DenseTensorType>()); | ||
|
||
ir::Operation* shaddow_op = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 从这里代码看来,我们还是希望有类似 ir::shadow(xxx) C++端API的,这样能简洁很多代码 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的 |
||
ir::Operation::Create({op->result(0)}, attr_map, {out_type}, op_info); | ||
|
||
map_op_pair[*it] = shaddow_op; | ||
program->block()->push_back(shaddow_op); | ||
if ((*it)->num_results() > 0) { | ||
for (size_t i = 0; i < shaddow_op->num_results(); ++i) { | ||
map_value_pair[(*it)->result(i)] = shaddow_op->result(i); | ||
} | ||
} | ||
} | ||
} | ||
|
||
return program; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,11 +14,13 @@ | |
#pragma once | ||
|
||
#include "paddle/ir/core/program.h" | ||
#include "paddle/phi/common/place.h" | ||
|
||
namespace paddle { | ||
namespace dialect { | ||
|
||
std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog); | ||
std::unique_ptr<ir::Program> PdOpLowerToKernelPass( | ||
ir::Program* prog, phi::Place place = phi::CPUPlace()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是不是可以不加默认的Place值? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是需要的,我后续调整下 |
||
|
||
} // namespace dialect | ||
} // namespace paddle |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/impl/feed_with_place_impl.h" | ||
|
||
namespace phi { | ||
|
||
|
@@ -26,11 +27,20 @@ void FeedWithPlaceKernel(const Context& ctx, | |
DenseTensor* out) {} | ||
|
||
} // namespace phi | ||
PD_REGISTER_KERNEL(feed_with_place, | ||
|
||
PD_REGISTER_KERNEL( | ||
feed_with_place, CPU, ALL_LAYOUT, phi::FeedWithPlaceKernel, float) {} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 原来feed_with_place 注册了4种类型,这里为什么只保留了1种? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 当前的验证,只用到了float,是需要补全的 |
||
|
||
PD_REGISTER_KERNEL(shaddow_feed, | ||
CPU, | ||
ALL_LAYOUT, | ||
phi::FeedWithPlaceKernel, | ||
phi::ShaddowFeedKernel, | ||
bool, | ||
float, | ||
int32_t, | ||
int64_t, | ||
double) {} | ||
double, | ||
phi::float16, | ||
phi::bfloat16, | ||
phi::complex64, | ||
phi::complex128) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/feed_with_place_kernel.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/impl/feed_with_place_impl.h" | ||
|
||
PD_REGISTER_KERNEL(shaddow_feed, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::ShaddowFeedKernel, | ||
bool, | ||
float, | ||
int32_t, | ||
int64_t, | ||
double, | ||
phi::float16, | ||
phi::bfloat16, | ||
phi::complex64, | ||
phi::complex128) {} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/fetch_kernel.h" | ||
|
||
#include "paddle/phi/kernels/impl/fetch_impl.h" | ||
|
||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
PD_REGISTER_KERNEL(fetch, | ||
GPU, | ||
ALL_LAYOUT, | ||
phi::FetchKernel, | ||
float, | ||
double, | ||
int, | ||
int64_t, | ||
uint8_t, | ||
int8_t, | ||
int16_t, | ||
phi::float16, | ||
phi::bfloat16, | ||
phi::dtype::complex<float>, | ||
phi::dtype::complex<double>, | ||
bool) {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我查了下,应该是shadow?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我改一下