Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] standardize the use of value[-1]. #57349

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ inline pir::Operation* InsertCombineOperationForTarget(
std::string combine_op_name(pir::CombineOp::name());
pir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name);

std::vector<pir::OpResult> src_values;
std::vector<pir::Value> src_values;
std::vector<pir::Type> types_in_vec;
for (const auto& arg_name : args) {
auto defining_info = param_map->at(arg_name);
Expand Down Expand Up @@ -299,7 +299,7 @@ pir::OpResult OpTranscriber::GetAttributeAsInput(
return defining_op->result(0);
}

std::vector<pir::OpResult> OpTranscriber::GenerateOperationInput(
std::vector<pir::Value> OpTranscriber::GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand All @@ -314,7 +314,7 @@ std::vector<pir::OpResult> OpTranscriber::GenerateOperationInput(

VLOG(10) << "[op:" << op_desc.Type() << "][input] start";

std::vector<pir::OpResult> op_inputs;
std::vector<pir::Value> op_inputs;

for (const auto& info : input_infos) {
if (auto special_handler = this->GetSpecialInputHandlers(info.name)) {
Expand Down Expand Up @@ -779,7 +779,7 @@ struct AssignValueOpTranscriber : public OpTranscriber {

VLOG(10) << "[op assign_value] attribute translation done";

std::vector<pir::OpResult> op_inputs = {};
std::vector<pir::Value> op_inputs = {};

OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types;
Expand Down Expand Up @@ -904,7 +904,7 @@ struct FeedOpTranscriber : public OpTranscriber {
return attribute_map;
}

std::vector<pir::OpResult> GenerateOperationInput(
std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand Down Expand Up @@ -942,7 +942,7 @@ struct DataOpTranscriber : public FeedOpTranscriber {
};

struct SplitOpTranscriber : public OpTranscriber {
std::vector<pir::OpResult> GenerateOperationInput(
std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand All @@ -953,7 +953,7 @@ struct SplitOpTranscriber : public OpTranscriber {

VLOG(10) << "[op:split][input] start";

std::vector<pir::OpResult> op_inputs;
std::vector<pir::Value> op_inputs;
// process first input
auto x_input_vars = op_desc.Input("X");
IR_ENFORCE(x_input_vars.size() == 1, "x input of split MUST be a tensor");
Expand Down Expand Up @@ -1085,7 +1085,7 @@ struct ShadowOutputOpTranscriber : public OpTranscriber {
pir::Program* program) override {
auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name());

std::vector<pir::OpResult> op_inputs;
std::vector<pir::Value> op_inputs;
auto legacy_input_vars = op_desc.Input("x", true);

auto defining_info = (*param_map)[legacy_input_vars[0]];
Expand Down Expand Up @@ -1163,7 +1163,7 @@ struct FillConstant2FullTranscriber : public OpTranscriber {
return op_info;
}

std::vector<pir::OpResult> GenerateOperationInput(
std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand Down Expand Up @@ -1245,14 +1245,14 @@ struct FillConstant2FullWithTensorTranscriber : public OpTranscriber {
return op_info;
}

std::vector<pir::OpResult> GenerateOperationInput(
std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
pir::Program* program) override {
std::vector<pir::OpResult> op_inputs;
std::vector<pir::Value> op_inputs;
if (op_desc.HasInput("ShapeTensor", true) &&
op_desc.Input("ShapeTensor", true).size() > 0) {
auto shape_tensor_vars = op_desc.Input("ShapeTensor", true);
Expand Down Expand Up @@ -1409,7 +1409,7 @@ struct ReduceOpTranscriber : public OpTranscriber {
};

struct ElementwiseTranscriber : public OpTranscriber {
std::vector<pir::OpResult> GenerateOperationInput(
std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/ir_adaptor/translator/op_translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ struct OpTranscriber {

public:
virtual pir::OpInfo LoopkUpOpInfo(pir::IrContext* ctx, const OpDesc& op_desc);
virtual std::vector<pir::OpResult> GenerateOperationInput(
virtual std::vector<pir::Value> GenerateOperationInput(
pir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
Expand Down
13 changes: 6 additions & 7 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ void HandleForSpecialOp(
HandleForIfOp(place, op_item, block, ctx, map_op_pair, map_value_pair);
return;
}
std::vector<pir::OpResult> vec_inputs;
std::vector<pir::Value> vec_inputs;
std::vector<pir::Type> op_output_types;
if (op_item->name() == "builtin.combine") {
// Copy op inputs
Expand Down Expand Up @@ -754,8 +754,7 @@ void HandleForSpecialOp(

if (new_in.type().isa<pir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<pir::VectorType>().data();
auto index = op_item->attributes()
.at("index")
auto index = op_item->attribute("index")
.dyn_cast<pir::Int32Attribute>()
.data();
op_output_types.push_back(vec_types[index]);
Expand Down Expand Up @@ -899,7 +898,7 @@ std::vector<pir::Type> BuildOpOutputType(pir::Operation* op_item,
return op_output_types;
}

std::vector<pir::OpResult> BuildOpInputList(
std::vector<pir::Value> BuildOpInputList(
pir::Operation* op_item,
const std::string& kernel_fn_str,
const phi::KernelKey& kernel_key,
Expand All @@ -913,7 +912,7 @@ std::vector<pir::OpResult> BuildOpInputList(
return {};
}

std::vector<pir::OpResult> vec_inputs;
std::vector<pir::Value> vec_inputs;

for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
Expand Down Expand Up @@ -981,7 +980,7 @@ std::vector<pir::OpResult> BuildOpInputList(
auto pre_define_op = cur_in.GetDefiningOp();

if (pre_define_op->name() == "builtin.combine") {
std::vector<pir::OpResult> inner_inputs;
std::vector<pir::Value> inner_inputs;
std::vector<pir::Type> types_in_vec;
bool is_trans = false;
for (size_t j = 0; j < pre_define_op->num_operands(); ++j) {
Expand Down Expand Up @@ -1155,7 +1154,7 @@ std::string GetKernelFnStr(const OpYamlInfoParser* op_info_parser,
pir::Operation* BuildPhiKernelOp(
const std::string& kernel_fn_str,
const phi::KernelKey& kernel_key,
const std::vector<pir::OpResult>& vec_inputs,
const std::vector<pir::Value>& vec_inputs,
const std::vector<pir::Type>& op_output_types,
pir::Operation* op_item,
pir::Block* block,
Expand Down
8 changes: 1 addition & 7 deletions paddle/pir/core/op_result.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ bool OpResult::operator==(const OpResult &other) const {
return impl_ == other.impl_;
}

// OpResult::OpResult(const detail::OpResultImpl *impl) : Value(impl) {}

uint32_t OpResult::GetValidInlineIndex(uint32_t index) {
uint32_t max_inline_index =
pir::detail::OpResultImpl::GetMaxInlineResultIndex();
return index <= max_inline_index ? index : max_inline_index;
}
OpResult::OpResult(const detail::OpResultImpl *impl) : Value(impl) {}

} // namespace pir
8 changes: 2 additions & 6 deletions paddle/pir/core/op_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,14 @@ class OpResultImpl;
///
class IR_API OpResult : public Value {
public:
OpResult() = default;
OpResult(std::nullptr_t ptr = nullptr) : Value(ptr){}; // NOLINT
Operation *owner() const;
uint32_t GetResultIndex() const;
bool operator==(const OpResult &other) const;
// OpResult(const detail::OpResultImpl *impl); // NOLINT

// This func will remove in next pr.
OpResult(const detail::ValueImpl *impl) : Value(impl) {} // NOLINT

private:
friend Operation;
static uint32_t GetValidInlineIndex(uint32_t index);
OpResult(const detail::OpResultImpl *impl); // NOLINT
// Access classof annd dyn_cast_from.
friend Value;
static bool classof(Value value);
Expand Down
10 changes: 7 additions & 3 deletions paddle/pir/core/operation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@

namespace pir {
Operation *Operation::Create(OperationArgument &&argument) {
return Create(argument.inputs,
std::vector<Value> inputs;
for (auto op_result : argument.inputs) {
inputs.emplace_back(op_result);
}
return Create(inputs,
argument.attributes,
argument.output_types,
argument.info,
Expand All @@ -38,7 +42,7 @@ Operation *Operation::Create(OperationArgument &&argument) {
// Allocate the required memory based on the size and number of inputs, outputs,
// and operators, and construct it in the order of: OpOutlineResult,
// OpInlineResult, Operation, operand.
Operation *Operation::Create(const std::vector<pir::OpResult> &inputs,
Operation *Operation::Create(const std::vector<Value> &inputs,
const AttributeMap &attributes,
const std::vector<Type> &output_types,
pir::OpInfo op_info,
Expand Down Expand Up @@ -89,7 +93,7 @@ Operation *Operation::Create(const std::vector<pir::OpResult> &inputs,
IR_THROW("The address of OpOperandImpl must be divisible by 8.");
}
for (size_t idx = 0; idx < num_operands; idx++) {
new (base_ptr) detail::OpOperandImpl(inputs[idx].impl_, op);
new (base_ptr) detail::OpOperandImpl(inputs[idx], op);
base_ptr += sizeof(detail::OpOperandImpl);
}
// 3.4. Construct BlockOperands.
Expand Down
2 changes: 1 addition & 1 deletion paddle/pir/core/operation.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class IR_API alignas(8) Operation final {
/// NOTE: Similar to new and delete, the destroy() and the create() need to be
/// used in conjunction.
///
static Operation *Create(const std::vector<pir::OpResult> &inputs,
static Operation *Create(const std::vector<pir::Value> &inputs,
const AttributeMap &attributes,
const std::vector<pir::Type> &output_types,
pir::OpInfo op_info,
Expand Down
28 changes: 27 additions & 1 deletion paddle/pir/core/operation_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#pragma once

#include <initializer_list>
#include <memory>
#include "paddle/pir/core/attribute.h"
#include "paddle/pir/core/op_info.h"
Expand Down Expand Up @@ -56,12 +57,29 @@ struct OperationArgument {
num_regions(num_regions),
successors(successors) {}

/// Add Operand.
// Will be deleted in the next pr.
void AddOperand(OpResult operand) { inputs.emplace_back(operand); }

void AddInput(Value input) {
inputs.emplace_back(input.dyn_cast<OpResult>());
}

// Will be deleted in the next pr.
template <class InputIt>
void AddOperands(InputIt first, InputIt last);

template <class InputIt>
void AddInputs(InputIt first, InputIt last);

void AddInputs(std::initializer_list<Value> value_list) {
AddInputs(std::begin(value_list), std::end(value_list));
}

template <class ValueContainer>
void AddInputs(const ValueContainer& value_container) {
AddInputs(std::begin(value_container), std::end(value_container));
}

/// Add Output.
void AddOutput(Type type) { output_types.emplace_back(type); }

Expand All @@ -87,6 +105,14 @@ void OperationArgument::AddOperands(InputIt first, InputIt last) {
inputs.emplace_back(*first++);
}
}

template <class InputIt>
void OperationArgument::AddInputs(InputIt first, InputIt last) {
while (first != last) {
AddInput(*first++);
}
}

template <class InputIt>
void OperationArgument::AddOutputs(InputIt first, InputIt last) {
while (first != last) {
Expand Down
14 changes: 4 additions & 10 deletions paddle/pir/core/parser/ir_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ Operation* IrParser::ParseOperation() {

OpInfo opinfo = ParseOpInfo();

std::vector<OpResult> inputs = ParseOprandList();
std::vector<Value> inputs = ParseOprandList();

pir::AttributeMap attributeMap = ParseAttributeMap();

Expand Down Expand Up @@ -269,14 +269,14 @@ OpInfo IrParser::ParseOpInfo() {

// OprandList := ValueList
// ValueList := ValueId(,ValueId)*
std::vector<OpResult> IrParser::ParseOprandList() {
std::vector<Value> IrParser::ParseOprandList() {
ConsumeAToken("(");
std::vector<OpResult> inputs{};
std::vector<Value> inputs{};
Token ind_token = ConsumeToken();
while (ind_token.val_ != ")") {
std::string t = "";
if (ind_token.token_type_ == NULL_) {
inputs.push_back(GetNullValue());
inputs.emplace_back();
} else {
t = ind_token.val_;
inputs.push_back(opresultmap[t]);
Expand Down Expand Up @@ -327,12 +327,6 @@ std::vector<Type> IrParser::ParseTypeList() {
return type_vector;
}

OpResult IrParser::GetNullValue() {
Value* v = new Value{nullptr};
OpResult* opresult = static_cast<OpResult*>(v);
return *opresult;
}

Attribute Attribute::Parse(std::istream& is, IrContext* ctx) {
IrParser parser(ctx, is);
return parser.ParseAttribute();
Expand Down
4 changes: 1 addition & 3 deletions paddle/pir/core/parser/ir_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,12 @@ class IrParser {

std::vector<std::string> ParseOpResultList();

std::vector<OpResult> ParseOprandList();
std::vector<Value> ParseOprandList();

AttributeMap ParseAttributeMap();

std::vector<Type> ParseTypeList();

OpResult GetNullValue();

Type ParseType();

Attribute ParseAttribute();
Expand Down
Loading