Skip to content

Commit

Permalink
Add ShapeOptimizationPass for SymbolicDim constraint construction. (#…
Browse files Browse the repository at this point in the history
…57069)

* add constraint_pass for symbolicDim construction.

* add UT.
  • Loading branch information
liuruyan authored Sep 15, 2023
1 parent 2ad4d49 commit d499021
Show file tree
Hide file tree
Showing 10 changed files with 531 additions and 174 deletions.
7 changes: 6 additions & 1 deletion paddle/pir/dialect/shape/ir/shape_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@ ShapeDialect::ShapeDialect(IrContext *context)
}

void ShapeDialect::initialize() {
RegisterOps<SymbolicDim, DimOp, TieProductEqualOp, TieShapeOp, FuncOp>();
RegisterOps<SymbolicDim,
DimOp,
TieProductEqualOp,
TieShapeOp,
FuncOp,
TensorDimOp>();
}

} // namespace dialect
Expand Down
56 changes: 51 additions & 5 deletions paddle/pir/dialect/shape/ir/shape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/pir/dialect/shape/ir/shape_op.h"
#include "paddle/pir/core/builtin_attribute.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/builtin_type.h"

namespace pir {
Expand Down Expand Up @@ -104,15 +105,15 @@ void SymbolicDim::updateKnownNonSizeZero(bool attrValue) {
pir::BoolAttribute::get(pir::IrContext::Instance(), attrValue));
}

bool SymbolicDim::isDynamic() {
bool SymbolicDim::IsDynamic() {
return getValue() == ShapedTypeInterface::kDynamic;
}

bool SymbolicDim::merge(SymbolicDim other) {
if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue())
bool SymbolicDim::Merge(SymbolicDim other) {
if (!IsDynamic() && !other.IsDynamic() && getValue() != other.getValue())
return false;
if (isDynamic() && !other.isDynamic()) updateValue(other.getValue());
if (!isDynamic() && other.isDynamic()) other.updateValue(getValue());
if (IsDynamic() && !other.IsDynamic()) updateValue(other.getValue());
if (!IsDynamic() && other.IsDynamic()) other.updateValue(getValue());

bool knownNonNegativeFlag =
getKnownNonNegative() || other.getKnownNonNegative();
Expand Down Expand Up @@ -213,9 +214,26 @@ void TieShapeOp::Build(Builder &builder,
const pir::OpResult &input) {
argument.inputs = {input};
}
void TieShapeOp::Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const pir::OpResult &input,
const std::vector<pir::OpResult> &dims) {
argument.inputs = {input};
for (auto &dim : dims) {
argument.inputs.push_back(dim);
}
}

pir::Value TieShapeOp::getValue() { return operand_source(0); }

std::vector<pir::Value> TieShapeOp::getShapeDimIndexes() {
std::vector<pir::Value> res;
for (uint32_t i = 1; i < num_operands(); i++) {
res.push_back(operand_source(i));
}
return res;
}

void FuncOp::Build(Builder &builder, OperationArgument &argument) {
argument.num_regions = 1;
}
Expand All @@ -226,6 +244,33 @@ pir::Block *FuncOp::block() {
return region.front();
}

void TensorDimOp::Build(Builder &builder,
OperationArgument &argument,
const pir::OpResult &source,
const pir::OpResult &index) {
argument.inputs = {source, index};
argument.output_types.emplace_back(
pir::IndexType::get(pir::IrContext::Instance()));
}

void TensorDimOp::Build(Builder &builder,
OperationArgument &argument,
const pir::OpResult &source,
int64_t index) {
pir::OpResult indexValue =
builder
.Build<pir::ConstantOp>(
pir::Int64Attribute::get(pir::IrContext::Instance(), 2),
pir::IndexType::get(pir::IrContext::Instance()))
->result(0);
argument.inputs = {source, indexValue};
argument.output_types.emplace_back(
pir::IndexType::get(pir::IrContext::Instance()));
}

pir::Value TensorDimOp::getSource() { return operand_source(0); }

pir::Value TensorDimOp::getIndex() { return operand_source(1); }
} // namespace dialect
} // namespace pir

Expand All @@ -234,3 +279,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::DimOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp)
IR_DEFINE_EXPLICIT_TYPE_ID(pir::dialect::TensorDimOp)
37 changes: 34 additions & 3 deletions paddle/pir/dialect/shape/ir/shape_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ class IR_API SymbolicDim : public Op<SymbolicDim> {
void updateKnownNonSizeOne(bool attrValue);
void updateKnownNonSizeZero(bool attrValue);

bool isDynamic();
bool merge(SymbolicDim other);
bool IsDynamic();
bool Merge(SymbolicDim other);

static const std::string getSymbolicDimAttrName() {
return "SymbolicDimAttr";
return "kSymbolicDimAttr";
}

void Verify() {}
Expand Down Expand Up @@ -112,7 +112,14 @@ class IR_API TieShapeOp : public Op<TieShapeOp> {
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const pir::OpResult &input);

static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const pir::OpResult &input,
const std::vector<pir::OpResult> &dims);

pir::Value getValue();
std::vector<pir::Value> getShapeDimIndexes();
void Verify() {}
};

Expand All @@ -129,6 +136,29 @@ class IR_API FuncOp : public Op<FuncOp> {
pir::Block *block();
void Verify() {}
};

class IR_API TensorDimOp : public Op<TensorDimOp> {
public:
using Op::Op;
static const char *name() { return "shape.tensor_dim"; }

static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;

static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const pir::OpResult &source,
const pir::OpResult &index);
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const pir::OpResult &source,
int64_t index);
pir::Value getIndex();
pir::Value getSource();
pir::OpResult out() { return result(0); }
void Verify() {}
};

} // namespace dialect
} // namespace pir

Expand All @@ -137,3 +167,4 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::DimOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieProductEqualOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TieShapeOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::FuncOp);
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(pir::dialect::TensorDimOp);
109 changes: 109 additions & 0 deletions paddle/pir/dialect/shape/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/dialect/shape/transforms/shape_optimization_pass.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/pir/dialect/shape/ir/shape_op.h"

#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/program.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"

namespace {

bool InsertTieShapeOnValue(pir::OpResult value,
pir::Builder& builder) { // NOLINT
auto ty = value.type().dyn_cast<paddle::dialect::DenseTensorType>();

if (!ty || ty.dims().size() == 0) return true;
std::vector<pir::OpResult> dimSizes;
for (int64_t dim = 0, rank = ty.dims().size(); dim < rank; ++dim) {
auto dimOp = builder.Build<pir::dialect::TensorDimOp>(value, dim);
dimSizes.push_back(dimOp.out());
}
builder.Build<pir::dialect::TieShapeOp>(value, dimSizes);
return true;
}

bool InsertTieShapeOnRegion(pir::Region* region);

bool InsertTieShapeOnOperation(pir::Operation* op,
pir::Builder& builder) { // NOLINT
if (op->isa<pir::dialect::TieShapeOp>()) return true;
// TODO(liujinnan): skip the specialized Ops.

for (size_t i = 0; i < op->num_regions(); ++i) {
if (!InsertTieShapeOnRegion(&(op->region(i)))) return false;
}
builder.SetInsertionPointAfter(op);
for (pir::OpResult v : op->results()) {
if (!InsertTieShapeOnValue(v, builder)) return false;
}

return true;
}

bool insertTieShapeOnBlock(pir::Block* block) {
pir::Builder builder =
pir::Builder(pir::IrContext::Instance(), block, block->begin());
// TODO(liujinnan): mapping block arguments

std::vector<pir::Operation*> op_list;
for (pir::Operation* op : *block) op_list.push_back(op);
for (pir::Operation* op : op_list) {
if (!InsertTieShapeOnOperation(op, builder)) return false;
}
return true;
}

bool InsertTieShapeOnRegion(pir::Region* region) {
for (pir::Block* block : *region) {
if (!insertTieShapeOnBlock(block)) return false;
}
return true;
}

bool MaterializeShapeComputation(pir::ModuleOp m) {
if (!InsertTieShapeOnRegion(&(m->region(0)))) return false;
// TODO(liujinnan): add rewitter pattern for reifyInferShape.
return true;
}

class ShapeOptimizationPass : public pir::Pass {
public:
ShapeOptimizationPass() : pir::Pass("shape_optimization", 0) {}

void Run(pir::Operation* op) override {
auto module_op = op->dyn_cast<pir::ModuleOp>();
IR_ENFORCE(module_op, "ShapeOptimizationPass should run on module op.");
MaterializeShapeComputation(module_op);
}

bool CanApplyOn(pir::Operation* op) const override {
return op->name() == "builtin.module" && op->num_regions() > 0;
}
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateShapeOptimizationPass() {
return std::make_unique<ShapeOptimizationPass>();
}

} // namespace pir

REGISTER_IR_PASS(shape_optimization, ShapeOptimizationPass);
26 changes: 26 additions & 0 deletions paddle/pir/dialect/shape/transforms/shape_optimization_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateShapeOptimizationPass();

} // namespace pir
Loading

0 comments on commit d499021

Please sign in to comment.