Skip to content

Commit

Permalink
[Prim][PIR] Sink Forward Prim (PaddlePaddle#58130)
Browse files Browse the repository at this point in the history
* decomp sink

* polish code

* fix flag

* fix code

* fix code

* fix code2

* fix code2

* fix code3
  • Loading branch information
cyber-pioneer authored and jiahy0825 committed Oct 26, 2023
1 parent d1fc667 commit 62150b2
Show file tree
Hide file tree
Showing 13 changed files with 368 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# =====================================
# DecompInterface gen op list
# =====================================


decomp_interface_declare_gen_op_list = ['mean']

decomp_interface_implementation_gen_op_list = ["mean"]
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import sys

import yaml
from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list
from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke
from op_interface_gen import (
gen_exclusive_interface_str,
Expand Down Expand Up @@ -58,6 +59,7 @@
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/trait/inplace.h"
#include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h"
#include "paddle/fluid/framework/infershape_utils.h"
Expand Down Expand Up @@ -1036,6 +1038,8 @@ def OpGenerator(
and op_info.op_phi_name[0] not in vjp_interface_black_list
):
op_interfaces += ["paddle::dialect::VjpInterface"]
if op_info.op_phi_name[0] in decomp_interface_declare_gen_op_list:
op_interfaces += ["paddle::dialect::DecompInterface"]
exclusive_interface_str = gen_exclusive_interface_str(
op_info, op_info_items
)
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from decomp_interface_gen_op_list import decomp_interface_declare_gen_op_list

# generator interfaces
from vjp_interface_black_list import vjp_interface_black_list

Expand Down Expand Up @@ -316,4 +318,6 @@ def gen_exclusive_interface_str(op_info, op_info_items):
)
if op_info.op_phi_name[0] not in vjp_interface_black_list:
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
if op_info.op_phi_name[0] in decomp_interface_declare_gen_op_list:
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Decomp(pir::Operation* op);"
return exclusive_interface_str
52 changes: 52 additions & 0 deletions paddle/fluid/pir/dialect/operator/interface/decomp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include "paddle/pir/core/op_base.h"

namespace paddle {
namespace dialect {
class DecompInterface : public pir::OpInterfaceBase<DecompInterface> {
public:
struct Concept {
explicit Concept(
std::vector<std::vector<pir::OpResult>> (*decomp)(pir::Operation* op))
: decomp_(decomp) {}
std::vector<std::vector<pir::OpResult>> (*decomp_)(pir::Operation* op);
};

template <class ConcreteOp>
struct Model : public Concept {
static std::vector<std::vector<pir::OpResult>> Decomp(pir::Operation* op) {
return ConcreteOp::Decomp(op);
}
Model() : Concept(Decomp) {}
};

/// Constructor
DecompInterface(pir::Operation* op, Concept* impl)
: pir::OpInterfaceBase<DecompInterface>(op), impl_(impl) {}

std::vector<std::vector<pir::OpResult>> Decomp(pir::Operation* op) {
return impl_->decomp_(op);
}

private:
Concept* impl_;
};

} // namespace dialect
} // namespace paddle

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::DecompInterface)
2 changes: 2 additions & 0 deletions paddle/fluid/pir/dialect/operator/interface/interface.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
Expand All @@ -37,3 +38,4 @@ std::vector<std::vector<pir::OpResult>> VjpInterface::Vjp(
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::InferMetaInterface)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OpYamlInfoInterface)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::VjpInterface)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DecompInterface)
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,6 @@ target_include_directories(pd_op_dialect_api INTERFACE ${PD_DIALECT_BINARY_DIR})

cc_library(
pd_op_dialect
SRCS op_dialect.cc manual_op_vjp.cc ${op_vjp_source_file}
SRCS op_dialect.cc manual_op_decomp.cc manual_op_vjp.cc ${op_vjp_source_file}
DEPS pd_op_dialect_api param_to_variable primitive_vjp_experimental
pd_op_dialect_utils op_yaml_info_parser)
65 changes: 65 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/primitive/composite/composite.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/core/op_base.h"

// TODO(chenzhuo)
// this file will be generated in pd_op_decomp.cc

namespace paddle {
namespace dialect {
using IntArray = paddle::experimental::IntArray;

std::vector<std::vector<pir::OpResult>> MeanOp::Decomp(pir::Operation* op) {
MeanOp op_obj = op->dyn_cast<MeanOp>();
(void)op_obj;

VLOG(4) << "Decomp Prepare inputs of mean";

Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));

VLOG(4) << "Decomp prepare attributes of mean";

IntArray axis = op->attribute("axis")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data();

bool keepdim = op->attribute("keepdim").dyn_cast<pir::BoolAttribute>().data();
VLOG(4) << "Decomp mean keep_dim " << keepdim;

VLOG(4) << "Decomp prepare call mean's decomp interface";

Tensor op_res =
paddle::primitive::details::mean_decomp<primitive::LazyTensor>(
x, axis, keepdim);

auto org_res = op->results();
std::vector<std::vector<pir::OpResult>> res(org_res.size());
res[0].push_back(
std::static_pointer_cast<primitive::LazyTensor>(op_res.impl())
->value()
.dyn_cast<pir::OpResult>());
return res;
}

} // namespace dialect
} // namespace paddle
48 changes: 46 additions & 2 deletions paddle/fluid/primitive/composite/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,55 @@

#pragma once

namespace paddle {
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/fluid/primitive/utils/utils.h"

namespace paddle {
namespace primitive {
namespace details {

template <typename T>
Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) {
auto org_dtype = x.dtype();
auto x_tmp = x;
bool need_cast = org_dtype == phi::DataType::FLOAT16 ||
org_dtype == phi::DataType::BFLOAT16;
if (need_cast) {
x_tmp = cast<T>(x, phi::DataType::FLOAT32);
}
std::vector<int64_t> x_dim = phi::vectorize<int64_t>(x_tmp.dims());
int64_t axis_size = axis.size();
int64_t x_dim_size = x_dim.size();
auto axis_ = std::vector<int64_t>();
if (axis_size == 0) {
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
}
} else {
axis_ = axis.GetData();
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_[i] = axis[i] + x_dim_size;
}
}
}

int64_t value = 1;
for (size_t i = 0; i < axis_.size(); i++) {
value *= x_dim[axis_[i]];
}
auto sum_x = sum<T>(x_tmp, IntArray(axis_), x_tmp.dtype(), keepdim);
auto res = divide<T>(
sum_x, full<T>(phi::vectorize(sum_x.dims()), value, sum_x.dtype()));
if (need_cast) {
return cast<T>(res, org_dtype);
} else {
return res;
}
}

namespace experimental {}
} // namespace details

} // namespace primitive
} // namespace paddle
38 changes: 38 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ limitations under the License. */
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/imperative/layout_autotune.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
#include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h"
#include "paddle/fluid/prim/utils/eager/eager_tensor_operants.h"
Expand Down Expand Up @@ -766,6 +767,42 @@ void BindVjp(pybind11::module *m) {
out (bool): True means that the op has custom vjp rules, False means it does not.
)DOC");
}

void BindDecomp(pybind11::module *m) {
m->def("call_decomp", [](pir::Operation &fwd_op) {
py::list res;
paddle::dialect::DecompInterface decomp_interface =
fwd_op.dyn_cast<paddle::dialect::DecompInterface>();
PADDLE_ENFORCE(
decomp_interface,
phi::errors::InvalidArgument(
"The decomp function is not registered in %s op ", fwd_op.name()));
std::vector<std::vector<pir::OpResult>> decomp_res =
decomp_interface.Decomp(&fwd_op);
for (size_t i = 0; i < decomp_res.size(); ++i) {
py::list sub_res;
for (size_t j = 0; j < decomp_res[i].size(); ++j) {
if (!decomp_res[i][j]) {
sub_res.append(nullptr);
} else {
sub_res.append(decomp_res[i][j]);
}
}
res.append(sub_res);
}
return res;
});

m->def("has_decomp", [](pir::Operation &fwd_op) {
pir::IrContext *ctx = pir::IrContext::Instance();
pir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name());
auto decomp_interface_impl =
fwd_op_info.GetInterfaceImpl<paddle::dialect::DecompInterface>();
if (decomp_interface_impl == nullptr) return false;
return true;
});
}

PYBIND11_MODULE(libpaddle, m) {
BindImperative(&m);
BindEager(&m);
Expand Down Expand Up @@ -2940,6 +2977,7 @@ All parameter, weight, gradient are variables in Paddle.

BindPIR(&m);
BindVjp(&m);
BindDecomp(&m);
}
} // namespace pybind
} // namespace paddle
22 changes: 20 additions & 2 deletions python/paddle/decomposition/decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import typing

from paddle import pir
from paddle.base.core import call_decomp, has_decomp
from paddle.base.libpaddle.pir import Block, Operation, Program
from paddle.framework import core

Expand All @@ -30,6 +31,18 @@ def _build_tensor_tuple(xs):
return TypeError(f"Type {type(xs)} is not supported.")


def _analyse_decomp_results(orig_outs, decomp_outs):
assert len(orig_outs) == len(decomp_outs)
res = []
for org_item, new_item in zip(orig_outs, decomp_outs):
if isinstance(org_item, pir.OpResult):
assert len(new_item) == 1 and isinstance(new_item[0], pir.OpResult)
res.append(new_item[0])
else:
res.append(new_item)
return res


def _prepare_python_api_arguments(op):
"""
For standard api of operator, its inputs should keep consistent with organization of its inputs and attrs.
Expand Down Expand Up @@ -215,7 +228,8 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter):
for idx, op in enumerate(ops_list):
op_name = op.name()
decom_rule = register.get_decomp_rule(op_name)
lower = decom_rule and op_filter(op)
has_sink_decomp_rule = has_decomp(op)
lower = (decom_rule or has_sink_decomp_rule) and op_filter(op)

if op.name() == "builtin.combine":
temp_op = op
Expand All @@ -231,7 +245,11 @@ def _decompose_subgraph(block, orig_vars, dst_vars, op_filter):
pir.set_insertion_point(op)
input_args = _prepare_python_api_arguments(op)
orig_outs = op.results()
new_outs = _build_tensor_tuple(decom_rule(*input_args))
if has_sink_decomp_rule:
decomp_outs = call_decomp(op)
new_outs = _analyse_decomp_results(orig_outs, decomp_outs)
else:
new_outs = _build_tensor_tuple(decom_rule(*input_args))

# Todo: To cover such case: some outputs are no longer needed after decomposition.
_check_op_results(
Expand Down
1 change: 0 additions & 1 deletion python/paddle/decomposition/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from .register import register_decomp


@register_decomp('pd_op.mean')
def mean(x, axis, keepdim):
"""define composite rule of op mean"""
x_shape = x.shape
Expand Down
2 changes: 1 addition & 1 deletion test/prim/pir_prim/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
set(TEST_PRIM_PURE_NEW_IR_CASES
test_prim_program test_prim_simpnet test_prim_custom_vjp test_prim_jit
test_pir_prim_flags)
test_pir_prim_flags test_sink_decomp)

foreach(target ${TEST_PRIM_PURE_NEW_IR_CASES})
py_test_modules(${target} MODULES ${target} ENVS GLOG_v=1
Expand Down
Loading

0 comments on commit 62150b2

Please sign in to comment.