From 62ba74fad496d5133b2937199ac97afd21bc1ac9 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Wed, 6 Sep 2023 07:08:43 +0000 Subject: [PATCH 1/7] rewrite tensor writer --- paddle/cinn/ir/utils/ir_nodes_collector.cc | 19 +++++++++++++++++++ paddle/cinn/ir/utils/ir_nodes_collector.h | 2 ++ paddle/cinn/optim/vectorize_loops.cc | 4 ++-- 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/paddle/cinn/ir/utils/ir_nodes_collector.cc b/paddle/cinn/ir/utils/ir_nodes_collector.cc index e99da88a1dd35e..c4ae3d0c8e0e0e 100644 --- a/paddle/cinn/ir/utils/ir_nodes_collector.cc +++ b/paddle/cinn/ir/utils/ir_nodes_collector.cc @@ -207,5 +207,24 @@ std::set CollectReferencedTensors( return ts0; } +std::set CollectTensorNeedsWrite(const Expr* e) { + std::set tensor_written; + IrNodesCollector::handler_t handler = [&](const Expr* x) { + tensor_written.insert(x->As()->name); + }; + IrNodesCollector::teller_t teller = [](const Expr* x) { + if (x->As() && x->As()->tensor.As()) { + return true; + } + if (x->As() && x->As()->is_call_node()) { + return true; + } + return false; + }; + IrNodesCollector collector(std::move(teller), std::move(handler), true); + collector.Visit(e); + return tensor_written; +} + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/ir/utils/ir_nodes_collector.h b/paddle/cinn/ir/utils/ir_nodes_collector.h index 75ed3fa9e64f4d..085c1c099550ed 100755 --- a/paddle/cinn/ir/utils/ir_nodes_collector.h +++ b/paddle/cinn/ir/utils/ir_nodes_collector.h @@ -65,5 +65,7 @@ std::map CollectTensorMap( return true; }); +std::set CollectTensorNeedsWrite(const Expr* e); + } // namespace ir } // namespace cinn diff --git a/paddle/cinn/optim/vectorize_loops.cc b/paddle/cinn/optim/vectorize_loops.cc index 745bec47b45073..3b7202df83dc46 100644 --- a/paddle/cinn/optim/vectorize_loops.cc +++ b/paddle/cinn/optim/vectorize_loops.cc @@ -185,7 +185,7 @@ class CudaVectorizer : public IRMutator { const Var iter_var_; // the loop var of the vecotrized loop const int factor_; // the factor for vectorize - TensorWriteTeller write_teller_; + std::set write_teller_; TensorVectorizeTeller vectorized_teller_; absl::flat_hash_map tensor2vectorized_vars_; @@ -215,7 +215,7 @@ class CudaVectorizer : public IRMutator { } void Visit(Expr *expr) { - write_teller_.Collect(expr); + write_teller_ = ir::CollectTensorNeedsWrite(expr); vectorized_teller_.Collect(expr); IRMutator::Visit(expr, expr); } From f2279b5021e942ce7019947d056c34d2a4ce1964 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Wed, 6 Sep 2023 07:19:14 +0000 Subject: [PATCH 2/7] add modification --- paddle/cinn/ir/lowered_func.cc | 10 ++++------ paddle/cinn/lang/lower_impl.cc | 12 +++++------- 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/paddle/cinn/ir/lowered_func.cc b/paddle/cinn/ir/lowered_func.cc index 84e8fb3e974e7a..77243cfa83aaa6 100644 --- a/paddle/cinn/ir/lowered_func.cc +++ b/paddle/cinn/ir/lowered_func.cc @@ -209,8 +209,7 @@ void _LoweredFunc_::AllocTempBuffer() {} void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) { buffer_data_cast_exprs.clear(); // collect write. - optim::TensorWriteTeller write_teller; - write_teller.Collect(&body); + auto write_teller = ir::CollectTensorNeedsWrite(&body); auto tensors = CollectAllTensorReference(with_expr_gen_tensor); std::sort(tensors.begin(), @@ -224,7 +223,7 @@ void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) { if (!tensor->buffer.defined()) continue; Type value_type = tensor->type().ElementOf(); - bool is_const = !write_teller.IsWrite(tensor->name); + bool is_const = !write_teller.count(tensor->name); value_type.set_cpp_handle(); value_type.set_cpp_const(is_const); Var variable = _Var_::Make(tensor->name, value_type); @@ -250,8 +249,7 @@ std::vector _LoweredFunc_::CudaAliasVarExprs() const { } // collect write. std::vector res; - optim::TensorWriteTeller write_teller; - write_teller.Collect(&body); + auto write_teller = ir::CollectTensorNeedsWrite(&body); auto tensors = CollectAllTensorReference(); std::sort(tensors.begin(), @@ -269,7 +267,7 @@ std::vector _LoweredFunc_::CudaAliasVarExprs() const { continue; } Type value_type = tensor->type().ElementOf(); - bool is_const = !write_teller.IsWrite(tensor->name); + bool is_const = !write_teller.count(tensor->name); value_type.set_cpp_handle(); value_type.set_cpp_const(is_const); Var variable = _Var_::Make(tensor->name, value_type); diff --git a/paddle/cinn/lang/lower_impl.cc b/paddle/cinn/lang/lower_impl.cc index f313d52938a93a..46750039b3a0e0 100644 --- a/paddle/cinn/lang/lower_impl.cc +++ b/paddle/cinn/lang/lower_impl.cc @@ -342,8 +342,7 @@ std::vector LowerImpl::GenerateFunctionArgumentList( CheckArgsUnique(); std::vector args; - optim::TensorWriteTeller teller; - teller.Collect(&fn_body); + auto teller = ir::CollectTensorNeedsWrite(&fn_body); std::set arg_names; @@ -358,7 +357,7 @@ std::vector LowerImpl::GenerateFunctionArgumentList( for (auto& tensor : tensor_args_) { auto* tensor_node = tensor.As(); - bool is_output = teller.IsWrite(tensor->name); + bool is_output = teller.count(tensor->name); VLOG(1) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; @@ -396,8 +395,7 @@ std::vector LowerImpl::GenFuncArgForSplitKernel( std::vector in_args; std::vector out_args; - optim::TensorWriteTeller teller; - teller.Collect(&func_iterator); + auto teller = ir::CollectTensorNeedsWrite(&func_iterator); std::set arg_names; std::set all_tensor_names; @@ -448,7 +446,7 @@ std::vector LowerImpl::GenFuncArgForSplitKernel( VLOG(3) << "In tensor_args_, it has : " << tensor->name; if (temp_tensor_names.count(tensor->name) > 0) continue; if (all_tensor_names.count(tensor->name) == 0) continue; - bool is_output = teller.IsWrite(tensor->name); + bool is_output = teller.count(tensor->name); VLOG(3) << "tensor argument " << tensor->name << " buffer " << tensor->buffer->name; @@ -485,7 +483,7 @@ std::vector LowerImpl::GenFuncArgForSplitKernel( VLOG(3) << "Tensor " << tensor->name; if (tensor->buffer.defined() && !arg_names.count(tensor->buffer->name)) { bool is_output = - teller.IsWrite(tensor->name) && teller.IsWrite(tensor->name); + teller.count(tensor->name) && teller.count(tensor->name); if (is_output) out_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput); } From b1618f03d24c2e46a291c886b4678965153170c1 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Wed, 6 Sep 2023 07:21:26 +0000 Subject: [PATCH 3/7] fix unique target find --- paddle/cinn/ir/utils/ir_nodes_collector.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/ir/utils/ir_nodes_collector.cc b/paddle/cinn/ir/utils/ir_nodes_collector.cc index c4ae3d0c8e0e0e..afca8fa4934239 100644 --- a/paddle/cinn/ir/utils/ir_nodes_collector.cc +++ b/paddle/cinn/ir/utils/ir_nodes_collector.cc @@ -221,7 +221,7 @@ std::set CollectTensorNeedsWrite(const Expr* e) { } return false; }; - IrNodesCollector collector(std::move(teller), std::move(handler), true); + IrNodesCollector collector(std::move(teller), std::move(handler), false); collector.Visit(e); return tensor_written; } From 9432d0e81b9b346fb1cbbc4d8b248c4a3aaa6b44 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Wed, 6 Sep 2023 09:16:04 +0000 Subject: [PATCH 4/7] fix bug --- paddle/cinn/optim/vectorize_loops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/cinn/optim/vectorize_loops.cc b/paddle/cinn/optim/vectorize_loops.cc index 3b7202df83dc46..25a543788e4b4d 100644 --- a/paddle/cinn/optim/vectorize_loops.cc +++ b/paddle/cinn/optim/vectorize_loops.cc @@ -289,7 +289,7 @@ class CudaVectorizer : public IRMutator { const std::vector &indices, bool is_store) { auto *node = tensor.As(); - bool is_const = !write_teller_.IsWrite(node->name); + bool is_const = !write_teller_.count(node->name); // generate the corresponding vector type Type scalar_type = tensor->type().ElementOf(); From 09f4509ac0be6f3ca1278a505683e07258362a7f Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Fri, 8 Sep 2023 03:03:22 +0000 Subject: [PATCH 5/7] delete tensor_write file --- paddle/cinn/ir/utils/ir_nodes_collector.cc | 8 ++- paddle/cinn/optim/tensor_write_tell.cc | 19 ------- paddle/cinn/optim/tensor_write_tell.h | 58 ---------------------- 3 files changed, 7 insertions(+), 78 deletions(-) delete mode 100644 paddle/cinn/optim/tensor_write_tell.cc delete mode 100644 paddle/cinn/optim/tensor_write_tell.h diff --git a/paddle/cinn/ir/utils/ir_nodes_collector.cc b/paddle/cinn/ir/utils/ir_nodes_collector.cc index afca8fa4934239..feaae1a7f4bc54 100644 --- a/paddle/cinn/ir/utils/ir_nodes_collector.cc +++ b/paddle/cinn/ir/utils/ir_nodes_collector.cc @@ -210,7 +210,13 @@ std::set CollectReferencedTensors( std::set CollectTensorNeedsWrite(const Expr* e) { std::set tensor_written; IrNodesCollector::handler_t handler = [&](const Expr* x) { - tensor_written.insert(x->As()->name); + if (x->As()) { + tensor_written.insert( + x->As()->tensor.As()->name); + } + if (x->As()) { + tensor_written.insert(x->As()->name); + } }; IrNodesCollector::teller_t teller = [](const Expr* x) { if (x->As() && x->As()->tensor.As()) { diff --git a/paddle/cinn/optim/tensor_write_tell.cc b/paddle/cinn/optim/tensor_write_tell.cc deleted file mode 100644 index 9f0f5747c3f3da..00000000000000 --- a/paddle/cinn/optim/tensor_write_tell.cc +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/cinn/optim/tensor_write_tell.h" - -namespace cinn { -namespace optim {} // namespace optim -} // namespace cinn diff --git a/paddle/cinn/optim/tensor_write_tell.h b/paddle/cinn/optim/tensor_write_tell.h deleted file mode 100644 index f8ee114561a302..00000000000000 --- a/paddle/cinn/optim/tensor_write_tell.h +++ /dev/null @@ -1,58 +0,0 @@ -// Copyright (c) 2021 CINN Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once -#include -#include - -#include "paddle/cinn/ir/ir.h" -#include "paddle/cinn/ir/utils/ir_mutator.h" - -namespace cinn { -namespace optim { - -struct TensorWriteTeller : public ir::IRMutator { - //! Collect the write info in \p op. - void Collect(const Expr* op) { Visit(op, op); } - - bool IsWrite(const std::string& tensor_name) const { - return tensor_written.count(tensor_name); - } - - private: - std::set tensor_written; - - void Visit(const Expr* expr, const Expr* op) override { - IRMutator::Visit(expr, op); - } - - void Visit(const ir::Store* expr, const Expr* op) override { - auto* node = op->As(); - CHECK(node); - auto* tensor = node->tensor.As(); - CHECK(tensor); - tensor_written.insert(tensor->name); - IRMutator::Visit(expr, op); - } - - void Visit(const ir::_Tensor_* op, const Expr* expr) override { - auto* node = expr->As(); - if (node->is_call_node()) { - tensor_written.insert(node->name); - } - } -}; - -} // namespace optim -} // namespace cinn From 49dc30253c0ffa04c0244805f95b591a8a8baa54 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Fri, 8 Sep 2023 03:51:19 +0000 Subject: [PATCH 6/7] fix bugs about delete tensor_write --- paddle/cinn/ir/lowered_func.cc | 1 - paddle/cinn/lang/lower_impl.h | 1 - paddle/cinn/optim/CMakeLists.txt | 1 - paddle/cinn/optim/vectorize_loops.cc | 1 - 4 files changed, 4 deletions(-) diff --git a/paddle/cinn/ir/lowered_func.cc b/paddle/cinn/ir/lowered_func.cc index 77243cfa83aaa6..5a897e7c334a59 100644 --- a/paddle/cinn/ir/lowered_func.cc +++ b/paddle/cinn/ir/lowered_func.cc @@ -27,7 +27,6 @@ #include "paddle/cinn/ir/buffer.h" #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/ir/utils/ir_visitor.h" -#include "paddle/cinn/optim/tensor_write_tell.h" #include "paddle/cinn/runtime/intrinsic.h" #include "paddle/cinn/utils/string.h" diff --git a/paddle/cinn/lang/lower_impl.h b/paddle/cinn/lang/lower_impl.h index 3e52279b19566a..925050a03e35d3 100644 --- a/paddle/cinn/lang/lower_impl.h +++ b/paddle/cinn/lang/lower_impl.h @@ -34,7 +34,6 @@ #include "paddle/cinn/optim/optimize.h" #include "paddle/cinn/optim/remove_nested_block.h" #include "paddle/cinn/optim/replace_call_with_expr.h" -#include "paddle/cinn/optim/tensor_write_tell.h" #include "paddle/cinn/optim/transform_gpu_forloop.h" #include "paddle/cinn/optim/transform_polyfor_to_for.h" #include "paddle/cinn/poly/ast_gen.h" diff --git a/paddle/cinn/optim/CMakeLists.txt b/paddle/cinn/optim/CMakeLists.txt index 45c38c26327170..45f2fbcdd8327d 100755 --- a/paddle/cinn/optim/CMakeLists.txt +++ b/paddle/cinn/optim/CMakeLists.txt @@ -7,7 +7,6 @@ gather_srcs( replace_call_with_expr.cc ir_replace.cc replace_var_with_expr.cc - tensor_write_tell.cc ir_simplify.cc optimize.cc vectorize_loops.cc diff --git a/paddle/cinn/optim/vectorize_loops.cc b/paddle/cinn/optim/vectorize_loops.cc index 25a543788e4b4d..2f3a9b29a35677 100644 --- a/paddle/cinn/optim/vectorize_loops.cc +++ b/paddle/cinn/optim/vectorize_loops.cc @@ -31,7 +31,6 @@ #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/optim/ir_replace.h" #include "paddle/cinn/optim/ir_simplify.h" -#include "paddle/cinn/optim/tensor_write_tell.h" #include "paddle/cinn/optim/unroll_loops.h" #include "paddle/cinn/utils/functional.h" From 23cda6aa81553e85d9985668d79031a5df0494e8 Mon Sep 17 00:00:00 2001 From: Courtesy-Xs Date: Fri, 8 Sep 2023 06:31:36 +0000 Subject: [PATCH 7/7] add relative headfile for lower_impl.h --- paddle/cinn/lang/lower_impl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/cinn/lang/lower_impl.h b/paddle/cinn/lang/lower_impl.h index 925050a03e35d3..638edd33ba638e 100644 --- a/paddle/cinn/lang/lower_impl.h +++ b/paddle/cinn/lang/lower_impl.h @@ -27,6 +27,7 @@ #include "paddle/cinn/common/graph_utils.h" #include "paddle/cinn/ir/buffer.h" +#include "paddle/cinn/ir/utils/ir_mutator.h" #include "paddle/cinn/ir/utils/ir_printer.h" #include "paddle/cinn/optim/buffer_assign.h" #include "paddle/cinn/optim/compute_inline_expand.h"