Skip to content

Commit

Permalink
[PIR] Fix bug of invalid enum backend (#59070)
Browse files Browse the repository at this point in the history
* struct_kernel_has_no_arg_def

* refine code

* refine code
  • Loading branch information
kangguangli authored Nov 17, 2023
1 parent fa0c858 commit 71ffae7
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,15 @@ static phi::Backend DeriveBackend(const std::string& op,
return kernel_backend;
}

static phi::Backend ChooseInputBackend(const phi::Kernel& kernel,
size_t input_index,
phi::Backend default_backend) {
if (kernel.GetKernelRegisteredType() == phi::KernelRegisteredType::FUNCTION) {
return kernel.InputAt(input_index).backend;
}
return default_backend;
}

static std::set<std::string> GetInputsByDataOp(pir::Block* block) {
std::set<std::string> data_op_names;
for (auto op_item : *block) {
Expand Down Expand Up @@ -1236,12 +1245,10 @@ std::vector<pir::Value> BuildInputs(
auto args_def = kernel.args_def();
auto input_defs = args_def.input_defs();

auto dst_backend =
DeriveBackend(op_item->name(),
place,
op_info_parser,
kernel.InputAt(tensor_param_index).backend,
i);
auto input_backend = ChooseInputBackend(
kernel, tensor_param_index, kernel_key.backend());
auto dst_backend = DeriveBackend(
op_item->name(), place, op_info_parser, input_backend, i);
VLOG(6) << "Infer kernel backend from input " << i << " of op "
<< op_item->name();

Expand Down Expand Up @@ -1295,18 +1302,19 @@ std::vector<pir::Value> BuildInputs(
auto args_def = kernel.args_def();
auto input_defs = args_def.input_defs();

auto input_backend = ChooseInputBackend(
kernel, tensor_param_index, kernel_key.backend());
bool need_trans =
(place.GetType() != phi::AllocationType::UNDEFINED) &&
(op_info_parser != nullptr &&
!op_info_parser->IsTensorAttribute(i)) &&
(paddle::experimental::NeedTransformPlace(
place, kernel.InputAt(tensor_param_index).backend, {}));
place, input_backend, {}));
if (need_trans) {
VLOG(6) << "need trans from " << place << " to "
<< kernel_key.backend();
// build memcopy op
auto out_place = phi::TransToPhiPlace(
kernel.InputAt(tensor_param_index).backend);
auto out_place = phi::TransToPhiPlace(input_backend);
pir::Type out_type;
if (in_i_type.isa<AllocatedDenseTensorType>()) {
out_type = AllocatedDenseTensorType::get(
Expand Down Expand Up @@ -1359,12 +1367,10 @@ std::vector<pir::Value> BuildInputs(
auto args_def = kernel.args_def();
auto input_defs = args_def.input_defs();

auto dst_backend =
DeriveBackend(op_item->name(),
place,
op_info_parser,
kernel.InputAt(tensor_param_index).backend,
i);
auto input_backend = ChooseInputBackend(
kernel, tensor_param_index, kernel_key.backend());
auto dst_backend = DeriveBackend(
op_item->name(), place, op_info_parser, input_backend, i);
VLOG(6) << "Infer kernel backend from input " << i << " of op ";
bool need_trans =
(in_place.GetType() != phi::AllocationType::UNDEFINED) &&
Expand Down Expand Up @@ -1397,12 +1403,10 @@ std::vector<pir::Value> BuildInputs(
auto args_def = kernel.args_def();
auto input_defs = args_def.input_defs();

auto dst_backend =
DeriveBackend(op_item->name(),
place,
op_info_parser,
kernel.InputAt(tensor_param_index).backend,
i);
auto input_backend = ChooseInputBackend(
kernel, tensor_param_index, kernel_key.backend());
auto dst_backend = DeriveBackend(
op_item->name(), place, op_info_parser, input_backend, i);
VLOG(6) << "Infer kernel backend from input " << i << " of op ";
bool need_trans =
(in_place.GetType() != phi::AllocationType::UNDEFINED) &&
Expand Down

0 comments on commit 71ffae7

Please sign in to comment.