Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【CINN】Integate cast_simplify into ir_simplify #56958

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddle/cinn/common/cas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h"

namespace cinn {
Expand Down
3 changes: 1 addition & 2 deletions paddle/cinn/common/ir_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/cast_simplify.h"

namespace cinn {
namespace common {
Expand Down Expand Up @@ -147,7 +146,7 @@ Expr IndiceToAbsOffset(const std::vector<Expr> &shape,
for (int i = 0; i < shape.size(); i++) {
CHECK_EQ(shape[i].type(), Int(32));
Expr indice_prod = indices[i];
optim::CastSimplify(&indice_prod);
optim::SimplifyCast(&indice_prod);
for (int j = i + 1; j < shape.size(); j++) {
indice_prod = RampRelatedMul(indice_prod, shape[j]);
}
Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/optim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ gather_srcs(
compute_inline_expand.cc
buffer_assign.cc
replace_const_param_to_integer.cc
cast_simplify.cc
lower_intrin.cc
cast_bool_to_int8.cc
collect_undefined_vars.cc
Expand Down
117 changes: 0 additions & 117 deletions paddle/cinn/optim/cast_simplify.cc

This file was deleted.

31 changes: 0 additions & 31 deletions paddle/cinn/optim/cast_simplify.h

This file was deleted.

12 changes: 5 additions & 7 deletions paddle/cinn/optim/cast_simplify_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,27 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/optim/cast_simplify.h"

#include <gtest/gtest.h>

#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"

#include "paddle/cinn/optim/ir_simplify.h"
namespace cinn::optim {

TEST(CastSimplify, same_type) {
Var n("n");
Expr a = ir::Cast::Make(Int(32), n);
LOG(INFO) << n->type();
LOG(INFO) << a;
CastSimplify(&a);
SimplifyCast(&a);
ASSERT_EQ(utils::GetStreamCnt(a), "n");
}

TEST(CastSimplify, Imm_int) {
Expr a = ir::Cast::Make(Int(64), Expr(1));
Expr c = ir::Cast::Make(Int(32), a);
LOG(INFO) << c;
CastSimplify(&c);
SimplifyCast(&c);
LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "1");
ASSERT_EQ(c.type(), Int(32));
Expand All @@ -44,7 +42,7 @@ TEST(CastSimplify, Imm_double) {
Expr a = ir::Cast::Make(Float(64), Expr(2.33));
Expr c = ir::Cast::Make(Int(32), a);
LOG(INFO) << c;
CastSimplify(&c);
SimplifyCast(&c);
LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "2");
ASSERT_EQ(c.type(), Int(32));
Expand All @@ -54,7 +52,7 @@ TEST(CastSimplify, Imm_uint) {
Expr a = ir::Cast::Make(UInt(64), Expr(1));
Expr c = ir::Cast::Make(UInt(32), a);
LOG(INFO) << c;
CastSimplify(&c);
SimplifyCast(&c);
LOG(INFO) << c;
ASSERT_EQ(utils::GetStreamCnt(c), "1");
ASSERT_EQ(c.type(), UInt(32));
Expand Down
90 changes: 88 additions & 2 deletions paddle/cinn/optim/ir_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,14 @@
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h"

namespace cinn {
namespace optim {
using namespace ir; // NOLINT
using common::bfloat16;
using common::ExprToGinacConverter;
using common::float16;
using utils::GetStreamCnt;
using utils::Replace;

Expand Down Expand Up @@ -359,11 +360,95 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> {
}
};

template <typename CastType, typename T>
CastType NormCastValue(T value) {
if (type_of<CastType>().is_uint() || type_of<T>().is_uint()) {
// not support uint
return static_cast<CastType>(value);
}

if (std::isinf(value)) {
return std::numeric_limits<CastType>::infinity();
} else if (std::isnan(value)) {
return std::numeric_limits<CastType>::signaling_NaN();
} else if (value >= static_cast<T>(std::numeric_limits<CastType>::max())) {
return std::numeric_limits<CastType>::max();
} else if (value <= static_cast<T>(std::numeric_limits<CastType>::lowest())) {
return std::numeric_limits<CastType>::lowest();
}
return static_cast<CastType>(value);
}

struct SimplifyCastMutator : public ir::IRMutator<> {
void operator()(Expr* expr) { ir::IRMutator<ir::Expr*>::Visit(expr, expr); }

void Visit(const ir::Cast* op, Expr* expr) {
auto* node = expr->As<ir::Cast>();

ir::IRMutator<ir::Expr*>::Visit(&node->v(), &node->v());

if (op->type() == op->v().type()) {
*expr = op->v();
return;
}

#define __CAST_TO_TYPE(type__) \
if (auto* i = op->v().As<ir::IntImm>()) { \
*expr = Expr(static_cast<type__>(i->value)); \
} else if (auto* f = op->v().As<ir::FloatImm>()) { \
*expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \
} else if (auto* u = op->v().As<ir::UIntImm>()) { \
*expr = Expr(static_cast<type__>(u->value)); \
} else { \
CINN_NOT_IMPLEMENTED \
}

if (op->v().is_constant()) {
if (op->type() == type_of<int8_t>()) {
__CAST_TO_TYPE(int8_t)
} else if (op->type() == type_of<int16_t>()) {
__CAST_TO_TYPE(int16_t)
} else if (op->type() == type_of<int32_t>()) {
__CAST_TO_TYPE(int32_t)
} else if (op->type() == type_of<int64_t>()) {
__CAST_TO_TYPE(int64_t)
} else if (op->type() == type_of<uint8_t>()) {
__CAST_TO_TYPE(uint8_t)
} else if (op->type() == type_of<uint16_t>()) {
__CAST_TO_TYPE(uint16_t)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<float>()) {
__CAST_TO_TYPE(float)
} else if (op->type() == type_of<double>()) {
__CAST_TO_TYPE(double)
} else if (op->type() == type_of<bool>()) {
__CAST_TO_TYPE(bool)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<bfloat16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(bfloat16)
} else if (op->type() == type_of<float16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(float16)
} else {
CINN_NOT_IMPLEMENTED
}
}
#undef __CAST_TO_TYPE
}
};

} // namespace

void Simplify(Expr* expr) {
VLOG(3) << "Begin Simplify " << *expr;
optim::CastSimplify(expr);
SimplifyCastMutator()(expr);
SimplifyRampMutator()(expr);
SimplifyLoadMutator()(expr);
SimplifyStoreMutator()(expr);
Expand All @@ -376,6 +461,7 @@ void Simplify(Expr* expr) {
ReplaceFracWithDivMutator()(expr);
}

void SimplifyCast(Expr* expr) { SimplifyCastMutator()(expr); }
void SimplifyForLoops(Expr* expr) { SimplifyForLoopsMutator()(expr); }
void SimplifyBlocks(Expr* expr) { SimplifyBlocksMutator()(expr); }

Expand Down
2 changes: 2 additions & 0 deletions paddle/cinn/optim/ir_simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ namespace optim {
*/
void Simplify(Expr *expr);

void SimplifyCast(Expr *expr);

void SimplifyForLoops(Expr *expr);

void SimplifyBlocks(Expr *expr);
Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/optim/optimize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/call_arg_list_to_pod_value.h"
#include "paddle/cinn/optim/cast_bool_to_int8.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h"
#include "paddle/cinn/optim/extern_call_process.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
Expand Down