From 856e46722ccb0243819a79633f6be89cf083fd7d Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Thu, 7 Sep 2023 07:45:08 +0000 Subject: [PATCH 1/8] add constraint_pass for symbolicDim construction. --- .../transforms/shape_optimization_pass.cc | 48 +++++++++++++++++++ .../transforms/shape_optimization_pass.h | 26 ++++++++++ test/cpp/ir/shape_dialect/CMakeLists.txt | 12 +++++ .../ir/shape_dialect/constraint_pass_test.cc | 47 ++++++++++++++++++ 4 files changed, 133 insertions(+) create mode 100644 paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc create mode 100644 paddle/ir/dialect/shape/transforms/shape_optimization_pass.h create mode 100644 test/cpp/ir/shape_dialect/constraint_pass_test.cc diff --git a/paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc new file mode 100644 index 0000000000000..7e709be871258 --- /dev/null +++ b/paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc @@ -0,0 +1,48 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/ir/dialect/shape/transforms/shape_optimization_pass.h" + +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/pass/pass.h" +#include "paddle/ir/pass/pass_registry.h" + +namespace { + +class ShapeOptimizationPass : public ir::Pass { + public: + ShapeOptimizationPass() : ir::Pass("shape_optimization", 0) {} + + void Run(ir::Operation *op) override { + auto module_op = op->dyn_cast(); + IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); + } + + bool CanApplyOn(ir::Operation *op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } +}; + +} // namespace + +namespace ir { + +std::unique_ptr CreateShapeOptimizationPass() { + return std::make_unique(); +} + +} // namespace ir + +REGISTER_PASS(shape_optimization, ShapeOptimizationPass); diff --git a/paddle/ir/dialect/shape/transforms/shape_optimization_pass.h b/paddle/ir/dialect/shape/transforms/shape_optimization_pass.h new file mode 100644 index 0000000000000..cc29531b24107 --- /dev/null +++ b/paddle/ir/dialect/shape/transforms/shape_optimization_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/ir/core/dll_decl.h" + +namespace ir { + +class Pass; + +IR_API std::unique_ptr CreateShapeOptimizationPass(); + +} // namespace ir diff --git a/test/cpp/ir/shape_dialect/CMakeLists.txt b/test/cpp/ir/shape_dialect/CMakeLists.txt index ae3e3d63d52bd..b4d770fe105fb 100644 --- a/test/cpp/ir/shape_dialect/CMakeLists.txt +++ b/test/cpp/ir/shape_dialect/CMakeLists.txt @@ -6,3 +6,15 @@ cc_test_old( pd_dialect ir gtest) + +set(TEST_DEPS gtest pd_dialect ir) + +if(WITH_DISTRIBUTE) + set(TEST_DEPS ${TEST_DEPS} fleet_executor) +endif() + +cc_test_old(constraint_pass_test SRCS constraint_pass_test.cc DEPS ${TEST_DEPS}) + +set_tests_properties( + constraint_pass_test PROPERTIES ENVIRONMENT + "FLAGS_enable_new_ir_in_executor=true") diff --git a/test/cpp/ir/shape_dialect/constraint_pass_test.cc b/test/cpp/ir/shape_dialect/constraint_pass_test.cc new file mode 100644 index 0000000000000..822e69257f5f3 --- /dev/null +++ b/test/cpp/ir/shape_dialect/constraint_pass_test.cc @@ -0,0 +1,47 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/ir/core/builder.h" +#include "paddle/ir/core/builtin_attribute.h" +#include "paddle/ir/core/builtin_dialect.h" +#include "paddle/ir/core/builtin_op.h" +#include "paddle/ir/core/cast_utils.h" +#include "paddle/ir/core/dialect.h" +#include "paddle/ir/core/enforce.h" +#include "paddle/ir/core/ir_context.h" +#include "paddle/ir/core/op_info.h" +#include "paddle/ir/core/parameter.h" +#include "paddle/ir/core/program.h" +#include "paddle/ir/core/value.h" +#include "paddle/ir/dialect/shape/transforms/shape_optimization_pass.h" +#include "paddle/ir/pass/pass.h" +#include "paddle/ir/pass/pass_manager.h" +#include "paddle/phi/core/kernel_registry.h" + +TEST(pattern_rewrite, Patterns) { + ir::IrContext *ctx = ir::IrContext::Instance(); + ir::Program program(ctx); + + ir::PassManager pm(ctx); + pm.AddPass(ir::CreateShapeOptimizationPass()); + CHECK_EQ(pm.Run(&program), true); +} From be338f5f6b0207052572bbe91ab09e5b828d8a60 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Thu, 7 Sep 2023 11:19:00 +0000 Subject: [PATCH 2/8] fix error --- paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc index 7e709be871258..cd73c75ae630c 100644 --- a/paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc +++ b/paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc @@ -45,4 +45,4 @@ std::unique_ptr CreateShapeOptimizationPass() { } // namespace ir -REGISTER_PASS(shape_optimization, ShapeOptimizationPass); +REGISTER_IR_PASS(shape_optimization, ShapeOptimizationPass); From b5fe5dad64678fa3e52b672f4e41953da5d1e060 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Mon, 11 Sep 2023 05:42:35 +0000 Subject: [PATCH 3/8] address build package too large. --- .../transforms/shape_optimization_pass.cc | 24 ++++++------ .../transforms/shape_optimization_pass.h | 6 +-- test/cpp/pir/shape_dialect/CMakeLists.txt | 23 ++++++++++- .../shape_dialect/constraint_pass_test.cc | 38 +++++++++---------- 4 files changed, 55 insertions(+), 36 deletions(-) rename paddle/{ir => pir}/dialect/shape/transforms/shape_optimization_pass.cc (65%) rename paddle/{ir => pir}/dialect/shape/transforms/shape_optimization_pass.h (90%) rename test/cpp/{ir => pir}/shape_dialect/constraint_pass_test.cc (53%) diff --git a/paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc similarity index 65% rename from paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc rename to paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc index cd73c75ae630c..0ff6b0c331f4b 100644 --- a/paddle/ir/dialect/shape/transforms/shape_optimization_pass.cc +++ b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc @@ -12,37 +12,37 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/ir/dialect/shape/transforms/shape_optimization_pass.h" +#include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/pass/pass.h" -#include "paddle/ir/pass/pass_registry.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" namespace { -class ShapeOptimizationPass : public ir::Pass { +class ShapeOptimizationPass : public pir::Pass { public: - ShapeOptimizationPass() : ir::Pass("shape_optimization", 0) {} + ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {} - void Run(ir::Operation *op) override { - auto module_op = op->dyn_cast(); + void Run(pir::Operation *op) override { + auto module_op = op->dyn_cast(); IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); } - bool CanApplyOn(ir::Operation *op) const override { + bool CanApplyOn(pir::Operation *op) const override { return op->name() == "builtin.module" && op->num_regions() > 0; } }; } // namespace -namespace ir { +namespace pir { std::unique_ptr CreateShapeOptimizationPass() { return std::make_unique(); } -} // namespace ir +} // namespace pir REGISTER_IR_PASS(shape_optimization, ShapeOptimizationPass); diff --git a/paddle/ir/dialect/shape/transforms/shape_optimization_pass.h b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.h similarity index 90% rename from paddle/ir/dialect/shape/transforms/shape_optimization_pass.h rename to paddle/pir/dialect/shape/transforms/shape_optimization_pass.h index cc29531b24107..43bad532c920d 100644 --- a/paddle/ir/dialect/shape/transforms/shape_optimization_pass.h +++ b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.h @@ -15,12 +15,12 @@ #pragma once #include -#include "paddle/ir/core/dll_decl.h" +#include "paddle/pir/core/dll_decl.h" -namespace ir { +namespace pir { class Pass; IR_API std::unique_ptr CreateShapeOptimizationPass(); -} // namespace ir +} // namespace pir diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index 71fbbc3823d22..0dfee7a9151d3 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -7,13 +7,32 @@ cc_test_old( pir gtest) -set(TEST_DEPS gtest pd_dialect ir) +set(TEST_DEPS gtest pd_op_dialect pir) if(WITH_DISTRIBUTE) set(TEST_DEPS ${TEST_DEPS} fleet_executor) endif() -cc_test_old(constraint_pass_test SRCS constraint_pass_test.cc DEPS ${TEST_DEPS}) +if(WIN32 AND WITH_TESTING) + cc_test_old( + constraint_pass_test + SRCS + constraint_pass_test.cc + DEPS + gtest + ${TEST_DEPS} + ${BRPC_DEPS}) +else() + cc_test_old( + constraint_pass_test + SRCS + constraint_pass_test.cc + DEPS + gtest + ${TEST_DEPS} + ${paddle_lib} + python) +endif() set_tests_properties( constraint_pass_test PROPERTIES ENVIRONMENT diff --git a/test/cpp/ir/shape_dialect/constraint_pass_test.cc b/test/cpp/pir/shape_dialect/constraint_pass_test.cc similarity index 53% rename from test/cpp/ir/shape_dialect/constraint_pass_test.cc rename to test/cpp/pir/shape_dialect/constraint_pass_test.cc index 822e69257f5f3..2c07e110ce98d 100644 --- a/test/cpp/ir/shape_dialect/constraint_pass_test.cc +++ b/test/cpp/pir/shape_dialect/constraint_pass_test.cc @@ -20,28 +20,28 @@ #include #include -#include "paddle/ir/core/builder.h" -#include "paddle/ir/core/builtin_attribute.h" -#include "paddle/ir/core/builtin_dialect.h" -#include "paddle/ir/core/builtin_op.h" -#include "paddle/ir/core/cast_utils.h" -#include "paddle/ir/core/dialect.h" -#include "paddle/ir/core/enforce.h" -#include "paddle/ir/core/ir_context.h" -#include "paddle/ir/core/op_info.h" -#include "paddle/ir/core/parameter.h" -#include "paddle/ir/core/program.h" -#include "paddle/ir/core/value.h" -#include "paddle/ir/dialect/shape/transforms/shape_optimization_pass.h" -#include "paddle/ir/pass/pass.h" -#include "paddle/ir/pass/pass_manager.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/cast_utils.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/op_info.h" +#include "paddle/pir/core/parameter.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" +#include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" TEST(pattern_rewrite, Patterns) { - ir::IrContext *ctx = ir::IrContext::Instance(); - ir::Program program(ctx); + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); - ir::PassManager pm(ctx); - pm.AddPass(ir::CreateShapeOptimizationPass()); + pir::PassManager pm(ctx); + pm.AddPass(pir::CreateShapeOptimizationPass()); CHECK_EQ(pm.Run(&program), true); } From 99a381fc31939071c7cbdff34f359c18894fe374 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Tue, 12 Sep 2023 13:26:03 +0000 Subject: [PATCH 4/8] temporary commit. --- paddle/pir/dialect/shape/ir/shape_dialect.cc | 7 +- paddle/pir/dialect/shape/ir/shape_op.cc | 69 +++++++++++++++++++ paddle/pir/dialect/shape/ir/shape_op.h | 65 +++++++++++++++++ .../transforms/shape_optimization_pass.cc | 66 +++++++++++++++++- test/cpp/pir/shape_dialect/CMakeLists.txt | 34 +++------ .../pir/shape_dialect/constraint_pass_test.cc | 55 ++++++++++++++- .../cpp/pir/shape_dialect/symbolic_op_test.cc | 69 ++++++++++++++++++- 7 files changed, 331 insertions(+), 34 deletions(-) diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.cc b/paddle/pir/dialect/shape/ir/shape_dialect.cc index 7638e635be631..611d2d95c4810 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.cc +++ b/paddle/pir/dialect/shape/ir/shape_dialect.cc @@ -23,7 +23,12 @@ ShapeDialect::ShapeDialect(IrContext *context) } void ShapeDialect::initialize() { - RegisterOps(); + RegisterOps(); } } // namespace dialect diff --git a/paddle/pir/dialect/shape/ir/shape_op.cc b/paddle/pir/dialect/shape/ir/shape_op.cc index be7d378c7fe8a..c40780bfbc1a5 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.cc +++ b/paddle/pir/dialect/shape/ir/shape_op.cc @@ -14,6 +14,7 @@ #include "paddle/pir/dialect/shape/ir/shape_op.h" #include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/builtin_type.h" namespace pir { @@ -191,9 +192,77 @@ std::vector TieProductEqualOp::getRhs() { return res; } +const char *TieShapeOp::attributes_name[attributes_num] = { + SymbolicDim::getSymbolicDimAttrName().c_str()}; // NOLINT + +void TieShapeOp::Build(Builder &builder, + OperationArgument &argument, + const pir::OpResult &input) { + argument.inputs = {input}; +} +void TieShapeOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const pir::OpResult &input, + const std::vector &dims) { + argument.inputs = {input}; + for (auto &dim : dims) { + argument.inputs.push_back(dim); + } +} + +pir::Value TieShapeOp::getValue() { return operand_source(0); } + +std::vector TieShapeOp::getShapeDimIndexes() { + std::vector res; + for (uint32_t i = 1; i < num_operands(); i++) { + res.push_back(operand_source(i)); + } + return res; +} + +void FuncOp::Build(Builder &builder, OperationArgument &argument) { + argument.num_regions = 1; +} + +pir::Block *FuncOp::block() { + pir::Region ®ion = (*this)->region(0); + if (region.empty()) region.emplace_back(); + return region.front(); +} + +void TensorDimOp::Build(Builder &builder, + OperationArgument &argument, + const pir::OpResult &source, + const pir::OpResult &index) { + argument.inputs = {source, index}; + argument.output_types.emplace_back( + pir::IndexType::get(pir::IrContext::Instance())); +} + +void TensorDimOp::Build(Builder &builder, + OperationArgument &argument, + const pir::OpResult &source, + int64_t index) { + pir::OpResult indexValue = + builder + .Build( + pir::Int64Attribute::get(pir::IrContext::Instance(), 2), + pir::IndexType::get(pir::IrContext::Instance())) + ->result(0); + argument.inputs = {source, indexValue}; + argument.output_types.emplace_back( + pir::IndexType::get(pir::IrContext::Instance())); +} + +pir::Value TensorDimOp::getSource() { return operand_source(0); } + +pir::Value TensorDimOp::getIndex() { return operand_source(1); } } // namespace dialect } // namespace pir IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::SymbolicDim) IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::DimOp) IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TensorDimOp) diff --git a/paddle/pir/dialect/shape/ir/shape_op.h b/paddle/pir/dialect/shape/ir/shape_op.h index 4df90213cd616..24f8c165abe31 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.h +++ b/paddle/pir/dialect/shape/ir/shape_op.h @@ -54,6 +54,10 @@ class IR_API SymbolicDim : public Op { bool isDynamic(); bool merge(SymbolicDim other); + static const std::string getSymbolicDimAttrName() { + return "SymbolicDimAttr"; + } + void Verify() {} }; @@ -93,9 +97,70 @@ class IR_API TieProductEqualOp : public Op { void Verify() {} }; +class IR_API TieShapeOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.tie_shape"; } + + static constexpr uint32_t attributes_num = 1; + static const char *attributes_name[attributes_num]; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const pir::OpResult &input); + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const pir::OpResult &input, + const std::vector &dims); + + pir::Value getValue(); + std::vector getShapeDimIndexes(); + void Verify() {} +}; + +class IR_API FuncOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.func"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument); // NOLINT + pir::Block *block(); + void Verify() {} +}; + +class IR_API TensorDimOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.tensor_dim"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const pir::OpResult &source, + const pir::OpResult &index); + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const pir::OpResult &source, + int64_t index); + pir::Value getIndex(); + pir::Value getSource(); + pir::OpResult out() { return result(0); } + void Verify() {} +}; + } // namespace dialect } // namespace pir IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::SymbolicDim); IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::DimOp); IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TensorDimOp); diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc index 0ff6b0c331f4b..7787666c52147 100644 --- a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc +++ b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc @@ -13,6 +13,8 @@ // limitations under the License. #include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/dialect/shape/ir/shape_op.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/program.h" @@ -21,16 +23,76 @@ namespace { +bool insertTieShapeOnValue(pir::OpResult value, + pir::Builder& builder) { // NOLINT + // Only insert tie_shape ops for non-zero ranked tensor type + auto ty = value.type().dyn_cast(); + + if (!ty || ty.dims().size() == 0) return true; + std::vector dimSizes; + for (int64_t dim = 0, rank = ty.dims().size(); dim < rank; ++dim) { + auto dimOp = builder.Build(value, dim); + dimSizes.push_back(dimOp.out()); + } + builder.Build(value, dimSizes); + return true; +} + +bool insertTieShapeOnRegion(pir::Region* region); + +bool insertTieShapeOnOperation(pir::Operation* op, + pir::Builder& builder) { // NOLINT + if (op->isa()) return true; + // TODO(liujinnan): skip the specialized Ops. + + for (size_t i = 0; i < op->num_regions(); ++i) { + if (!insertTieShapeOnRegion(&(op->region(i)))) return false; + } + builder.SetInsertionPointAfter(op); + for (pir::OpResult v : op->results()) { + if (!insertTieShapeOnValue(v, builder)) return false; + } + + return true; +} + +bool insertTieShapeOnBlock(pir::Block* block) { + pir::Builder builder = + pir::Builder(pir::IrContext::Instance(), block, block->begin()); + // TODO(liujinnan): mapping block arguments + + std::vector op_list; + for (pir::Operation* op : *block) op_list.push_back(op); + for (pir::Operation* op : op_list) { + if (!insertTieShapeOnOperation(op, builder)) return false; + } + return true; +} + +bool insertTieShapeOnRegion(pir::Region* region) { + for (pir::Block* block : *region) { + if (!insertTieShapeOnBlock(block)) return false; + } + return true; +} + +bool materializeShapeComputation(pir::ModuleOp m) { + if (!insertTieShapeOnRegion(&(m->region(0)))) return false; + // TODO(liujinnan): add rewitter pattern for reifyInferShape. + return true; +} + class ShapeOptimizationPass : public pir::Pass { public: ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {} - void Run(pir::Operation *op) override { + void Run(pir::Operation* op) override { auto module_op = op->dyn_cast(); IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); + materializeShapeComputation(module_op); } - bool CanApplyOn(pir::Operation *op) const override { + bool CanApplyOn(pir::Operation* op) const override { return op->name() == "builtin.module" && op->num_regions() > 0; } }; diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index 0dfee7a9151d3..d5fe787de4a80 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -7,32 +7,14 @@ cc_test_old( pir gtest) -set(TEST_DEPS gtest pd_op_dialect pir) - -if(WITH_DISTRIBUTE) - set(TEST_DEPS ${TEST_DEPS} fleet_executor) -endif() - -if(WIN32 AND WITH_TESTING) - cc_test_old( - constraint_pass_test - SRCS - constraint_pass_test.cc - DEPS - gtest - ${TEST_DEPS} - ${BRPC_DEPS}) -else() - cc_test_old( - constraint_pass_test - SRCS - constraint_pass_test.cc - DEPS - gtest - ${TEST_DEPS} - ${paddle_lib} - python) -endif() +cc_test_old( + constraint_pass_test + SRCS + constraint_pass_test.cc + DEPS + gtest + pd_op_dialect + pir) set_tests_properties( constraint_pass_test PROPERTIES ENVIRONMENT diff --git a/test/cpp/pir/shape_dialect/constraint_pass_test.cc b/test/cpp/pir/shape_dialect/constraint_pass_test.cc index 2c07e110ce98d..c99d7493af09a 100644 --- a/test/cpp/pir/shape_dialect/constraint_pass_test.cc +++ b/test/cpp/pir/shape_dialect/constraint_pass_test.cc @@ -20,11 +20,14 @@ #include #include +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/pir/core/builder.h" #include "paddle/pir/core/builtin_attribute.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/builtin_type.h" #include "paddle/pir/core/cast_utils.h" #include "paddle/pir/core/dialect.h" #include "paddle/pir/core/enforce.h" @@ -33,15 +36,61 @@ #include "paddle/pir/core/parameter.h" #include "paddle/pir/core/program.h" #include "paddle/pir/core/value.h" +#include "paddle/pir/dialect/shape/ir/shape_dialect.h" #include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" -TEST(pattern_rewrite, Patterns) { +pir::AttributeMap CreateAttributeMap( + const std::vector &attribute_names, + const std::vector &attributes) { pir::IrContext *ctx = pir::IrContext::Instance(); - pir::Program program(ctx); + pir::AttributeMap attr_map; + for (size_t i = 0; i < attribute_names.size(); i++) { + pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); + attr_map.insert( + std::pair(attribute_names[i], attr_value)); + } + return attr_map; +} + +pir::Operation *CreateDenseTensorOp( + pir::IrContext *ctx, + const phi::DDim &dims, + const std::vector &attribute_names, + const std::vector &attributes) { + std::vector op_inputs = {}; + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + std::vector op_output_types = { + paddle::dialect::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset)}; + pir::Operation *op = + pir::Operation::Create(op_inputs, + CreateAttributeMap(attribute_names, attributes), + op_output_types, + pir::OpInfo()); + return op; +} +TEST(constraint_pass, materialize_shape) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); pir::PassManager pm(ctx); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Operation *op0 = + CreateDenseTensorOp(ctx, {-100000, 2}, {"op0_attr"}, {"op0_name"}); + program.block()->push_back(op0); + pir::Operation *op1 = + CreateDenseTensorOp(ctx, {-100000, 2, 2}, {"op1_attr"}, {"op1_name"}); + program.block()->push_back(op1); + + EXPECT_EQ(program.block()->size(), static_cast(2)); pm.AddPass(pir::CreateShapeOptimizationPass()); - CHECK_EQ(pm.Run(&program), true); + EXPECT_TRUE(pm.Run(&program)); + // 5 ConstantOp + 5 TensorDim + 2 TieShape + op0 + op1 == 14 Ops. + EXPECT_EQ(program.block()->size(), static_cast(14)); } diff --git a/test/cpp/pir/shape_dialect/symbolic_op_test.cc b/test/cpp/pir/shape_dialect/symbolic_op_test.cc index f916650376fbe..2c793c6e57c38 100644 --- a/test/cpp/pir/shape_dialect/symbolic_op_test.cc +++ b/test/cpp/pir/shape_dialect/symbolic_op_test.cc @@ -26,6 +26,40 @@ #include "paddle/pir/dialect/shape/ir/shape_op.h" #include "paddle/pir/dialect/shape/utils/shape_utils.h" +pir::AttributeMap CreateAttributeMap( + const std::vector &attribute_names, + const std::vector &attributes) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::AttributeMap attr_map; + for (size_t i = 0; i < attribute_names.size(); i++) { + pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); + attr_map.insert( + std::pair(attribute_names[i], attr_value)); + } + return attr_map; +} + +pir::Operation *CreateDenseTensorOp( + pir::IrContext *ctx, + const phi::DDim &dims, + const std::vector &attribute_names, + const std::vector &attributes) { + std::vector op_inputs = {}; + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + std::vector op_output_types = { + paddle::dialect::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset)}; + pir::Operation *op = + pir::Operation::Create(op_inputs, + CreateAttributeMap(attribute_names, attributes), + op_output_types, + pir::OpInfo()); + return op; +} + TEST(assist_struct_test, symbolic_dim) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); @@ -318,7 +352,7 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { symDimProductRhs_)); } -TEST(assist_struct_test, dim) { +TEST(shape_op, dim) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); ctx->GetOrRegisterDialect(); @@ -333,7 +367,7 @@ TEST(assist_struct_test, dim) { EXPECT_EQ(res.type(), pir::IndexType::get(ctx)); } -TEST(assist_struct_test, tie_product_equal) { +TEST(shape_op, tie_product_equal) { pir::IrContext *ctx = pir::IrContext::Instance(); pir::Program program(ctx); ctx->GetOrRegisterDialect(); @@ -369,3 +403,34 @@ TEST(assist_struct_test, tie_product_equal) { EXPECT_EQ(lhs, lhs_ref); EXPECT_EQ(rhs, rhs_ref); } + +TEST(shape_op, tensor_dim) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::Operation *op = + CreateDenseTensorOp(ctx, {-100000, 2}, {"op_attr"}, {"op_name"}); + pir::OpResult resDenseTensorValue = op->result(0); + + pir::dialect::TensorDimOp tensorDimOp0 = + builder.Build(resDenseTensorValue, 0); + pir::OpResult res0 = tensorDimOp0.out(); + + pir::OpResult indexValue = + builder + .Build( + pir::Int64Attribute::get(pir::IrContext::Instance(), 1), + pir::IndexType::get(pir::IrContext::Instance())) + ->result(0); + pir::dialect::TensorDimOp tensorDimOp1 = + builder.Build(resDenseTensorValue, indexValue); + pir::OpResult res1 = tensorDimOp1.out(); + + EXPECT_EQ(res0.type(), pir::IndexType::get(ctx)); + EXPECT_EQ(res1.type(), pir::IndexType::get(ctx)); + EXPECT_EQ(tensorDimOp0.getSource(), resDenseTensorValue); + EXPECT_EQ(tensorDimOp1.getSource(), resDenseTensorValue); + EXPECT_EQ(tensorDimOp1.getIndex(), indexValue); +} From 67a9399086eea4dbfbda327a15d019115231292b Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Thu, 14 Sep 2023 05:19:44 +0000 Subject: [PATCH 5/8] add UT. --- .../transforms/shape_optimization_pass.cc | 1 - .../cpp/pir/shape_dialect/symbolic_op_test.cc | 31 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc index 7787666c52147..53920cf237acb 100644 --- a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc +++ b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc @@ -25,7 +25,6 @@ namespace { bool insertTieShapeOnValue(pir::OpResult value, pir::Builder& builder) { // NOLINT - // Only insert tie_shape ops for non-zero ranked tensor type auto ty = value.type().dyn_cast(); if (!ty || ty.dims().size() == 0) return true; diff --git a/test/cpp/pir/shape_dialect/symbolic_op_test.cc b/test/cpp/pir/shape_dialect/symbolic_op_test.cc index 87f4623f811ce..87bc34d4657f1 100644 --- a/test/cpp/pir/shape_dialect/symbolic_op_test.cc +++ b/test/cpp/pir/shape_dialect/symbolic_op_test.cc @@ -586,3 +586,34 @@ TEST(assist_struct_test, shape_analysis) { EXPECT_TRUE(shapeAnalysis.isShapeEqual(value1, value2)); EXPECT_FALSE(shapeAnalysis.isShapeEqual(value1, value5)); } + +TEST(shape_op, tensor_dim) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::Operation *op = + CreateDenseTensorOp(ctx, {-100000, 2}, {"op_attr"}, {"op_name"}); + pir::OpResult resDenseTensorValue = op->result(0); + + pir::dialect::TensorDimOp tensorDimOp0 = + builder.Build(resDenseTensorValue, 0); + pir::OpResult res0 = tensorDimOp0.out(); + + pir::OpResult indexValue = + builder + .Build( + pir::Int64Attribute::get(pir::IrContext::Instance(), 1), + pir::IndexType::get(pir::IrContext::Instance())) + ->result(0); + pir::dialect::TensorDimOp tensorDimOp1 = + builder.Build(resDenseTensorValue, indexValue); + pir::OpResult res1 = tensorDimOp1.out(); + + EXPECT_EQ(res0.type(), pir::IndexType::get(ctx)); + EXPECT_EQ(res1.type(), pir::IndexType::get(ctx)); + EXPECT_EQ(tensorDimOp0.getSource(), resDenseTensorValue); + EXPECT_EQ(tensorDimOp1.getSource(), resDenseTensorValue); + EXPECT_EQ(tensorDimOp1.getIndex(), indexValue); +} From 458a1e4df60b80ac21c8128e359dd218b1becfd2 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Thu, 14 Sep 2023 06:15:31 +0000 Subject: [PATCH 6/8] change format. --- paddle/pir/dialect/shape/ir/shape_op.cc | 10 +- paddle/pir/dialect/shape/ir/shape_op.h | 6 +- .../transforms/shape_optimization_pass.cc | 20 +-- paddle/pir/dialect/shape/utils/shape_utils.cc | 148 +++++++++--------- paddle/pir/dialect/shape/utils/shape_utils.h | 52 +++--- .../cpp/pir/shape_dialect/symbolic_op_test.cc | 122 +++++++-------- 6 files changed, 179 insertions(+), 179 deletions(-) diff --git a/paddle/pir/dialect/shape/ir/shape_op.cc b/paddle/pir/dialect/shape/ir/shape_op.cc index 85da70160d6f1..0f560c916e0b6 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.cc +++ b/paddle/pir/dialect/shape/ir/shape_op.cc @@ -105,15 +105,15 @@ void SymbolicDim::updateKnownNonSizeZero(bool attrValue) { pir::BoolAttribute::get(pir::IrContext::Instance(), attrValue)); } -bool SymbolicDim::isDynamic() { +bool SymbolicDim::IsDynamic() { return getValue() == ShapedTypeInterface::kDynamic; } -bool SymbolicDim::merge(SymbolicDim other) { - if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue()) +bool SymbolicDim::Merge(SymbolicDim other) { + if (!IsDynamic() && !other.IsDynamic() && getValue() != other.getValue()) return false; - if (isDynamic() && !other.isDynamic()) updateValue(other.getValue()); - if (!isDynamic() && other.isDynamic()) other.updateValue(getValue()); + if (IsDynamic() && !other.IsDynamic()) updateValue(other.getValue()); + if (!IsDynamic() && other.IsDynamic()) other.updateValue(getValue()); bool knownNonNegativeFlag = getKnownNonNegative() || other.getKnownNonNegative(); diff --git a/paddle/pir/dialect/shape/ir/shape_op.h b/paddle/pir/dialect/shape/ir/shape_op.h index b1e1a381a6fa9..d111e81c4989d 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.h +++ b/paddle/pir/dialect/shape/ir/shape_op.h @@ -51,11 +51,11 @@ class IR_API SymbolicDim : public Op { void updateKnownNonSizeOne(bool attrValue); void updateKnownNonSizeZero(bool attrValue); - bool isDynamic(); - bool merge(SymbolicDim other); + bool IsDynamic(); + bool Merge(SymbolicDim other); static const std::string getSymbolicDimAttrName() { - return "SymbolicDimAttr"; + return "kSymbolicDimAttr"; } void Verify() {} diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc index 53920cf237acb..a4922c69ed0c5 100644 --- a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc +++ b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc @@ -23,7 +23,7 @@ namespace { -bool insertTieShapeOnValue(pir::OpResult value, +bool InsertTieShapeOnValue(pir::OpResult value, pir::Builder& builder) { // NOLINT auto ty = value.type().dyn_cast(); @@ -37,19 +37,19 @@ bool insertTieShapeOnValue(pir::OpResult value, return true; } -bool insertTieShapeOnRegion(pir::Region* region); +bool InsertTieShapeOnRegion(pir::Region* region); -bool insertTieShapeOnOperation(pir::Operation* op, +bool InsertTieShapeOnOperation(pir::Operation* op, pir::Builder& builder) { // NOLINT if (op->isa()) return true; // TODO(liujinnan): skip the specialized Ops. for (size_t i = 0; i < op->num_regions(); ++i) { - if (!insertTieShapeOnRegion(&(op->region(i)))) return false; + if (!InsertTieShapeOnRegion(&(op->region(i)))) return false; } builder.SetInsertionPointAfter(op); for (pir::OpResult v : op->results()) { - if (!insertTieShapeOnValue(v, builder)) return false; + if (!InsertTieShapeOnValue(v, builder)) return false; } return true; @@ -63,20 +63,20 @@ bool insertTieShapeOnBlock(pir::Block* block) { std::vector op_list; for (pir::Operation* op : *block) op_list.push_back(op); for (pir::Operation* op : op_list) { - if (!insertTieShapeOnOperation(op, builder)) return false; + if (!InsertTieShapeOnOperation(op, builder)) return false; } return true; } -bool insertTieShapeOnRegion(pir::Region* region) { +bool InsertTieShapeOnRegion(pir::Region* region) { for (pir::Block* block : *region) { if (!insertTieShapeOnBlock(block)) return false; } return true; } -bool materializeShapeComputation(pir::ModuleOp m) { - if (!insertTieShapeOnRegion(&(m->region(0)))) return false; +bool MaterializeShapeComputation(pir::ModuleOp m) { + if (!InsertTieShapeOnRegion(&(m->region(0)))) return false; // TODO(liujinnan): add rewitter pattern for reifyInferShape. return true; } @@ -88,7 +88,7 @@ class ShapeOptimizationPass : public pir::Pass { void Run(pir::Operation* op) override { auto module_op = op->dyn_cast(); IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); - materializeShapeComputation(module_op); + MaterializeShapeComputation(module_op); } bool CanApplyOn(pir::Operation* op) const override { diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index 1de3f03620961..99ec8f57bc2c2 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -17,7 +17,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" namespace pir { -bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { +bool CompareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { if (lhs.size() < 1 || (lhs[0] != 'S' && lhs[0] != 'C')) return lhs < rhs; if (rhs.size() < 1 || (rhs[0] != 'S' && rhs[0] != 'C')) return lhs < rhs; int64_t lhsIdx = 0, rhsIdx = 0; @@ -30,14 +30,14 @@ bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { return (lhs[0] < rhs[0]) || (lhs[0] == rhs[0] && lhsIdx < rhsIdx); } -bool compareSymbolicDimProduct(SymbolicDimProduct& lhs, // NOLINT +bool CompareSymbolicDimProduct(SymbolicDimProduct& lhs, // NOLINT SymbolicDimProduct& rhs) { // NOLINT if (lhs.symbols.size() < rhs.symbols.size()) return true; if (lhs.symbols.size() == rhs.symbols.size()) { for (size_t idx = 0; idx < lhs.symbols.size(); ++idx) { const std::string lhsName = lhs.symbols[idx].getSymName(); const std::string rhsName = rhs.symbols[idx].getSymName(); - if (compareSymbolicDimNames(lhsName, rhsName)) return true; + if (CompareSymbolicDimNames(lhsName, rhsName)) return true; if (lhsName != rhsName) return false; } } @@ -60,7 +60,7 @@ const std::string SymbolTable::insert(Operation* symbol) { return name; } -bool SymbolicDimMgr::load() { +bool SymbolicDimMgr::Load() { auto funcOp = symbolTable_.getOp()->dyn_cast(); assert(funcOp); for (auto op_ : *(funcOp.block())) { @@ -70,14 +70,14 @@ bool SymbolicDimMgr::load() { symbolNameSet_.insert(op.getSymName()); } } - return loadShapeConstraintGraph(); + return LoadShapeConstraintGraph(); } -bool SymbolicDimMgr::loadShapeConstraintGraph() { +bool SymbolicDimMgr::LoadShapeConstraintGraph() { // TODO(liujinnan): add more constraint function. currently, only support // tie_product_equal. auto constraint_vec = - symbolTable_.lookup("tie_product_equal"); + symbolTable_.Lookup("tie_product_equal"); if (!constraint_vec.size()) return true; @@ -89,7 +89,7 @@ bool SymbolicDimMgr::loadShapeConstraintGraph() { product.factor *= constOp.value().dyn_cast().data(); continue; } else if (auto dimOp = definingOp->dyn_cast()) { - auto sym = symbolTable_.lookup(dimOp.getName()); + auto sym = symbolTable_.Lookup(dimOp.getName()); if (!sym) return false; product.symbols.push_back(sym); continue; @@ -103,7 +103,7 @@ bool SymbolicDimMgr::loadShapeConstraintGraph() { SymbolicDimProduct lhs, rhs; if (!build_sym_product(op.getLhs(), lhs) || !build_sym_product(op.getRhs(), rhs) || - !mapSymbolicDimProductEqual(lhs, rhs)) + !MapSymbolicDimProductEqual(lhs, rhs)) return false; } return true; @@ -115,24 +115,24 @@ int64_t gcd(int64_t m, int64_t n) { return (m < n) ? gcd(m, n % m) : gcd(m % n, n); } -bool SymbolicDimMgr::mapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, +bool SymbolicDimMgr::MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { SymbolicDimProduct newLhs, newRhs; - std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs); + std::tie(newLhs, newRhs) = SimplifySymbolicDimProductPair(lhs, rhs); // early return for identity case. if (newLhs == newRhs) return true; if (newLhs.factor == newRhs.factor && newLhs.symbols.size() == 1 && newRhs.symbols.size() == 1) { - return mapSymbolicDimEqual(newLhs.symbols[0], newRhs.symbols[0]); + return MapSymbolicDimEqual(newLhs.symbols[0], newRhs.symbols[0]); } else if (newLhs.symbols.size() == 0 && newRhs.symbols.size() == 1 && newRhs.factor == 1) { - return mapSymbolicDimEqual(newConstantSymbolicDim(newLhs.factor), + return MapSymbolicDimEqual(NewConstantSymbolicDim(newLhs.factor), newRhs.symbols[0]); } else if (newRhs.symbols.size() == 0 && newLhs.symbols.size() == 1 && newLhs.factor == 1) { - return mapSymbolicDimEqual(newConstantSymbolicDim(newRhs.factor), + return MapSymbolicDimEqual(NewConstantSymbolicDim(newRhs.factor), newLhs.symbols[0]); } @@ -144,10 +144,10 @@ bool SymbolicDimMgr::mapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, } std::pair -SymbolicDimMgr::simplifySymbolicDimProductPair(const SymbolicDimProduct& x, +SymbolicDimMgr::SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, const SymbolicDimProduct& y) { - auto lhs = simplifySymbolicDimProduct(x); - auto rhs = simplifySymbolicDimProduct(y); + auto lhs = SimplifySymbolicDimProduct(x); + auto rhs = SimplifySymbolicDimProduct(y); SymbolicDimProduct newLhs, newRhs; int64_t gcdFactor = gcd(std::abs(lhs.factor), std::abs(rhs.factor)); @@ -190,19 +190,19 @@ SymbolicDimMgr::simplifySymbolicDimProductPair(const SymbolicDimProduct& x, return std::make_pair(std::move(newLhs), std::move(newRhs)); } -SymbolicDimProduct SymbolicDimMgr::simplifySymbolicDimProduct( +SymbolicDimProduct SymbolicDimMgr::SimplifySymbolicDimProduct( const SymbolicDimProduct& x) { std::vector copied; copied.reserve(x.symbols.size()); - for (SymbolicDim op : x.symbols) copied.push_back(getRootSymbolicDim(op)); + for (SymbolicDim op : x.symbols) copied.push_back(GetRootSymbolicDim(op)); sort(copied.begin(), copied.end(), [&](SymbolicDim lhs, SymbolicDim rhs) { - return compareSymbolicDimNames(lhs.getSymName(), rhs.getSymName()); + return CompareSymbolicDimNames(lhs.getSymName(), rhs.getSymName()); }); SymbolicDimProduct newX; newX.factor = x.factor; for (SymbolicDim op : copied) { - if (!op.isDynamic()) { + if (!op.IsDynamic()) { newX.factor *= op.getValue(); } else { newX.symbols.push_back(op); @@ -211,7 +211,7 @@ SymbolicDimProduct SymbolicDimMgr::simplifySymbolicDimProduct( return newX; } -const std::string SymbolicDimMgr::getNextName() { +const std::string SymbolicDimMgr::GetNextName() { std::string name; do { name = "S" + std::to_string(nextSymbolicIdx_++); @@ -231,13 +231,13 @@ SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) { symbolTable_ = SymbolTable(func); } -SymbolicDim SymbolicDimMgr::newSymbolicDim(const std::string& name) { +SymbolicDim SymbolicDimMgr::NewSymbolicDim(const std::string& name) { auto funcOp = symbolTable_.getOp()->dyn_cast(); assert(funcOp); Builder builder = Builder(m_.ir_context(), funcOp.block()); // default settting dim != 0 dialect::SymbolicDim symbol = - builder.Build(name.empty() ? getNextName() : name, + builder.Build(name.empty() ? GetNextName() : name, ShapedTypeInterface::kDynamic, false, false, @@ -248,12 +248,12 @@ SymbolicDim SymbolicDimMgr::newSymbolicDim(const std::string& name) { return symbol; } -SymbolicDim SymbolicDimMgr::newConstantSymbolicDim(int64_t val) { +SymbolicDim SymbolicDimMgr::NewConstantSymbolicDim(int64_t val) { auto it = constantSymbolicDimMap_.find(val); if (it == constantSymbolicDimMap_.end()) { auto name = "C" + std::to_string(val); it = constantSymbolicDimMap_ - .insert(std::make_pair(val, newSymbolicDim(name))) + .insert(std::make_pair(val, NewSymbolicDim(name))) .first; it->second.updateValue(val); if (val == -1) it->second.updateKnownNegativeOne(true); @@ -261,22 +261,22 @@ SymbolicDim SymbolicDimMgr::newConstantSymbolicDim(int64_t val) { if (val != 1) it->second.updateKnownNonSizeOne(true); if (val != 0) it->second.updateKnownNonSizeZero(true); } - return getRootSymbolicDim(it->second); + return GetRootSymbolicDim(it->second); } -std::vector SymbolicDimMgr::createSymbolicDimsForRankedValue( +std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( Value value) { std::vector symbols; auto dims = value.type().dyn_cast().dims(); for (int idx = 0; idx < dims.size(); ++idx) { symbols.push_back(dims[idx] == ShapedTypeInterface::kDynamic - ? newSymbolicDim() - : newConstantSymbolicDim(dims[idx])); + ? NewSymbolicDim() + : NewConstantSymbolicDim(dims[idx])); } return symbols; } -SymbolicDim SymbolicDimMgr::getRootSymbolicDim(SymbolicDim symbol) { +SymbolicDim SymbolicDimMgr::GetRootSymbolicDim(SymbolicDim symbol) { SymbolicDim current = symbol; std::vector path; while (symbolDimUnionSet_[current] != current) { @@ -287,32 +287,32 @@ SymbolicDim SymbolicDimMgr::getRootSymbolicDim(SymbolicDim symbol) { return current; } -bool SymbolicDimMgr::isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { - SymbolicDim lhsRoot = getRootSymbolicDim(lhs); - SymbolicDim rhsRoot = getRootSymbolicDim(rhs); +bool SymbolicDimMgr::IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { + SymbolicDim lhsRoot = GetRootSymbolicDim(lhs); + SymbolicDim rhsRoot = GetRootSymbolicDim(rhs); return lhsRoot == rhsRoot; } -bool SymbolicDimMgr::mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { - SymbolicDim lhsRoot = getRootSymbolicDim(lhs); - SymbolicDim rhsRoot = getRootSymbolicDim(rhs); +bool SymbolicDimMgr::MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { + SymbolicDim lhsRoot = GetRootSymbolicDim(lhs); + SymbolicDim rhsRoot = GetRootSymbolicDim(rhs); if (lhsRoot != rhsRoot) { - if (compareSymbolicDimNames(lhsRoot.getSymName(), rhsRoot.getSymName())) { - if (!lhsRoot.merge(rhsRoot)) return false; + if (CompareSymbolicDimNames(lhsRoot.getSymName(), rhsRoot.getSymName())) { + if (!lhsRoot.Merge(rhsRoot)) return false; symbolDimUnionSet_[rhsRoot] = lhsRoot; } else { - if (!rhsRoot.merge(lhsRoot)) return false; + if (!rhsRoot.Merge(lhsRoot)) return false; symbolDimUnionSet_[lhsRoot] = rhsRoot; } } return true; } -SymbolicDimProduct* SymbolicDimMgr::symbolicDimProductDivide( +SymbolicDimProduct* SymbolicDimMgr::SymbolicDimProductDivide( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { SymbolicDimProduct newLhs, newRhs; - std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs); + std::tie(newLhs, newRhs) = SimplifySymbolicDimProductPair(lhs, rhs); if (newLhs.factor == 0 || newRhs.factor == 0) return nullptr; if (newLhs.factor % newRhs.factor != 0) return nullptr; @@ -340,16 +340,16 @@ SymbolicDimProduct* SymbolicDimMgr::symbolicDimProductDivide( return result; } -bool SymbolicDimMgr::isMultipleOfKnownSymbolicDimProductEqualPair( +bool SymbolicDimMgr::IsMultipleOfKnownSymbolicDimProductEqualPair( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { for (auto& pairOutter : productEqualityMap_) { const SymbolicDimProduct& x = pairOutter.first; - auto factorX = symbolicDimProductDivide(lhs, x); + auto factorX = SymbolicDimProductDivide(lhs, x); if (!factorX) continue; for (auto& pairInner : pairOutter.second) { if (!pairInner.second) continue; const SymbolicDimProduct& y = pairInner.first; - auto factorY = symbolicDimProductDivide(rhs, y); + auto factorY = SymbolicDimProductDivide(rhs, y); if (!factorY || (*factorX) != (*factorY)) continue; return true; } @@ -358,7 +358,7 @@ bool SymbolicDimMgr::isMultipleOfKnownSymbolicDimProductEqualPair( return false; } -bool SymbolicDimMgr::updateProductEqualityMap() { +bool SymbolicDimMgr::UpdateProductEqualityMap() { // early return if nothing is updated. if (productEqualityMapUpdated_) return true; @@ -370,7 +370,7 @@ bool SymbolicDimMgr::updateProductEqualityMap() { if (!pairInner.second) continue; const SymbolicDimProduct& y = pairInner.first; SymbolicDimProduct newX, newY; - std::tie(newX, newY) = simplifySymbolicDimProductPair(x, y); + std::tie(newX, newY) = SimplifySymbolicDimProductPair(x, y); if (newX == newY) continue; newMap[newX][newY] = newMap[newY][newX] = true; productSet.insert(newX); @@ -425,7 +425,7 @@ bool SymbolicDimMgr::updateProductEqualityMap() { for (auto& y : productSet) { if (!productEqualityMap_[x][y]) continue; productEqualityMap_[x][y] = productEqualityMap_[y][x] = false; - if (!isMultipleOfKnownSymbolicDimProductEqualPair(x, y)) { + if (!IsMultipleOfKnownSymbolicDimProductEqualPair(x, y)) { productEqualityMap_[x][y] = productEqualityMap_[y][x] = true; } } @@ -449,25 +449,25 @@ bool SymbolicDimMgr::updateProductEqualityMap() { return true; } -bool SymbolicDimMgr::isSymbolicDimProductEqual(const SymbolicDimProduct& lhs, +bool SymbolicDimMgr::IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { SymbolicDimProduct newLhs, newRhs; - std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs); + std::tie(newLhs, newRhs) = SimplifySymbolicDimProductPair(lhs, rhs); // early return for identity case. if (newLhs == newRhs) return true; - IR_ENFORCE(updateProductEqualityMap(), "Update product equality map failed."); - return isMultipleOfKnownSymbolicDimProductEqualPair(newLhs, newRhs); + IR_ENFORCE(UpdateProductEqualityMap(), "Update product equality map failed."); + return IsMultipleOfKnownSymbolicDimProductEqualPair(newLhs, newRhs); } -bool SymbolicDimMgr::save() { +bool SymbolicDimMgr::Save() { using Name2SymbolFn = std::function; auto updateAttrs = [&](ArrayAttribute attrs, Name2SymbolFn fn) { std::vector newAttrs; for (Attribute attr : attrs.AsVector()) { auto sym = fn(attr.dyn_cast().AsString()); assert(sym); - SymbolicDim root = getRootSymbolicDim(sym); + SymbolicDim root = GetRootSymbolicDim(sym); Attribute rootSymbol = StrAttribute::get(m_->ir_context(), root.getSymName()); newAttrs.push_back(rootSymbol); @@ -481,11 +481,11 @@ bool SymbolicDimMgr::save() { auto attrs = op->attribute(SymbolicDim::getSymbolicDimAttrName()); auto symbolicShapeAttr = updateAttrs(attrs, [&](const std::string& name) { - return symbolTable_.lookup(name); + return symbolTable_.Lookup(name); }); op->set_attribute(SymbolicDim::getSymbolicDimAttrName(), symbolicShapeAttr); } - if (!updateProductEqualityMap()) { + if (!UpdateProductEqualityMap()) { return false; } std::unordered_set usedSymbolicOps; @@ -493,7 +493,7 @@ bool SymbolicDimMgr::save() { // TODO(liujinnan): collect uses in value. auto collectUsedSymbols = [&](ArrayAttribute attrs) { for (Attribute attr : attrs.AsVector()) { - auto sym = symbolTable_.lookup( + auto sym = symbolTable_.Lookup( attr.dyn_cast().AsString()); assert(sym); if (usedSymbolicOps.insert(sym).second) @@ -539,7 +539,7 @@ bool SymbolicDimMgr::save() { std::sort(usedSymbolNames.begin(), usedSymbolNames.end(), [&](const std::string& lhs, const std::string& rhs) { - return compareSymbolicDimNames(lhs, rhs); + return CompareSymbolicDimNames(lhs, rhs); }); int numNonConstDims = 0; std::unordered_map nameMapping; @@ -569,10 +569,10 @@ bool SymbolicDimMgr::save() { // TODO(liujinnan): update attributes attached to values. - return saveShapeConstraintGraph(); + return SaveShapeConstraintGraph(); } -bool SymbolicDimMgr::saveShapeConstraintGraph() { +bool SymbolicDimMgr::SaveShapeConstraintGraph() { auto funcOp = symbolTable_.getOp()->dyn_cast(); assert(funcOp); auto op_it = funcOp.block()->rbegin(); @@ -605,10 +605,10 @@ bool SymbolicDimMgr::saveShapeConstraintGraph() { for (auto& p : productEqualityMap_) sortedProductVec.push_back(p.first); std::sort(sortedProductVec.begin(), sortedProductVec.end(), - compareSymbolicDimProduct); + CompareSymbolicDimProduct); for (auto& x : sortedProductVec) { for (auto& y : sortedProductVec) { - if (!compareSymbolicDimProduct(x, y)) continue; + if (!CompareSymbolicDimProduct(x, y)) continue; if (!productEqualityMap_[x][y]) continue; auto lhsOperands = build_operands(x); auto rhsOperands = build_operands(y); @@ -618,17 +618,17 @@ bool SymbolicDimMgr::saveShapeConstraintGraph() { return true; } -bool ShapeAnalysis::isSameNumElements(Value lhs, Value rhs) { +bool ShapeAnalysis::IsSameNumElements(Value lhs, Value rhs) { if (lhs == rhs) return true; auto lhsTy = lhs.type().dyn_cast_interface(); auto rhsTy = rhs.type().dyn_cast_interface(); if (!lhsTy || !rhsTy || !lhsTy.hasRank() || !rhsTy.hasRank()) return false; - return isProductEqual(lhs, 0, lhsTy.getRank(), rhs, 0, rhsTy.getRank()); + return IsProductEqual(lhs, 0, lhsTy.getRank(), rhs, 0, rhsTy.getRank()); } -bool ShapeAnalysis::isProductEqual( +bool ShapeAnalysis::IsProductEqual( Value lhs, int lhsFrom, int lhsTo, Value rhs, int rhsFrom, int rhsTo) { std::vector lhsDimIdxs, rhsDimIdxs; lhsDimIdxs.reserve(lhsTo - lhsFrom); @@ -636,12 +636,12 @@ bool ShapeAnalysis::isProductEqual( for (int i = lhsFrom; i < lhsTo; ++i) lhsDimIdxs.push_back(i); for (int i = rhsFrom; i < rhsTo; ++i) rhsDimIdxs.push_back(i); - return isProductEqual(lhs, lhsDimIdxs, rhs, rhsDimIdxs); + return IsProductEqual(lhs, lhsDimIdxs, rhs, rhsDimIdxs); } SymbolicDimShapeAnalysis::SymbolicDimShapeAnalysis(ModuleOp m) : m_(m), mgr_(m) { - mgr_.load(); + mgr_.Load(); for (auto op : *(m_.block())) { auto tieShapeOp = op->dyn_cast(); if (!tieShapeOp) continue; @@ -652,7 +652,7 @@ SymbolicDimShapeAnalysis::SymbolicDimShapeAnalysis(ModuleOp m) .attribute(SymbolicDim::getSymbolicDimAttrName()) .AsVector(); for (const auto& attr : attrs) { - auto symOp = mgr_.symbolTable().lookup( + auto symOp = mgr_.symbolTable().Lookup( attr.dyn_cast().AsString()); if (!symOp) continue; symbols.push_back(symOp); @@ -660,9 +660,9 @@ SymbolicDimShapeAnalysis::SymbolicDimShapeAnalysis(ModuleOp m) } } -SymbolicDimShapeAnalysis::~SymbolicDimShapeAnalysis() { mgr_.save(); } +SymbolicDimShapeAnalysis::~SymbolicDimShapeAnalysis() { mgr_.Save(); } -bool SymbolicDimShapeAnalysis::isShapeEqual(Value lhs, Value rhs) { +bool SymbolicDimShapeAnalysis::IsShapeEqual(Value lhs, Value rhs) { if (lhs == rhs) return true; auto lhsTy = lhs.type().dyn_cast_interface(); @@ -684,15 +684,15 @@ bool SymbolicDimShapeAnalysis::isShapeEqual(Value lhs, Value rhs) { std::vector lhsSyms; std::vector rhsSyms; for (auto sym : lhsIt->second) { - lhsSyms.push_back(mgr_.getRootSymbolicDim(sym)); + lhsSyms.push_back(mgr_.GetRootSymbolicDim(sym)); } for (auto sym : rhsIt->second) { - rhsSyms.push_back(mgr_.getRootSymbolicDim(sym)); + rhsSyms.push_back(mgr_.GetRootSymbolicDim(sym)); } return lhsSyms == rhsSyms; } -bool SymbolicDimShapeAnalysis::isProductEqual(Value lhs, +bool SymbolicDimShapeAnalysis::IsProductEqual(Value lhs, std::vector lhsDimIdxs, Value rhs, std::vector rhsDimIdxs) { @@ -722,6 +722,6 @@ bool SymbolicDimShapeAnalysis::isProductEqual(Value lhs, return false; } - return mgr_.isSymbolicDimProductEqual(lhsProd, rhsProd); + return mgr_.IsSymbolicDimProductEqual(lhsProd, rhsProd); } } // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index f00bff91a5517..ab728fc6a00a0 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -53,7 +53,7 @@ class SymbolTable { template typename std::enable_if::value, SymbolicDim>::type - lookup(const std::string& name) const { + Lookup(const std::string& name) const { auto it = symbolTableMap_.find(name); return it != symbolTableMap_.end() ? it->second->dyn_cast() : SymbolicDim(nullptr); @@ -61,7 +61,7 @@ class SymbolTable { template typename std::enable_if::value, std::vector>::type - lookup(const std::string& name) const { + Lookup(const std::string& name) const { std::vector res; auto it = symbolFuncMap_.find(name); if (it != symbolFuncMap_.end()) { @@ -101,35 +101,35 @@ struct SymProductHasher { class SymbolicDimMgr { public: explicit SymbolicDimMgr(ModuleOp m); - bool load(); - SymbolicDim newSymbolicDim(const std::string& name = {}); - SymbolicDim newConstantSymbolicDim(int64_t val); - std::vector createSymbolicDimsForRankedValue(Value value); - SymbolicDim getRootSymbolicDim(SymbolicDim symbol); - bool isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + bool Load(); + SymbolicDim NewSymbolicDim(const std::string& name = {}); + SymbolicDim NewConstantSymbolicDim(int64_t val); + std::vector CreateSymbolicDimsForRankedValue(Value value); + SymbolicDim GetRootSymbolicDim(SymbolicDim symbol); + bool IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); SymbolTable& symbolTable() { return symbolTable_; } - bool mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); - SymbolicDimProduct simplifySymbolicDimProduct(const SymbolicDimProduct& x); + bool MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + SymbolicDimProduct SimplifySymbolicDimProduct(const SymbolicDimProduct& x); std::pair - simplifySymbolicDimProductPair(const SymbolicDimProduct& x, + SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, const SymbolicDimProduct& y); - SymbolicDimProduct* symbolicDimProductDivide(const SymbolicDimProduct& x, + SymbolicDimProduct* SymbolicDimProductDivide(const SymbolicDimProduct& x, const SymbolicDimProduct& y); - bool save(); + bool Save(); - bool isSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + bool IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); - bool mapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + bool MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); private: - const std::string getNextName(); - bool updateProductEqualityMap(); - bool isMultipleOfKnownSymbolicDimProductEqualPair( + const std::string GetNextName(); + bool UpdateProductEqualityMap(); + bool IsMultipleOfKnownSymbolicDimProductEqualPair( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); - bool saveShapeConstraintGraph(); - bool loadShapeConstraintGraph(); + bool SaveShapeConstraintGraph(); + bool LoadShapeConstraintGraph(); private: ModuleOp m_; @@ -157,15 +157,15 @@ class ShapeAnalysis { public: virtual ~ShapeAnalysis() = default; - virtual bool isShapeEqual(Value lhs, Value rhs) = 0; + virtual bool IsShapeEqual(Value lhs, Value rhs) = 0; - virtual bool isProductEqual(Value lhs, + virtual bool IsProductEqual(Value lhs, std::vector lhsDimIdxs, Value rhs, std::vector rhsDimIdxs) = 0; - virtual bool isProductEqual( + virtual bool IsProductEqual( Value lhs, int lhsFrom, int lhsTo, Value rhs, int rhsFrom, int rhsTo); - virtual bool isSameNumElements(Value lhs, Value rhs); + virtual bool IsSameNumElements(Value lhs, Value rhs); }; class SymbolicDimShapeAnalysis : public ShapeAnalysis { @@ -175,9 +175,9 @@ class SymbolicDimShapeAnalysis : public ShapeAnalysis { SymbolicDimMgr& symbolicDimMgr() { return mgr_; } const SymbolicDimMgr& symbolicDimMgr() const { return mgr_; } - bool isShapeEqual(Value lhs, Value rhs) override; + bool IsShapeEqual(Value lhs, Value rhs) override; - bool isProductEqual(Value lhs, + bool IsProductEqual(Value lhs, std::vector lhsDimIdxs, Value rhs, std::vector rhsDimIdxs) override; diff --git a/test/cpp/pir/shape_dialect/symbolic_op_test.cc b/test/cpp/pir/shape_dialect/symbolic_op_test.cc index 87bc34d4657f1..fa9ccbe2346e2 100644 --- a/test/cpp/pir/shape_dialect/symbolic_op_test.cc +++ b/test/cpp/pir/shape_dialect/symbolic_op_test.cc @@ -77,8 +77,8 @@ TEST(assist_struct_test, symbolic_dim) { EXPECT_FALSE(symDim.getKnownNonSizeZero()); EXPECT_FALSE(symDim.getKnownNonNegative()); - EXPECT_FALSE(symDim.isDynamic()); - EXPECT_TRUE(symDim.merge(symDim_)); + EXPECT_FALSE(symDim.IsDynamic()); + EXPECT_TRUE(symDim.Merge(symDim_)); symDim.updateValue(20); symDim.updateSymName("S2"); @@ -87,7 +87,7 @@ TEST(assist_struct_test, symbolic_dim) { symDim.updateKnownNonSizeZero(true); symDim.updateKnownNonNegative(true); - EXPECT_FALSE(symDim.merge(symDim_)); + EXPECT_FALSE(symDim.Merge(symDim_)); EXPECT_EQ(symDim.getValue(), 20); EXPECT_EQ(symDim.getSymName(), "S2"); @@ -123,9 +123,9 @@ TEST(assist_struct_test, symbolic_dim_table) { pir::SymbolTable symbolTable(program.module_op()); EXPECT_EQ(symbolTable.insert(symDim), "S0"); - EXPECT_EQ(symbolTable.lookup("S0"), symDim); + EXPECT_EQ(symbolTable.Lookup("S0"), symDim); EXPECT_EQ(symbolTable.getOp(), program.module_op()); - EXPECT_FALSE(symbolTable.lookup("S1")); + EXPECT_FALSE(symbolTable.Lookup("S1")); } TEST(assist_struct_test, symbolic_dim_mgr_simple) { @@ -138,17 +138,17 @@ TEST(assist_struct_test, symbolic_dim_mgr_simple) { ctx->GetOrRegisterDialect(); pir::SymbolicDimMgr symDimMgr(program.module_op()); - pir::dialect::SymbolicDim symDimS0 = symDimMgr.newSymbolicDim(); - pir::dialect::SymbolicDim symDimS1 = symDimMgr.newSymbolicDim(); - pir::dialect::SymbolicDim symDimC10 = symDimMgr.newConstantSymbolicDim(10); - symDimMgr.mapSymbolicDimEqual(symDimS0, symDimS1); + pir::dialect::SymbolicDim symDimS0 = symDimMgr.NewSymbolicDim(); + pir::dialect::SymbolicDim symDimS1 = symDimMgr.NewSymbolicDim(); + pir::dialect::SymbolicDim symDimC10 = symDimMgr.NewConstantSymbolicDim(10); + symDimMgr.MapSymbolicDimEqual(symDimS0, symDimS1); auto op = CreateDenseTensorOp( ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); pir::Value res = op->result(0); std::vector symDimVec = - symDimMgr.createSymbolicDimsForRankedValue(res); + symDimMgr.CreateSymbolicDimsForRankedValue(res); EXPECT_EQ(symDimS0.getSymName(), "S0"); EXPECT_EQ(symDimS1.getSymName(), "S1"); @@ -157,13 +157,13 @@ TEST(assist_struct_test, symbolic_dim_mgr_simple) { EXPECT_EQ(symDimC10.getValue(), 10); EXPECT_EQ(symDimVec[0].getSymName(), "S2"); EXPECT_EQ(symDimVec[1].getSymName(), "C2"); - EXPECT_EQ(symDimMgr.symbolTable().lookup("S0"), + EXPECT_EQ(symDimMgr.symbolTable().Lookup("S0"), symDimS0); - EXPECT_EQ(symDimMgr.symbolTable().lookup("C10"), + EXPECT_EQ(symDimMgr.symbolTable().Lookup("C10"), symDimC10); - EXPECT_EQ(symDimMgr.getRootSymbolicDim(symDimS1), symDimS0); - EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimS1)); - EXPECT_FALSE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimC10)); + EXPECT_EQ(symDimMgr.GetRootSymbolicDim(symDimS1), symDimS0); + EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS0, symDimS1)); + EXPECT_FALSE(symDimMgr.IsSymbolicDimEqual(symDimS0, symDimC10)); } TEST(assist_struct_test, symbolic_dim_mgr_complex) { @@ -181,21 +181,21 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { pir::Builder builder = pir::Builder(ctx, funcOp.block()); - pir::dialect::SymbolicDim symDimS0 = symDimMgr.newSymbolicDim("S0"); - pir::dialect::SymbolicDim symDimS1 = symDimMgr.newSymbolicDim("S1"); - pir::dialect::SymbolicDim symDimS2 = symDimMgr.newSymbolicDim("S2"); - pir::dialect::SymbolicDim symDimS3 = symDimMgr.newSymbolicDim("S3"); - pir::dialect::SymbolicDim symDimS4 = symDimMgr.newSymbolicDim("S4"); - pir::dialect::SymbolicDim symDimS5 = symDimMgr.newSymbolicDim("S5"); - pir::dialect::SymbolicDim symDimS6 = symDimMgr.newSymbolicDim("S6"); - pir::dialect::SymbolicDim symDimS7 = symDimMgr.newSymbolicDim("S7"); - pir::dialect::SymbolicDim symDimS8 = symDimMgr.newSymbolicDim("S8"); - pir::dialect::SymbolicDim symDimS9 = symDimMgr.newSymbolicDim("S9"); - pir::dialect::SymbolicDim symDimS10 = symDimMgr.newSymbolicDim("S10"); - pir::dialect::SymbolicDim symDimS11 = symDimMgr.newSymbolicDim("S11"); - pir::dialect::SymbolicDim symDimS12 = symDimMgr.newSymbolicDim("S12"); - pir::dialect::SymbolicDim symDimC10 = symDimMgr.newConstantSymbolicDim(10); - pir::dialect::SymbolicDim symDimC20 = symDimMgr.newConstantSymbolicDim(20); + pir::dialect::SymbolicDim symDimS0 = symDimMgr.NewSymbolicDim("S0"); + pir::dialect::SymbolicDim symDimS1 = symDimMgr.NewSymbolicDim("S1"); + pir::dialect::SymbolicDim symDimS2 = symDimMgr.NewSymbolicDim("S2"); + pir::dialect::SymbolicDim symDimS3 = symDimMgr.NewSymbolicDim("S3"); + pir::dialect::SymbolicDim symDimS4 = symDimMgr.NewSymbolicDim("S4"); + pir::dialect::SymbolicDim symDimS5 = symDimMgr.NewSymbolicDim("S5"); + pir::dialect::SymbolicDim symDimS6 = symDimMgr.NewSymbolicDim("S6"); + pir::dialect::SymbolicDim symDimS7 = symDimMgr.NewSymbolicDim("S7"); + pir::dialect::SymbolicDim symDimS8 = symDimMgr.NewSymbolicDim("S8"); + pir::dialect::SymbolicDim symDimS9 = symDimMgr.NewSymbolicDim("S9"); + pir::dialect::SymbolicDim symDimS10 = symDimMgr.NewSymbolicDim("S10"); + pir::dialect::SymbolicDim symDimS11 = symDimMgr.NewSymbolicDim("S11"); + pir::dialect::SymbolicDim symDimS12 = symDimMgr.NewSymbolicDim("S12"); + pir::dialect::SymbolicDim symDimC10 = symDimMgr.NewConstantSymbolicDim(10); + pir::dialect::SymbolicDim symDimC20 = symDimMgr.NewConstantSymbolicDim(20); pir::OpResult dimOpS0 = builder.Build("S0").out(); pir::OpResult dimOpS1 = builder.Build("S1").out(); @@ -301,7 +301,7 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { tieShapeOp_->set_attribute( pir::dialect::SymbolicDim::getSymbolicDimAttrName(), arrayAttr_); - EXPECT_TRUE(symDimMgr.load()); + EXPECT_TRUE(symDimMgr.Load()); // For check indirect equality: S1 * S4 == S2 * S5 pir::SymbolicDimProduct symDimProductLhs; @@ -359,36 +359,36 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { pir::SymbolicDimProduct *divRes = symDimMgr.symbolicDimProductDivide( symDimProductDivLhs, symDimProductDivRhs); - EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS1, symDimS2)); - EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimS3)); - EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS4, symDimS5)); + EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS1, symDimS2)); + EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS0, symDimS3)); + EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS4, symDimS5)); EXPECT_EQ(symDimS6.getValue(), 200); - EXPECT_EQ(symDimMgr.symbolTable().lookup("C20"), + EXPECT_EQ(symDimMgr.symbolTable().Lookup("C20"), symDimC20); EXPECT_EQ(symDimS7.getValue(), symDimC10.getValue()); EXPECT_EQ(simplifiedProductS7.factor, 10); EXPECT_EQ(simplifiedProductS7.symbols.size(), static_cast(0)); EXPECT_EQ(newLhs.symbols.size(), static_cast(1)); EXPECT_EQ(newRhs.symbols.size(), static_cast(1)); - EXPECT_EQ(newLhs.symbols[0], symDimMgr.getRootSymbolicDim(symDimS4)); - EXPECT_EQ(newRhs.symbols[0], symDimMgr.getRootSymbolicDim(symDimS3)); + EXPECT_EQ(newLhs.symbols[0], symDimMgr.GetRootSymbolicDim(symDimS4)); + EXPECT_EQ(newRhs.symbols[0], symDimMgr.GetRootSymbolicDim(symDimS3)); EXPECT_EQ(divRes->factor, 2); EXPECT_EQ(divRes->symbols.size(), static_cast(1)); - EXPECT_EQ(divRes->symbols[0], symDimMgr.getRootSymbolicDim(symDimS4)); + EXPECT_EQ(divRes->symbols[0], symDimMgr.GetRootSymbolicDim(symDimS4)); EXPECT_TRUE( - symDimMgr.isSymbolicDimProductEqual(symDimProductLhs, symDimProductRhs)); - EXPECT_TRUE(symDimMgr.isSymbolicDimProductEqual(symDimProductLhs_, + symDimMgr.IsSymbolicDimProductEqual(symDimProductLhs, symDimProductRhs)); + EXPECT_TRUE(symDimMgr.IsSymbolicDimProductEqual(symDimProductLhs_, symDimProductRhs_)); - EXPECT_TRUE(symDimMgr.save()); + EXPECT_TRUE(symDimMgr.Save()); pir::SymbolicDimMgr symDimMgr_(program.module_op()); - EXPECT_TRUE(symDimMgr_.load()); + EXPECT_TRUE(symDimMgr_.Load()); auto attrs = tieShapeOp.attribute( pir::dialect::SymbolicDim::getSymbolicDimAttrName()); EXPECT_FALSE( - symDimMgr_.symbolTable().lookup("S7")); + symDimMgr_.symbolTable().Lookup("S7")); EXPECT_EQ(symDimMgr_.symbolTable() - .lookup("tie_product_equal") + .Lookup("tie_product_equal") .size(), static_cast(1)); @@ -437,10 +437,10 @@ TEST(shape_op, tie_product_equal) { EXPECT_EQ(symbolTable.insert(tie_product_equal), "tie_product_equal"); EXPECT_EQ( - symbolTable.lookup("tie_product_equal") + symbolTable.Lookup("tie_product_equal") .size(), static_cast(1)); - EXPECT_EQ(symbolTable.lookup( + EXPECT_EQ(symbolTable.Lookup( "tie_product_equal")[0], tie_product_equal); EXPECT_EQ(lhs, lhs_ref); @@ -572,19 +572,19 @@ TEST(assist_struct_test, shape_analysis) { pir::dialect::SymbolicDim::getSymbolicDimAttrName(), attrOp5); pir::SymbolicDimShapeAnalysis shapeAnalysis(program.module_op()); - EXPECT_TRUE(shapeAnalysis.isShapeEqual(value3, value4)); - EXPECT_FALSE(shapeAnalysis.isShapeEqual(value1, value2)); - EXPECT_FALSE(shapeAnalysis.isShapeEqual(value1, value3)); - EXPECT_FALSE(shapeAnalysis.isShapeEqual(value1, value5)); - EXPECT_FALSE(shapeAnalysis.isShapeEqual(value3, value5)); - EXPECT_TRUE(shapeAnalysis.isProductEqual(value1, {1}, value3, {0})); - EXPECT_TRUE(shapeAnalysis.isSameNumElements(value4, value3)); - - shapeAnalysis.symbolicDimMgr().mapSymbolicDimEqual(symDimS0, symDimS1); - shapeAnalysis.symbolicDimMgr().mapSymbolicDimEqual(symDimS0, symDimS2); - - EXPECT_TRUE(shapeAnalysis.isShapeEqual(value1, value2)); - EXPECT_FALSE(shapeAnalysis.isShapeEqual(value1, value5)); + EXPECT_TRUE(shapeAnalysis.IsShapeEqual(value3, value4)); + EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value2)); + EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value3)); + EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value5)); + EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value3, value5)); + EXPECT_TRUE(shapeAnalysis.IsProductEqual(value1, {1}, value3, {0})); + EXPECT_TRUE(shapeAnalysis.IsSameNumElements(value4, value3)); + + shapeAnalysis.symbolicDimMgr().MapSymbolicDimEqual(symDimS0, symDimS1); + shapeAnalysis.symbolicDimMgr().MapSymbolicDimEqual(symDimS0, symDimS2); + + EXPECT_TRUE(shapeAnalysis.IsShapeEqual(value1, value2)); + EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value5)); } TEST(shape_op, tensor_dim) { @@ -593,8 +593,8 @@ TEST(shape_op, tensor_dim) { ctx->GetOrRegisterDialect(); pir::Builder builder = pir::Builder(ctx, program.block()); - pir::Operation *op = - CreateDenseTensorOp(ctx, {-100000, 2}, {"op_attr"}, {"op_name"}); + pir::Operation *op = CreateDenseTensorOp( + ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); pir::OpResult resDenseTensorValue = op->result(0); pir::dialect::TensorDimOp tensorDimOp0 = From e36d8002e8cc5e6956ffef2f3c15890d7202b448 Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Thu, 14 Sep 2023 06:37:24 +0000 Subject: [PATCH 7/8] fix error. --- test/cpp/pir/shape_dialect/symbolic_op_test.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/cpp/pir/shape_dialect/symbolic_op_test.cc b/test/cpp/pir/shape_dialect/symbolic_op_test.cc index fa9ccbe2346e2..a15ab3560b3aa 100644 --- a/test/cpp/pir/shape_dialect/symbolic_op_test.cc +++ b/test/cpp/pir/shape_dialect/symbolic_op_test.cc @@ -325,14 +325,14 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { symDimProductRhs_.symbols.push_back(symDimS11); symDimProductRhs_.symbols.push_back(symDimS12); - // For check simplifySymbolicDimProduct, {factor = 1, Sym = {S7}} => {factor = + // For check SimplifySymbolicDimProduct, {factor = 1, Sym = {S7}} => {factor = // 10} pir::SymbolicDimProduct symDimProductS7; symDimProductS7.symbols.push_back(symDimS7); pir::SymbolicDimProduct simplifiedProductS7 = - symDimMgr.simplifySymbolicDimProduct(symDimProductS7); + symDimMgr.SimplifySymbolicDimProduct(symDimProductS7); - // For check simplifySymbolicDimProductPair, X * Y * Y, Y * Y * Z => X, Z + // For check SimplifySymbolicDimProductPair, X * Y * Y, Y * Y * Z => X, Z pir::SymbolicDimProduct symDimProductPairLhs; pir::SymbolicDimProduct symDimProductPairRhs; pir::SymbolicDimProduct newLhs, newRhs; @@ -343,7 +343,7 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { symDimProductPairRhs.symbols.push_back(symDimS2); symDimProductPairRhs.symbols.push_back(symDimS3); - std::tie(newLhs, newRhs) = symDimMgr.simplifySymbolicDimProductPair( + std::tie(newLhs, newRhs) = symDimMgr.SimplifySymbolicDimProductPair( symDimProductPairLhs, symDimProductPairRhs); // For check symbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor From fcf88a419f8acddc5f0633cd2104c2fb173554ef Mon Sep 17 00:00:00 2001 From: liujinnan <1823192871@qq.com> Date: Thu, 14 Sep 2023 07:08:39 +0000 Subject: [PATCH 8/8] fix error. --- test/cpp/pir/shape_dialect/symbolic_op_test.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/cpp/pir/shape_dialect/symbolic_op_test.cc b/test/cpp/pir/shape_dialect/symbolic_op_test.cc index a15ab3560b3aa..f680982b2d146 100644 --- a/test/cpp/pir/shape_dialect/symbolic_op_test.cc +++ b/test/cpp/pir/shape_dialect/symbolic_op_test.cc @@ -346,7 +346,7 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { std::tie(newLhs, newRhs) = symDimMgr.SimplifySymbolicDimProductPair( symDimProductPairLhs, symDimProductPairRhs); - // For check symbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor + // For check SymbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor // = 2 Sym = {S4}} pir::SymbolicDimProduct symDimProductDivLhs; pir::SymbolicDimProduct symDimProductDivRhs; @@ -356,7 +356,7 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { symDimProductDivRhs.symbols.push_back(symDimS1); symDimProductDivRhs.symbols.push_back(symDimC10); - pir::SymbolicDimProduct *divRes = symDimMgr.symbolicDimProductDivide( + pir::SymbolicDimProduct *divRes = symDimMgr.SymbolicDimProductDivide( symDimProductDivLhs, symDimProductDivRhs); EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS1, symDimS2));