diff --git a/paddle/pir/dialect/shape/ir/shape_dialect.cc b/paddle/pir/dialect/shape/ir/shape_dialect.cc index fba4c69d466f4..611d2d95c4810 100644 --- a/paddle/pir/dialect/shape/ir/shape_dialect.cc +++ b/paddle/pir/dialect/shape/ir/shape_dialect.cc @@ -23,7 +23,12 @@ ShapeDialect::ShapeDialect(IrContext *context) } void ShapeDialect::initialize() { - RegisterOps(); + RegisterOps(); } } // namespace dialect diff --git a/paddle/pir/dialect/shape/ir/shape_op.cc b/paddle/pir/dialect/shape/ir/shape_op.cc index 1ad4484551092..0f560c916e0b6 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.cc +++ b/paddle/pir/dialect/shape/ir/shape_op.cc @@ -14,6 +14,7 @@ #include "paddle/pir/dialect/shape/ir/shape_op.h" #include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_op.h" #include "paddle/pir/core/builtin_type.h" namespace pir { @@ -104,15 +105,15 @@ void SymbolicDim::updateKnownNonSizeZero(bool attrValue) { pir::BoolAttribute::get(pir::IrContext::Instance(), attrValue)); } -bool SymbolicDim::isDynamic() { +bool SymbolicDim::IsDynamic() { return getValue() == ShapedTypeInterface::kDynamic; } -bool SymbolicDim::merge(SymbolicDim other) { - if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue()) +bool SymbolicDim::Merge(SymbolicDim other) { + if (!IsDynamic() && !other.IsDynamic() && getValue() != other.getValue()) return false; - if (isDynamic() && !other.isDynamic()) updateValue(other.getValue()); - if (!isDynamic() && other.isDynamic()) other.updateValue(getValue()); + if (IsDynamic() && !other.IsDynamic()) updateValue(other.getValue()); + if (!IsDynamic() && other.IsDynamic()) other.updateValue(getValue()); bool knownNonNegativeFlag = getKnownNonNegative() || other.getKnownNonNegative(); @@ -213,9 +214,26 @@ void TieShapeOp::Build(Builder &builder, const pir::OpResult &input) { argument.inputs = {input}; } +void TieShapeOp::Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const pir::OpResult &input, + const std::vector &dims) { + argument.inputs = {input}; + for (auto &dim : dims) { + argument.inputs.push_back(dim); + } +} pir::Value TieShapeOp::getValue() { return operand_source(0); } +std::vector TieShapeOp::getShapeDimIndexes() { + std::vector res; + for (uint32_t i = 1; i < num_operands(); i++) { + res.push_back(operand_source(i)); + } + return res; +} + void FuncOp::Build(Builder &builder, OperationArgument &argument) { argument.num_regions = 1; } @@ -226,6 +244,33 @@ pir::Block *FuncOp::block() { return region.front(); } +void TensorDimOp::Build(Builder &builder, + OperationArgument &argument, + const pir::OpResult &source, + const pir::OpResult &index) { + argument.inputs = {source, index}; + argument.output_types.emplace_back( + pir::IndexType::get(pir::IrContext::Instance())); +} + +void TensorDimOp::Build(Builder &builder, + OperationArgument &argument, + const pir::OpResult &source, + int64_t index) { + pir::OpResult indexValue = + builder + .Build( + pir::Int64Attribute::get(pir::IrContext::Instance(), 2), + pir::IndexType::get(pir::IrContext::Instance())) + ->result(0); + argument.inputs = {source, indexValue}; + argument.output_types.emplace_back( + pir::IndexType::get(pir::IrContext::Instance())); +} + +pir::Value TensorDimOp::getSource() { return operand_source(0); } + +pir::Value TensorDimOp::getIndex() { return operand_source(1); } } // namespace dialect } // namespace pir @@ -234,3 +279,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::DimOp) IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp) IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp) IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp) +IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TensorDimOp) diff --git a/paddle/pir/dialect/shape/ir/shape_op.h b/paddle/pir/dialect/shape/ir/shape_op.h index 70e408c03cfb6..d111e81c4989d 100644 --- a/paddle/pir/dialect/shape/ir/shape_op.h +++ b/paddle/pir/dialect/shape/ir/shape_op.h @@ -51,11 +51,11 @@ class IR_API SymbolicDim : public Op { void updateKnownNonSizeOne(bool attrValue); void updateKnownNonSizeZero(bool attrValue); - bool isDynamic(); - bool merge(SymbolicDim other); + bool IsDynamic(); + bool Merge(SymbolicDim other); static const std::string getSymbolicDimAttrName() { - return "SymbolicDimAttr"; + return "kSymbolicDimAttr"; } void Verify() {} @@ -112,7 +112,14 @@ class IR_API TieShapeOp : public Op { static void Build(Builder &builder, // NOLINT OperationArgument &argument, // NOLINT const pir::OpResult &input); + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const pir::OpResult &input, + const std::vector &dims); + pir::Value getValue(); + std::vector getShapeDimIndexes(); void Verify() {} }; @@ -129,6 +136,29 @@ class IR_API FuncOp : public Op { pir::Block *block(); void Verify() {} }; + +class IR_API TensorDimOp : public Op { + public: + using Op::Op; + static const char *name() { return "shape.tensor_dim"; } + + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const pir::OpResult &source, + const pir::OpResult &index); + static void Build(Builder &builder, // NOLINT + OperationArgument &argument, // NOLINT + const pir::OpResult &source, + int64_t index); + pir::Value getIndex(); + pir::Value getSource(); + pir::OpResult out() { return result(0); } + void Verify() {} +}; + } // namespace dialect } // namespace pir @@ -137,3 +167,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::DimOp); IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp); IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp); IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp); +IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TensorDimOp); diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc new file mode 100644 index 0000000000000..a4922c69ed0c5 --- /dev/null +++ b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/pir/dialect/shape/ir/shape_op.h" + +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" + +namespace { + +bool InsertTieShapeOnValue(pir::OpResult value, + pir::Builder& builder) { // NOLINT + auto ty = value.type().dyn_cast(); + + if (!ty || ty.dims().size() == 0) return true; + std::vector dimSizes; + for (int64_t dim = 0, rank = ty.dims().size(); dim < rank; ++dim) { + auto dimOp = builder.Build(value, dim); + dimSizes.push_back(dimOp.out()); + } + builder.Build(value, dimSizes); + return true; +} + +bool InsertTieShapeOnRegion(pir::Region* region); + +bool InsertTieShapeOnOperation(pir::Operation* op, + pir::Builder& builder) { // NOLINT + if (op->isa()) return true; + // TODO(liujinnan): skip the specialized Ops. + + for (size_t i = 0; i < op->num_regions(); ++i) { + if (!InsertTieShapeOnRegion(&(op->region(i)))) return false; + } + builder.SetInsertionPointAfter(op); + for (pir::OpResult v : op->results()) { + if (!InsertTieShapeOnValue(v, builder)) return false; + } + + return true; +} + +bool insertTieShapeOnBlock(pir::Block* block) { + pir::Builder builder = + pir::Builder(pir::IrContext::Instance(), block, block->begin()); + // TODO(liujinnan): mapping block arguments + + std::vector op_list; + for (pir::Operation* op : *block) op_list.push_back(op); + for (pir::Operation* op : op_list) { + if (!InsertTieShapeOnOperation(op, builder)) return false; + } + return true; +} + +bool InsertTieShapeOnRegion(pir::Region* region) { + for (pir::Block* block : *region) { + if (!insertTieShapeOnBlock(block)) return false; + } + return true; +} + +bool MaterializeShapeComputation(pir::ModuleOp m) { + if (!InsertTieShapeOnRegion(&(m->region(0)))) return false; + // TODO(liujinnan): add rewitter pattern for reifyInferShape. + return true; +} + +class ShapeOptimizationPass : public pir::Pass { + public: + ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {} + + void Run(pir::Operation* op) override { + auto module_op = op->dyn_cast(); + IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op."); + MaterializeShapeComputation(module_op); + } + + bool CanApplyOn(pir::Operation* op) const override { + return op->name() == "builtin.module" && op->num_regions() > 0; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateShapeOptimizationPass() { + return std::make_unique(); +} + +} // namespace pir + +REGISTER_IR_PASS(shape_optimization, ShapeOptimizationPass); diff --git a/paddle/pir/dialect/shape/transforms/shape_optimization_pass.h b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.h new file mode 100644 index 0000000000000..43bad532c920d --- /dev/null +++ b/paddle/pir/dialect/shape/transforms/shape_optimization_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateShapeOptimizationPass(); + +} // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_utils.cc b/paddle/pir/dialect/shape/utils/shape_utils.cc index 1de3f03620961..99ec8f57bc2c2 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.cc +++ b/paddle/pir/dialect/shape/utils/shape_utils.cc @@ -17,7 +17,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" namespace pir { -bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { +bool CompareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { if (lhs.size() < 1 || (lhs[0] != 'S' && lhs[0] != 'C')) return lhs < rhs; if (rhs.size() < 1 || (rhs[0] != 'S' && rhs[0] != 'C')) return lhs < rhs; int64_t lhsIdx = 0, rhsIdx = 0; @@ -30,14 +30,14 @@ bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) { return (lhs[0] < rhs[0]) || (lhs[0] == rhs[0] && lhsIdx < rhsIdx); } -bool compareSymbolicDimProduct(SymbolicDimProduct& lhs, // NOLINT +bool CompareSymbolicDimProduct(SymbolicDimProduct& lhs, // NOLINT SymbolicDimProduct& rhs) { // NOLINT if (lhs.symbols.size() < rhs.symbols.size()) return true; if (lhs.symbols.size() == rhs.symbols.size()) { for (size_t idx = 0; idx < lhs.symbols.size(); ++idx) { const std::string lhsName = lhs.symbols[idx].getSymName(); const std::string rhsName = rhs.symbols[idx].getSymName(); - if (compareSymbolicDimNames(lhsName, rhsName)) return true; + if (CompareSymbolicDimNames(lhsName, rhsName)) return true; if (lhsName != rhsName) return false; } } @@ -60,7 +60,7 @@ const std::string SymbolTable::insert(Operation* symbol) { return name; } -bool SymbolicDimMgr::load() { +bool SymbolicDimMgr::Load() { auto funcOp = symbolTable_.getOp()->dyn_cast(); assert(funcOp); for (auto op_ : *(funcOp.block())) { @@ -70,14 +70,14 @@ bool SymbolicDimMgr::load() { symbolNameSet_.insert(op.getSymName()); } } - return loadShapeConstraintGraph(); + return LoadShapeConstraintGraph(); } -bool SymbolicDimMgr::loadShapeConstraintGraph() { +bool SymbolicDimMgr::LoadShapeConstraintGraph() { // TODO(liujinnan): add more constraint function. currently, only support // tie_product_equal. auto constraint_vec = - symbolTable_.lookup("tie_product_equal"); + symbolTable_.Lookup("tie_product_equal"); if (!constraint_vec.size()) return true; @@ -89,7 +89,7 @@ bool SymbolicDimMgr::loadShapeConstraintGraph() { product.factor *= constOp.value().dyn_cast().data(); continue; } else if (auto dimOp = definingOp->dyn_cast()) { - auto sym = symbolTable_.lookup(dimOp.getName()); + auto sym = symbolTable_.Lookup(dimOp.getName()); if (!sym) return false; product.symbols.push_back(sym); continue; @@ -103,7 +103,7 @@ bool SymbolicDimMgr::loadShapeConstraintGraph() { SymbolicDimProduct lhs, rhs; if (!build_sym_product(op.getLhs(), lhs) || !build_sym_product(op.getRhs(), rhs) || - !mapSymbolicDimProductEqual(lhs, rhs)) + !MapSymbolicDimProductEqual(lhs, rhs)) return false; } return true; @@ -115,24 +115,24 @@ int64_t gcd(int64_t m, int64_t n) { return (m < n) ? gcd(m, n % m) : gcd(m % n, n); } -bool SymbolicDimMgr::mapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, +bool SymbolicDimMgr::MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { SymbolicDimProduct newLhs, newRhs; - std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs); + std::tie(newLhs, newRhs) = SimplifySymbolicDimProductPair(lhs, rhs); // early return for identity case. if (newLhs == newRhs) return true; if (newLhs.factor == newRhs.factor && newLhs.symbols.size() == 1 && newRhs.symbols.size() == 1) { - return mapSymbolicDimEqual(newLhs.symbols[0], newRhs.symbols[0]); + return MapSymbolicDimEqual(newLhs.symbols[0], newRhs.symbols[0]); } else if (newLhs.symbols.size() == 0 && newRhs.symbols.size() == 1 && newRhs.factor == 1) { - return mapSymbolicDimEqual(newConstantSymbolicDim(newLhs.factor), + return MapSymbolicDimEqual(NewConstantSymbolicDim(newLhs.factor), newRhs.symbols[0]); } else if (newRhs.symbols.size() == 0 && newLhs.symbols.size() == 1 && newLhs.factor == 1) { - return mapSymbolicDimEqual(newConstantSymbolicDim(newRhs.factor), + return MapSymbolicDimEqual(NewConstantSymbolicDim(newRhs.factor), newLhs.symbols[0]); } @@ -144,10 +144,10 @@ bool SymbolicDimMgr::mapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, } std::pair -SymbolicDimMgr::simplifySymbolicDimProductPair(const SymbolicDimProduct& x, +SymbolicDimMgr::SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, const SymbolicDimProduct& y) { - auto lhs = simplifySymbolicDimProduct(x); - auto rhs = simplifySymbolicDimProduct(y); + auto lhs = SimplifySymbolicDimProduct(x); + auto rhs = SimplifySymbolicDimProduct(y); SymbolicDimProduct newLhs, newRhs; int64_t gcdFactor = gcd(std::abs(lhs.factor), std::abs(rhs.factor)); @@ -190,19 +190,19 @@ SymbolicDimMgr::simplifySymbolicDimProductPair(const SymbolicDimProduct& x, return std::make_pair(std::move(newLhs), std::move(newRhs)); } -SymbolicDimProduct SymbolicDimMgr::simplifySymbolicDimProduct( +SymbolicDimProduct SymbolicDimMgr::SimplifySymbolicDimProduct( const SymbolicDimProduct& x) { std::vector copied; copied.reserve(x.symbols.size()); - for (SymbolicDim op : x.symbols) copied.push_back(getRootSymbolicDim(op)); + for (SymbolicDim op : x.symbols) copied.push_back(GetRootSymbolicDim(op)); sort(copied.begin(), copied.end(), [&](SymbolicDim lhs, SymbolicDim rhs) { - return compareSymbolicDimNames(lhs.getSymName(), rhs.getSymName()); + return CompareSymbolicDimNames(lhs.getSymName(), rhs.getSymName()); }); SymbolicDimProduct newX; newX.factor = x.factor; for (SymbolicDim op : copied) { - if (!op.isDynamic()) { + if (!op.IsDynamic()) { newX.factor *= op.getValue(); } else { newX.symbols.push_back(op); @@ -211,7 +211,7 @@ SymbolicDimProduct SymbolicDimMgr::simplifySymbolicDimProduct( return newX; } -const std::string SymbolicDimMgr::getNextName() { +const std::string SymbolicDimMgr::GetNextName() { std::string name; do { name = "S" + std::to_string(nextSymbolicIdx_++); @@ -231,13 +231,13 @@ SymbolicDimMgr::SymbolicDimMgr(ModuleOp m) : m_(m) { symbolTable_ = SymbolTable(func); } -SymbolicDim SymbolicDimMgr::newSymbolicDim(const std::string& name) { +SymbolicDim SymbolicDimMgr::NewSymbolicDim(const std::string& name) { auto funcOp = symbolTable_.getOp()->dyn_cast(); assert(funcOp); Builder builder = Builder(m_.ir_context(), funcOp.block()); // default settting dim != 0 dialect::SymbolicDim symbol = - builder.Build(name.empty() ? getNextName() : name, + builder.Build(name.empty() ? GetNextName() : name, ShapedTypeInterface::kDynamic, false, false, @@ -248,12 +248,12 @@ SymbolicDim SymbolicDimMgr::newSymbolicDim(const std::string& name) { return symbol; } -SymbolicDim SymbolicDimMgr::newConstantSymbolicDim(int64_t val) { +SymbolicDim SymbolicDimMgr::NewConstantSymbolicDim(int64_t val) { auto it = constantSymbolicDimMap_.find(val); if (it == constantSymbolicDimMap_.end()) { auto name = "C" + std::to_string(val); it = constantSymbolicDimMap_ - .insert(std::make_pair(val, newSymbolicDim(name))) + .insert(std::make_pair(val, NewSymbolicDim(name))) .first; it->second.updateValue(val); if (val == -1) it->second.updateKnownNegativeOne(true); @@ -261,22 +261,22 @@ SymbolicDim SymbolicDimMgr::newConstantSymbolicDim(int64_t val) { if (val != 1) it->second.updateKnownNonSizeOne(true); if (val != 0) it->second.updateKnownNonSizeZero(true); } - return getRootSymbolicDim(it->second); + return GetRootSymbolicDim(it->second); } -std::vector SymbolicDimMgr::createSymbolicDimsForRankedValue( +std::vector SymbolicDimMgr::CreateSymbolicDimsForRankedValue( Value value) { std::vector symbols; auto dims = value.type().dyn_cast().dims(); for (int idx = 0; idx < dims.size(); ++idx) { symbols.push_back(dims[idx] == ShapedTypeInterface::kDynamic - ? newSymbolicDim() - : newConstantSymbolicDim(dims[idx])); + ? NewSymbolicDim() + : NewConstantSymbolicDim(dims[idx])); } return symbols; } -SymbolicDim SymbolicDimMgr::getRootSymbolicDim(SymbolicDim symbol) { +SymbolicDim SymbolicDimMgr::GetRootSymbolicDim(SymbolicDim symbol) { SymbolicDim current = symbol; std::vector path; while (symbolDimUnionSet_[current] != current) { @@ -287,32 +287,32 @@ SymbolicDim SymbolicDimMgr::getRootSymbolicDim(SymbolicDim symbol) { return current; } -bool SymbolicDimMgr::isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { - SymbolicDim lhsRoot = getRootSymbolicDim(lhs); - SymbolicDim rhsRoot = getRootSymbolicDim(rhs); +bool SymbolicDimMgr::IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { + SymbolicDim lhsRoot = GetRootSymbolicDim(lhs); + SymbolicDim rhsRoot = GetRootSymbolicDim(rhs); return lhsRoot == rhsRoot; } -bool SymbolicDimMgr::mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { - SymbolicDim lhsRoot = getRootSymbolicDim(lhs); - SymbolicDim rhsRoot = getRootSymbolicDim(rhs); +bool SymbolicDimMgr::MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) { + SymbolicDim lhsRoot = GetRootSymbolicDim(lhs); + SymbolicDim rhsRoot = GetRootSymbolicDim(rhs); if (lhsRoot != rhsRoot) { - if (compareSymbolicDimNames(lhsRoot.getSymName(), rhsRoot.getSymName())) { - if (!lhsRoot.merge(rhsRoot)) return false; + if (CompareSymbolicDimNames(lhsRoot.getSymName(), rhsRoot.getSymName())) { + if (!lhsRoot.Merge(rhsRoot)) return false; symbolDimUnionSet_[rhsRoot] = lhsRoot; } else { - if (!rhsRoot.merge(lhsRoot)) return false; + if (!rhsRoot.Merge(lhsRoot)) return false; symbolDimUnionSet_[lhsRoot] = rhsRoot; } } return true; } -SymbolicDimProduct* SymbolicDimMgr::symbolicDimProductDivide( +SymbolicDimProduct* SymbolicDimMgr::SymbolicDimProductDivide( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { SymbolicDimProduct newLhs, newRhs; - std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs); + std::tie(newLhs, newRhs) = SimplifySymbolicDimProductPair(lhs, rhs); if (newLhs.factor == 0 || newRhs.factor == 0) return nullptr; if (newLhs.factor % newRhs.factor != 0) return nullptr; @@ -340,16 +340,16 @@ SymbolicDimProduct* SymbolicDimMgr::symbolicDimProductDivide( return result; } -bool SymbolicDimMgr::isMultipleOfKnownSymbolicDimProductEqualPair( +bool SymbolicDimMgr::IsMultipleOfKnownSymbolicDimProductEqualPair( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { for (auto& pairOutter : productEqualityMap_) { const SymbolicDimProduct& x = pairOutter.first; - auto factorX = symbolicDimProductDivide(lhs, x); + auto factorX = SymbolicDimProductDivide(lhs, x); if (!factorX) continue; for (auto& pairInner : pairOutter.second) { if (!pairInner.second) continue; const SymbolicDimProduct& y = pairInner.first; - auto factorY = symbolicDimProductDivide(rhs, y); + auto factorY = SymbolicDimProductDivide(rhs, y); if (!factorY || (*factorX) != (*factorY)) continue; return true; } @@ -358,7 +358,7 @@ bool SymbolicDimMgr::isMultipleOfKnownSymbolicDimProductEqualPair( return false; } -bool SymbolicDimMgr::updateProductEqualityMap() { +bool SymbolicDimMgr::UpdateProductEqualityMap() { // early return if nothing is updated. if (productEqualityMapUpdated_) return true; @@ -370,7 +370,7 @@ bool SymbolicDimMgr::updateProductEqualityMap() { if (!pairInner.second) continue; const SymbolicDimProduct& y = pairInner.first; SymbolicDimProduct newX, newY; - std::tie(newX, newY) = simplifySymbolicDimProductPair(x, y); + std::tie(newX, newY) = SimplifySymbolicDimProductPair(x, y); if (newX == newY) continue; newMap[newX][newY] = newMap[newY][newX] = true; productSet.insert(newX); @@ -425,7 +425,7 @@ bool SymbolicDimMgr::updateProductEqualityMap() { for (auto& y : productSet) { if (!productEqualityMap_[x][y]) continue; productEqualityMap_[x][y] = productEqualityMap_[y][x] = false; - if (!isMultipleOfKnownSymbolicDimProductEqualPair(x, y)) { + if (!IsMultipleOfKnownSymbolicDimProductEqualPair(x, y)) { productEqualityMap_[x][y] = productEqualityMap_[y][x] = true; } } @@ -449,25 +449,25 @@ bool SymbolicDimMgr::updateProductEqualityMap() { return true; } -bool SymbolicDimMgr::isSymbolicDimProductEqual(const SymbolicDimProduct& lhs, +bool SymbolicDimMgr::IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs) { SymbolicDimProduct newLhs, newRhs; - std::tie(newLhs, newRhs) = simplifySymbolicDimProductPair(lhs, rhs); + std::tie(newLhs, newRhs) = SimplifySymbolicDimProductPair(lhs, rhs); // early return for identity case. if (newLhs == newRhs) return true; - IR_ENFORCE(updateProductEqualityMap(), "Update product equality map failed."); - return isMultipleOfKnownSymbolicDimProductEqualPair(newLhs, newRhs); + IR_ENFORCE(UpdateProductEqualityMap(), "Update product equality map failed."); + return IsMultipleOfKnownSymbolicDimProductEqualPair(newLhs, newRhs); } -bool SymbolicDimMgr::save() { +bool SymbolicDimMgr::Save() { using Name2SymbolFn = std::function; auto updateAttrs = [&](ArrayAttribute attrs, Name2SymbolFn fn) { std::vector newAttrs; for (Attribute attr : attrs.AsVector()) { auto sym = fn(attr.dyn_cast().AsString()); assert(sym); - SymbolicDim root = getRootSymbolicDim(sym); + SymbolicDim root = GetRootSymbolicDim(sym); Attribute rootSymbol = StrAttribute::get(m_->ir_context(), root.getSymName()); newAttrs.push_back(rootSymbol); @@ -481,11 +481,11 @@ bool SymbolicDimMgr::save() { auto attrs = op->attribute(SymbolicDim::getSymbolicDimAttrName()); auto symbolicShapeAttr = updateAttrs(attrs, [&](const std::string& name) { - return symbolTable_.lookup(name); + return symbolTable_.Lookup(name); }); op->set_attribute(SymbolicDim::getSymbolicDimAttrName(), symbolicShapeAttr); } - if (!updateProductEqualityMap()) { + if (!UpdateProductEqualityMap()) { return false; } std::unordered_set usedSymbolicOps; @@ -493,7 +493,7 @@ bool SymbolicDimMgr::save() { // TODO(liujinnan): collect uses in value. auto collectUsedSymbols = [&](ArrayAttribute attrs) { for (Attribute attr : attrs.AsVector()) { - auto sym = symbolTable_.lookup( + auto sym = symbolTable_.Lookup( attr.dyn_cast().AsString()); assert(sym); if (usedSymbolicOps.insert(sym).second) @@ -539,7 +539,7 @@ bool SymbolicDimMgr::save() { std::sort(usedSymbolNames.begin(), usedSymbolNames.end(), [&](const std::string& lhs, const std::string& rhs) { - return compareSymbolicDimNames(lhs, rhs); + return CompareSymbolicDimNames(lhs, rhs); }); int numNonConstDims = 0; std::unordered_map nameMapping; @@ -569,10 +569,10 @@ bool SymbolicDimMgr::save() { // TODO(liujinnan): update attributes attached to values. - return saveShapeConstraintGraph(); + return SaveShapeConstraintGraph(); } -bool SymbolicDimMgr::saveShapeConstraintGraph() { +bool SymbolicDimMgr::SaveShapeConstraintGraph() { auto funcOp = symbolTable_.getOp()->dyn_cast(); assert(funcOp); auto op_it = funcOp.block()->rbegin(); @@ -605,10 +605,10 @@ bool SymbolicDimMgr::saveShapeConstraintGraph() { for (auto& p : productEqualityMap_) sortedProductVec.push_back(p.first); std::sort(sortedProductVec.begin(), sortedProductVec.end(), - compareSymbolicDimProduct); + CompareSymbolicDimProduct); for (auto& x : sortedProductVec) { for (auto& y : sortedProductVec) { - if (!compareSymbolicDimProduct(x, y)) continue; + if (!CompareSymbolicDimProduct(x, y)) continue; if (!productEqualityMap_[x][y]) continue; auto lhsOperands = build_operands(x); auto rhsOperands = build_operands(y); @@ -618,17 +618,17 @@ bool SymbolicDimMgr::saveShapeConstraintGraph() { return true; } -bool ShapeAnalysis::isSameNumElements(Value lhs, Value rhs) { +bool ShapeAnalysis::IsSameNumElements(Value lhs, Value rhs) { if (lhs == rhs) return true; auto lhsTy = lhs.type().dyn_cast_interface(); auto rhsTy = rhs.type().dyn_cast_interface(); if (!lhsTy || !rhsTy || !lhsTy.hasRank() || !rhsTy.hasRank()) return false; - return isProductEqual(lhs, 0, lhsTy.getRank(), rhs, 0, rhsTy.getRank()); + return IsProductEqual(lhs, 0, lhsTy.getRank(), rhs, 0, rhsTy.getRank()); } -bool ShapeAnalysis::isProductEqual( +bool ShapeAnalysis::IsProductEqual( Value lhs, int lhsFrom, int lhsTo, Value rhs, int rhsFrom, int rhsTo) { std::vector lhsDimIdxs, rhsDimIdxs; lhsDimIdxs.reserve(lhsTo - lhsFrom); @@ -636,12 +636,12 @@ bool ShapeAnalysis::isProductEqual( for (int i = lhsFrom; i < lhsTo; ++i) lhsDimIdxs.push_back(i); for (int i = rhsFrom; i < rhsTo; ++i) rhsDimIdxs.push_back(i); - return isProductEqual(lhs, lhsDimIdxs, rhs, rhsDimIdxs); + return IsProductEqual(lhs, lhsDimIdxs, rhs, rhsDimIdxs); } SymbolicDimShapeAnalysis::SymbolicDimShapeAnalysis(ModuleOp m) : m_(m), mgr_(m) { - mgr_.load(); + mgr_.Load(); for (auto op : *(m_.block())) { auto tieShapeOp = op->dyn_cast(); if (!tieShapeOp) continue; @@ -652,7 +652,7 @@ SymbolicDimShapeAnalysis::SymbolicDimShapeAnalysis(ModuleOp m) .attribute(SymbolicDim::getSymbolicDimAttrName()) .AsVector(); for (const auto& attr : attrs) { - auto symOp = mgr_.symbolTable().lookup( + auto symOp = mgr_.symbolTable().Lookup( attr.dyn_cast().AsString()); if (!symOp) continue; symbols.push_back(symOp); @@ -660,9 +660,9 @@ SymbolicDimShapeAnalysis::SymbolicDimShapeAnalysis(ModuleOp m) } } -SymbolicDimShapeAnalysis::~SymbolicDimShapeAnalysis() { mgr_.save(); } +SymbolicDimShapeAnalysis::~SymbolicDimShapeAnalysis() { mgr_.Save(); } -bool SymbolicDimShapeAnalysis::isShapeEqual(Value lhs, Value rhs) { +bool SymbolicDimShapeAnalysis::IsShapeEqual(Value lhs, Value rhs) { if (lhs == rhs) return true; auto lhsTy = lhs.type().dyn_cast_interface(); @@ -684,15 +684,15 @@ bool SymbolicDimShapeAnalysis::isShapeEqual(Value lhs, Value rhs) { std::vector lhsSyms; std::vector rhsSyms; for (auto sym : lhsIt->second) { - lhsSyms.push_back(mgr_.getRootSymbolicDim(sym)); + lhsSyms.push_back(mgr_.GetRootSymbolicDim(sym)); } for (auto sym : rhsIt->second) { - rhsSyms.push_back(mgr_.getRootSymbolicDim(sym)); + rhsSyms.push_back(mgr_.GetRootSymbolicDim(sym)); } return lhsSyms == rhsSyms; } -bool SymbolicDimShapeAnalysis::isProductEqual(Value lhs, +bool SymbolicDimShapeAnalysis::IsProductEqual(Value lhs, std::vector lhsDimIdxs, Value rhs, std::vector rhsDimIdxs) { @@ -722,6 +722,6 @@ bool SymbolicDimShapeAnalysis::isProductEqual(Value lhs, return false; } - return mgr_.isSymbolicDimProductEqual(lhsProd, rhsProd); + return mgr_.IsSymbolicDimProductEqual(lhsProd, rhsProd); } } // namespace pir diff --git a/paddle/pir/dialect/shape/utils/shape_utils.h b/paddle/pir/dialect/shape/utils/shape_utils.h index f00bff91a5517..ab728fc6a00a0 100644 --- a/paddle/pir/dialect/shape/utils/shape_utils.h +++ b/paddle/pir/dialect/shape/utils/shape_utils.h @@ -53,7 +53,7 @@ class SymbolTable { template typename std::enable_if::value, SymbolicDim>::type - lookup(const std::string& name) const { + Lookup(const std::string& name) const { auto it = symbolTableMap_.find(name); return it != symbolTableMap_.end() ? it->second->dyn_cast() : SymbolicDim(nullptr); @@ -61,7 +61,7 @@ class SymbolTable { template typename std::enable_if::value, std::vector>::type - lookup(const std::string& name) const { + Lookup(const std::string& name) const { std::vector res; auto it = symbolFuncMap_.find(name); if (it != symbolFuncMap_.end()) { @@ -101,35 +101,35 @@ struct SymProductHasher { class SymbolicDimMgr { public: explicit SymbolicDimMgr(ModuleOp m); - bool load(); - SymbolicDim newSymbolicDim(const std::string& name = {}); - SymbolicDim newConstantSymbolicDim(int64_t val); - std::vector createSymbolicDimsForRankedValue(Value value); - SymbolicDim getRootSymbolicDim(SymbolicDim symbol); - bool isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + bool Load(); + SymbolicDim NewSymbolicDim(const std::string& name = {}); + SymbolicDim NewConstantSymbolicDim(int64_t val); + std::vector CreateSymbolicDimsForRankedValue(Value value); + SymbolicDim GetRootSymbolicDim(SymbolicDim symbol); + bool IsSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); SymbolTable& symbolTable() { return symbolTable_; } - bool mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); - SymbolicDimProduct simplifySymbolicDimProduct(const SymbolicDimProduct& x); + bool MapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs); + SymbolicDimProduct SimplifySymbolicDimProduct(const SymbolicDimProduct& x); std::pair - simplifySymbolicDimProductPair(const SymbolicDimProduct& x, + SimplifySymbolicDimProductPair(const SymbolicDimProduct& x, const SymbolicDimProduct& y); - SymbolicDimProduct* symbolicDimProductDivide(const SymbolicDimProduct& x, + SymbolicDimProduct* SymbolicDimProductDivide(const SymbolicDimProduct& x, const SymbolicDimProduct& y); - bool save(); + bool Save(); - bool isSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + bool IsSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); - bool mapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, + bool MapSymbolicDimProductEqual(const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); private: - const std::string getNextName(); - bool updateProductEqualityMap(); - bool isMultipleOfKnownSymbolicDimProductEqualPair( + const std::string GetNextName(); + bool UpdateProductEqualityMap(); + bool IsMultipleOfKnownSymbolicDimProductEqualPair( const SymbolicDimProduct& lhs, const SymbolicDimProduct& rhs); - bool saveShapeConstraintGraph(); - bool loadShapeConstraintGraph(); + bool SaveShapeConstraintGraph(); + bool LoadShapeConstraintGraph(); private: ModuleOp m_; @@ -157,15 +157,15 @@ class ShapeAnalysis { public: virtual ~ShapeAnalysis() = default; - virtual bool isShapeEqual(Value lhs, Value rhs) = 0; + virtual bool IsShapeEqual(Value lhs, Value rhs) = 0; - virtual bool isProductEqual(Value lhs, + virtual bool IsProductEqual(Value lhs, std::vector lhsDimIdxs, Value rhs, std::vector rhsDimIdxs) = 0; - virtual bool isProductEqual( + virtual bool IsProductEqual( Value lhs, int lhsFrom, int lhsTo, Value rhs, int rhsFrom, int rhsTo); - virtual bool isSameNumElements(Value lhs, Value rhs); + virtual bool IsSameNumElements(Value lhs, Value rhs); }; class SymbolicDimShapeAnalysis : public ShapeAnalysis { @@ -175,9 +175,9 @@ class SymbolicDimShapeAnalysis : public ShapeAnalysis { SymbolicDimMgr& symbolicDimMgr() { return mgr_; } const SymbolicDimMgr& symbolicDimMgr() const { return mgr_; } - bool isShapeEqual(Value lhs, Value rhs) override; + bool IsShapeEqual(Value lhs, Value rhs) override; - bool isProductEqual(Value lhs, + bool IsProductEqual(Value lhs, std::vector lhsDimIdxs, Value rhs, std::vector rhsDimIdxs) override; diff --git a/test/cpp/pir/shape_dialect/CMakeLists.txt b/test/cpp/pir/shape_dialect/CMakeLists.txt index 73c635713f99d..d5fe787de4a80 100644 --- a/test/cpp/pir/shape_dialect/CMakeLists.txt +++ b/test/cpp/pir/shape_dialect/CMakeLists.txt @@ -6,3 +6,16 @@ cc_test_old( pd_op_dialect pir gtest) + +cc_test_old( + constraint_pass_test + SRCS + constraint_pass_test.cc + DEPS + gtest + pd_op_dialect + pir) + +set_tests_properties( + constraint_pass_test PROPERTIES ENVIRONMENT + "FLAGS_enable_new_ir_in_executor=true") diff --git a/test/cpp/pir/shape_dialect/constraint_pass_test.cc b/test/cpp/pir/shape_dialect/constraint_pass_test.cc new file mode 100644 index 0000000000000..c99d7493af09a --- /dev/null +++ b/test/cpp/pir/shape_dialect/constraint_pass_test.cc @@ -0,0 +1,96 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/builtin_dialect.h" +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/core/builtin_type.h" +#include "paddle/pir/core/cast_utils.h" +#include "paddle/pir/core/dialect.h" +#include "paddle/pir/core/enforce.h" +#include "paddle/pir/core/ir_context.h" +#include "paddle/pir/core/op_info.h" +#include "paddle/pir/core/parameter.h" +#include "paddle/pir/core/program.h" +#include "paddle/pir/core/value.h" +#include "paddle/pir/dialect/shape/ir/shape_dialect.h" +#include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_manager.h" + +pir::AttributeMap CreateAttributeMap( + const std::vector &attribute_names, + const std::vector &attributes) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::AttributeMap attr_map; + for (size_t i = 0; i < attribute_names.size(); i++) { + pir::Attribute attr_value = pir::StrAttribute::get(ctx, attributes[i]); + attr_map.insert( + std::pair(attribute_names[i], attr_value)); + } + return attr_map; +} + +pir::Operation *CreateDenseTensorOp( + pir::IrContext *ctx, + const phi::DDim &dims, + const std::vector &attribute_names, + const std::vector &attributes) { + std::vector op_inputs = {}; + pir::Type fp32_dtype = pir::Float32Type::get(ctx); + phi::DataLayout data_layout = phi::DataLayout::NCHW; + phi::LoD lod = {{0, 1, 2}}; + size_t offset = 0; + std::vector op_output_types = { + paddle::dialect::DenseTensorType::get( + ctx, fp32_dtype, dims, data_layout, lod, offset)}; + pir::Operation *op = + pir::Operation::Create(op_inputs, + CreateAttributeMap(attribute_names, attributes), + op_output_types, + pir::OpInfo()); + return op; +} + +TEST(constraint_pass, materialize_shape) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + pir::PassManager pm(ctx); + ctx->GetOrRegisterDialect(); + ctx->GetOrRegisterDialect(); + pir::Operation *op0 = + CreateDenseTensorOp(ctx, {-100000, 2}, {"op0_attr"}, {"op0_name"}); + program.block()->push_back(op0); + pir::Operation *op1 = + CreateDenseTensorOp(ctx, {-100000, 2, 2}, {"op1_attr"}, {"op1_name"}); + program.block()->push_back(op1); + + EXPECT_EQ(program.block()->size(), static_cast(2)); + pm.AddPass(pir::CreateShapeOptimizationPass()); + EXPECT_TRUE(pm.Run(&program)); + // 5 ConstantOp + 5 TensorDim + 2 TieShape + op0 + op1 == 14 Ops. + EXPECT_EQ(program.block()->size(), static_cast(14)); +} diff --git a/test/cpp/pir/shape_dialect/symbolic_op_test.cc b/test/cpp/pir/shape_dialect/symbolic_op_test.cc index 87f4623f811ce..f680982b2d146 100644 --- a/test/cpp/pir/shape_dialect/symbolic_op_test.cc +++ b/test/cpp/pir/shape_dialect/symbolic_op_test.cc @@ -77,8 +77,8 @@ TEST(assist_struct_test, symbolic_dim) { EXPECT_FALSE(symDim.getKnownNonSizeZero()); EXPECT_FALSE(symDim.getKnownNonNegative()); - EXPECT_FALSE(symDim.isDynamic()); - EXPECT_TRUE(symDim.merge(symDim_)); + EXPECT_FALSE(symDim.IsDynamic()); + EXPECT_TRUE(symDim.Merge(symDim_)); symDim.updateValue(20); symDim.updateSymName("S2"); @@ -87,7 +87,7 @@ TEST(assist_struct_test, symbolic_dim) { symDim.updateKnownNonSizeZero(true); symDim.updateKnownNonNegative(true); - EXPECT_FALSE(symDim.merge(symDim_)); + EXPECT_FALSE(symDim.Merge(symDim_)); EXPECT_EQ(symDim.getValue(), 20); EXPECT_EQ(symDim.getSymName(), "S2"); @@ -123,9 +123,9 @@ TEST(assist_struct_test, symbolic_dim_table) { pir::SymbolTable symbolTable(program.module_op()); EXPECT_EQ(symbolTable.insert(symDim), "S0"); - EXPECT_EQ(symbolTable.lookup("S0"), symDim); + EXPECT_EQ(symbolTable.Lookup("S0"), symDim); EXPECT_EQ(symbolTable.getOp(), program.module_op()); - EXPECT_FALSE(symbolTable.lookup("S1")); + EXPECT_FALSE(symbolTable.Lookup("S1")); } TEST(assist_struct_test, symbolic_dim_mgr_simple) { @@ -138,17 +138,17 @@ TEST(assist_struct_test, symbolic_dim_mgr_simple) { ctx->GetOrRegisterDialect(); pir::SymbolicDimMgr symDimMgr(program.module_op()); - pir::dialect::SymbolicDim symDimS0 = symDimMgr.newSymbolicDim(); - pir::dialect::SymbolicDim symDimS1 = symDimMgr.newSymbolicDim(); - pir::dialect::SymbolicDim symDimC10 = symDimMgr.newConstantSymbolicDim(10); - symDimMgr.mapSymbolicDimEqual(symDimS0, symDimS1); + pir::dialect::SymbolicDim symDimS0 = symDimMgr.NewSymbolicDim(); + pir::dialect::SymbolicDim symDimS1 = symDimMgr.NewSymbolicDim(); + pir::dialect::SymbolicDim symDimC10 = symDimMgr.NewConstantSymbolicDim(10); + symDimMgr.MapSymbolicDimEqual(symDimS0, symDimS1); auto op = CreateDenseTensorOp( ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); pir::Value res = op->result(0); std::vector symDimVec = - symDimMgr.createSymbolicDimsForRankedValue(res); + symDimMgr.CreateSymbolicDimsForRankedValue(res); EXPECT_EQ(symDimS0.getSymName(), "S0"); EXPECT_EQ(symDimS1.getSymName(), "S1"); @@ -157,13 +157,13 @@ TEST(assist_struct_test, symbolic_dim_mgr_simple) { EXPECT_EQ(symDimC10.getValue(), 10); EXPECT_EQ(symDimVec[0].getSymName(), "S2"); EXPECT_EQ(symDimVec[1].getSymName(), "C2"); - EXPECT_EQ(symDimMgr.symbolTable().lookup("S0"), + EXPECT_EQ(symDimMgr.symbolTable().Lookup("S0"), symDimS0); - EXPECT_EQ(symDimMgr.symbolTable().lookup("C10"), + EXPECT_EQ(symDimMgr.symbolTable().Lookup("C10"), symDimC10); - EXPECT_EQ(symDimMgr.getRootSymbolicDim(symDimS1), symDimS0); - EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimS1)); - EXPECT_FALSE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimC10)); + EXPECT_EQ(symDimMgr.GetRootSymbolicDim(symDimS1), symDimS0); + EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS0, symDimS1)); + EXPECT_FALSE(symDimMgr.IsSymbolicDimEqual(symDimS0, symDimC10)); } TEST(assist_struct_test, symbolic_dim_mgr_complex) { @@ -181,21 +181,21 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { pir::Builder builder = pir::Builder(ctx, funcOp.block()); - pir::dialect::SymbolicDim symDimS0 = symDimMgr.newSymbolicDim("S0"); - pir::dialect::SymbolicDim symDimS1 = symDimMgr.newSymbolicDim("S1"); - pir::dialect::SymbolicDim symDimS2 = symDimMgr.newSymbolicDim("S2"); - pir::dialect::SymbolicDim symDimS3 = symDimMgr.newSymbolicDim("S3"); - pir::dialect::SymbolicDim symDimS4 = symDimMgr.newSymbolicDim("S4"); - pir::dialect::SymbolicDim symDimS5 = symDimMgr.newSymbolicDim("S5"); - pir::dialect::SymbolicDim symDimS6 = symDimMgr.newSymbolicDim("S6"); - pir::dialect::SymbolicDim symDimS7 = symDimMgr.newSymbolicDim("S7"); - pir::dialect::SymbolicDim symDimS8 = symDimMgr.newSymbolicDim("S8"); - pir::dialect::SymbolicDim symDimS9 = symDimMgr.newSymbolicDim("S9"); - pir::dialect::SymbolicDim symDimS10 = symDimMgr.newSymbolicDim("S10"); - pir::dialect::SymbolicDim symDimS11 = symDimMgr.newSymbolicDim("S11"); - pir::dialect::SymbolicDim symDimS12 = symDimMgr.newSymbolicDim("S12"); - pir::dialect::SymbolicDim symDimC10 = symDimMgr.newConstantSymbolicDim(10); - pir::dialect::SymbolicDim symDimC20 = symDimMgr.newConstantSymbolicDim(20); + pir::dialect::SymbolicDim symDimS0 = symDimMgr.NewSymbolicDim("S0"); + pir::dialect::SymbolicDim symDimS1 = symDimMgr.NewSymbolicDim("S1"); + pir::dialect::SymbolicDim symDimS2 = symDimMgr.NewSymbolicDim("S2"); + pir::dialect::SymbolicDim symDimS3 = symDimMgr.NewSymbolicDim("S3"); + pir::dialect::SymbolicDim symDimS4 = symDimMgr.NewSymbolicDim("S4"); + pir::dialect::SymbolicDim symDimS5 = symDimMgr.NewSymbolicDim("S5"); + pir::dialect::SymbolicDim symDimS6 = symDimMgr.NewSymbolicDim("S6"); + pir::dialect::SymbolicDim symDimS7 = symDimMgr.NewSymbolicDim("S7"); + pir::dialect::SymbolicDim symDimS8 = symDimMgr.NewSymbolicDim("S8"); + pir::dialect::SymbolicDim symDimS9 = symDimMgr.NewSymbolicDim("S9"); + pir::dialect::SymbolicDim symDimS10 = symDimMgr.NewSymbolicDim("S10"); + pir::dialect::SymbolicDim symDimS11 = symDimMgr.NewSymbolicDim("S11"); + pir::dialect::SymbolicDim symDimS12 = symDimMgr.NewSymbolicDim("S12"); + pir::dialect::SymbolicDim symDimC10 = symDimMgr.NewConstantSymbolicDim(10); + pir::dialect::SymbolicDim symDimC20 = symDimMgr.NewConstantSymbolicDim(20); pir::OpResult dimOpS0 = builder.Build("S0").out(); pir::OpResult dimOpS1 = builder.Build("S1").out(); @@ -301,7 +301,7 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { tieShapeOp_->set_attribute( pir::dialect::SymbolicDim::getSymbolicDimAttrName(), arrayAttr_); - EXPECT_TRUE(symDimMgr.load()); + EXPECT_TRUE(symDimMgr.Load()); // For check indirect equality: S1 * S4 == S2 * S5 pir::SymbolicDimProduct symDimProductLhs; @@ -325,14 +325,14 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { symDimProductRhs_.symbols.push_back(symDimS11); symDimProductRhs_.symbols.push_back(symDimS12); - // For check simplifySymbolicDimProduct, {factor = 1, Sym = {S7}} => {factor = + // For check SimplifySymbolicDimProduct, {factor = 1, Sym = {S7}} => {factor = // 10} pir::SymbolicDimProduct symDimProductS7; symDimProductS7.symbols.push_back(symDimS7); pir::SymbolicDimProduct simplifiedProductS7 = - symDimMgr.simplifySymbolicDimProduct(symDimProductS7); + symDimMgr.SimplifySymbolicDimProduct(symDimProductS7); - // For check simplifySymbolicDimProductPair, X * Y * Y, Y * Y * Z => X, Z + // For check SimplifySymbolicDimProductPair, X * Y * Y, Y * Y * Z => X, Z pir::SymbolicDimProduct symDimProductPairLhs; pir::SymbolicDimProduct symDimProductPairRhs; pir::SymbolicDimProduct newLhs, newRhs; @@ -343,10 +343,10 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { symDimProductPairRhs.symbols.push_back(symDimS2); symDimProductPairRhs.symbols.push_back(symDimS3); - std::tie(newLhs, newRhs) = symDimMgr.simplifySymbolicDimProductPair( + std::tie(newLhs, newRhs) = symDimMgr.SimplifySymbolicDimProductPair( symDimProductPairLhs, symDimProductPairRhs); - // For check symbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor + // For check SymbolicDimProductDivide, {S4 * S1 * C20} / {S1 * C10} => {factor // = 2 Sym = {S4}} pir::SymbolicDimProduct symDimProductDivLhs; pir::SymbolicDimProduct symDimProductDivRhs; @@ -356,39 +356,39 @@ TEST(assist_struct_test, symbolic_dim_mgr_complex) { symDimProductDivRhs.symbols.push_back(symDimS1); symDimProductDivRhs.symbols.push_back(symDimC10); - pir::SymbolicDimProduct *divRes = symDimMgr.symbolicDimProductDivide( + pir::SymbolicDimProduct *divRes = symDimMgr.SymbolicDimProductDivide( symDimProductDivLhs, symDimProductDivRhs); - EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS1, symDimS2)); - EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimS3)); - EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS4, symDimS5)); + EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS1, symDimS2)); + EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS0, symDimS3)); + EXPECT_TRUE(symDimMgr.IsSymbolicDimEqual(symDimS4, symDimS5)); EXPECT_EQ(symDimS6.getValue(), 200); - EXPECT_EQ(symDimMgr.symbolTable().lookup("C20"), + EXPECT_EQ(symDimMgr.symbolTable().Lookup("C20"), symDimC20); EXPECT_EQ(symDimS7.getValue(), symDimC10.getValue()); EXPECT_EQ(simplifiedProductS7.factor, 10); EXPECT_EQ(simplifiedProductS7.symbols.size(), static_cast(0)); EXPECT_EQ(newLhs.symbols.size(), static_cast(1)); EXPECT_EQ(newRhs.symbols.size(), static_cast(1)); - EXPECT_EQ(newLhs.symbols[0], symDimMgr.getRootSymbolicDim(symDimS4)); - EXPECT_EQ(newRhs.symbols[0], symDimMgr.getRootSymbolicDim(symDimS3)); + EXPECT_EQ(newLhs.symbols[0], symDimMgr.GetRootSymbolicDim(symDimS4)); + EXPECT_EQ(newRhs.symbols[0], symDimMgr.GetRootSymbolicDim(symDimS3)); EXPECT_EQ(divRes->factor, 2); EXPECT_EQ(divRes->symbols.size(), static_cast(1)); - EXPECT_EQ(divRes->symbols[0], symDimMgr.getRootSymbolicDim(symDimS4)); + EXPECT_EQ(divRes->symbols[0], symDimMgr.GetRootSymbolicDim(symDimS4)); EXPECT_TRUE( - symDimMgr.isSymbolicDimProductEqual(symDimProductLhs, symDimProductRhs)); - EXPECT_TRUE(symDimMgr.isSymbolicDimProductEqual(symDimProductLhs_, + symDimMgr.IsSymbolicDimProductEqual(symDimProductLhs, symDimProductRhs)); + EXPECT_TRUE(symDimMgr.IsSymbolicDimProductEqual(symDimProductLhs_, symDimProductRhs_)); - EXPECT_TRUE(symDimMgr.save()); + EXPECT_TRUE(symDimMgr.Save()); pir::SymbolicDimMgr symDimMgr_(program.module_op()); - EXPECT_TRUE(symDimMgr_.load()); + EXPECT_TRUE(symDimMgr_.Load()); auto attrs = tieShapeOp.attribute( pir::dialect::SymbolicDim::getSymbolicDimAttrName()); EXPECT_FALSE( - symDimMgr_.symbolTable().lookup("S7")); + symDimMgr_.symbolTable().Lookup("S7")); EXPECT_EQ(symDimMgr_.symbolTable() - .lookup("tie_product_equal") + .Lookup("tie_product_equal") .size(), static_cast(1)); @@ -437,10 +437,10 @@ TEST(shape_op, tie_product_equal) { EXPECT_EQ(symbolTable.insert(tie_product_equal), "tie_product_equal"); EXPECT_EQ( - symbolTable.lookup("tie_product_equal") + symbolTable.Lookup("tie_product_equal") .size(), static_cast(1)); - EXPECT_EQ(symbolTable.lookup( + EXPECT_EQ(symbolTable.Lookup( "tie_product_equal")[0], tie_product_equal); EXPECT_EQ(lhs, lhs_ref); @@ -572,17 +572,48 @@ TEST(assist_struct_test, shape_analysis) { pir::dialect::SymbolicDim::getSymbolicDimAttrName(), attrOp5); pir::SymbolicDimShapeAnalysis shapeAnalysis(program.module_op()); - EXPECT_TRUE(shapeAnalysis.isShapeEqual(value3, value4)); - EXPECT_FALSE(shapeAnalysis.isShapeEqual(value1, value2)); - EXPECT_FALSE(shapeAnalysis.isShapeEqual(value1, value3)); - EXPECT_FALSE(shapeAnalysis.isShapeEqual(value1, value5)); - EXPECT_FALSE(shapeAnalysis.isShapeEqual(value3, value5)); - EXPECT_TRUE(shapeAnalysis.isProductEqual(value1, {1}, value3, {0})); - EXPECT_TRUE(shapeAnalysis.isSameNumElements(value4, value3)); - - shapeAnalysis.symbolicDimMgr().mapSymbolicDimEqual(symDimS0, symDimS1); - shapeAnalysis.symbolicDimMgr().mapSymbolicDimEqual(symDimS0, symDimS2); - - EXPECT_TRUE(shapeAnalysis.isShapeEqual(value1, value2)); - EXPECT_FALSE(shapeAnalysis.isShapeEqual(value1, value5)); + EXPECT_TRUE(shapeAnalysis.IsShapeEqual(value3, value4)); + EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value2)); + EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value3)); + EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value5)); + EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value3, value5)); + EXPECT_TRUE(shapeAnalysis.IsProductEqual(value1, {1}, value3, {0})); + EXPECT_TRUE(shapeAnalysis.IsSameNumElements(value4, value3)); + + shapeAnalysis.symbolicDimMgr().MapSymbolicDimEqual(symDimS0, symDimS1); + shapeAnalysis.symbolicDimMgr().MapSymbolicDimEqual(symDimS0, symDimS2); + + EXPECT_TRUE(shapeAnalysis.IsShapeEqual(value1, value2)); + EXPECT_FALSE(shapeAnalysis.IsShapeEqual(value1, value5)); +} + +TEST(shape_op, tensor_dim) { + pir::IrContext *ctx = pir::IrContext::Instance(); + pir::Program program(ctx); + ctx->GetOrRegisterDialect(); + pir::Builder builder = pir::Builder(ctx, program.block()); + + pir::Operation *op = CreateDenseTensorOp( + ctx, {pir::ShapedTypeInterface::kDynamic, 2}, {"op_attr"}, {"op_name"}); + pir::OpResult resDenseTensorValue = op->result(0); + + pir::dialect::TensorDimOp tensorDimOp0 = + builder.Build(resDenseTensorValue, 0); + pir::OpResult res0 = tensorDimOp0.out(); + + pir::OpResult indexValue = + builder + .Build( + pir::Int64Attribute::get(pir::IrContext::Instance(), 1), + pir::IndexType::get(pir::IrContext::Instance())) + ->result(0); + pir::dialect::TensorDimOp tensorDimOp1 = + builder.Build(resDenseTensorValue, indexValue); + pir::OpResult res1 = tensorDimOp1.out(); + + EXPECT_EQ(res0.type(), pir::IndexType::get(ctx)); + EXPECT_EQ(res1.type(), pir::IndexType::get(ctx)); + EXPECT_EQ(tensorDimOp0.getSource(), resDenseTensorValue); + EXPECT_EQ(tensorDimOp1.getSource(), resDenseTensorValue); + EXPECT_EQ(tensorDimOp1.getIndex(), indexValue); }