Skip to content

Commit

Permalink
【CINN】Rewrite tensor_writer_teller (#57019)
Browse files Browse the repository at this point in the history
* rewrite tensor writer

* add modification

* fix unique target find

* fix bug

* delete tensor_write file

* fix bugs about delete tensor_write

* add relative headfile for lower_impl.h
  • Loading branch information
Courtesy-Xs authored Sep 12, 2023
1 parent 6d936f9 commit 3a34a9c
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 97 deletions.
11 changes: 4 additions & 7 deletions paddle/cinn/ir/lowered_func.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"

Expand Down Expand Up @@ -209,8 +208,7 @@ void _LoweredFunc_::AllocTempBuffer() {}
void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) {
buffer_data_cast_exprs.clear();
// collect write.
optim::TensorWriteTeller write_teller;
write_teller.Collect(&body);
auto write_teller = ir::CollectTensorNeedsWrite(&body);

auto tensors = CollectAllTensorReference(with_expr_gen_tensor);
std::sort(tensors.begin(),
Expand All @@ -224,7 +222,7 @@ void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) {
if (!tensor->buffer.defined()) continue;

Type value_type = tensor->type().ElementOf();
bool is_const = !write_teller.IsWrite(tensor->name);
bool is_const = !write_teller.count(tensor->name);
value_type.set_cpp_handle();
value_type.set_cpp_const(is_const);
Var variable = _Var_::Make(tensor->name, value_type);
Expand All @@ -250,8 +248,7 @@ std::vector<Expr> _LoweredFunc_::CudaAliasVarExprs() const {
}
// collect write.
std::vector<Expr> res;
optim::TensorWriteTeller write_teller;
write_teller.Collect(&body);
auto write_teller = ir::CollectTensorNeedsWrite(&body);

auto tensors = CollectAllTensorReference();
std::sort(tensors.begin(),
Expand All @@ -269,7 +266,7 @@ std::vector<Expr> _LoweredFunc_::CudaAliasVarExprs() const {
continue;
}
Type value_type = tensor->type().ElementOf();
bool is_const = !write_teller.IsWrite(tensor->name);
bool is_const = !write_teller.count(tensor->name);
value_type.set_cpp_handle();
value_type.set_cpp_const(is_const);
Var variable = _Var_::Make(tensor->name, value_type);
Expand Down
25 changes: 25 additions & 0 deletions paddle/cinn/ir/utils/ir_nodes_collector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,30 @@ std::set<Expr> CollectReferencedTensors(
return ts0;
}

std::set<std::string> CollectTensorNeedsWrite(const Expr* e) {
std::set<std::string> tensor_written;
IrNodesCollector::handler_t handler = [&](const Expr* x) {
if (x->As<ir::Store>()) {
tensor_written.insert(
x->As<ir::Store>()->tensor.As<ir::_Tensor_>()->name);
}
if (x->As<ir::_Tensor_>()) {
tensor_written.insert(x->As<ir::_Tensor_>()->name);
}
};
IrNodesCollector::teller_t teller = [](const Expr* x) {
if (x->As<ir::Store>() && x->As<ir::Store>()->tensor.As<ir::_Tensor_>()) {
return true;
}
if (x->As<ir::_Tensor_>() && x->As<ir::_Tensor_>()->is_call_node()) {
return true;
}
return false;
};
IrNodesCollector collector(std::move(teller), std::move(handler), false);
collector.Visit(e);
return tensor_written;
}

} // namespace ir
} // namespace cinn
2 changes: 2 additions & 0 deletions paddle/cinn/ir/utils/ir_nodes_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,5 +65,7 @@ std::map<std::string, Expr> CollectTensorMap(
return true;
});

std::set<std::string> CollectTensorNeedsWrite(const Expr* e);

} // namespace ir
} // namespace cinn
12 changes: 5 additions & 7 deletions paddle/cinn/lang/lower_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,7 @@ std::vector<ir::Argument> LowerImpl::GenerateFunctionArgumentList(
CheckArgsUnique();

std::vector<ir::Argument> args;
optim::TensorWriteTeller teller;
teller.Collect(&fn_body);
auto teller = ir::CollectTensorNeedsWrite(&fn_body);

std::set<std::string> arg_names;

Expand All @@ -358,7 +357,7 @@ std::vector<ir::Argument> LowerImpl::GenerateFunctionArgumentList(

for (auto& tensor : tensor_args_) {
auto* tensor_node = tensor.As<ir::_Tensor_>();
bool is_output = teller.IsWrite(tensor->name);
bool is_output = teller.count(tensor->name);
VLOG(1) << "tensor argument " << tensor->name << " buffer "
<< tensor->buffer->name;

Expand Down Expand Up @@ -396,8 +395,7 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(

std::vector<ir::Argument> in_args;
std::vector<ir::Argument> out_args;
optim::TensorWriteTeller teller;
teller.Collect(&func_iterator);
auto teller = ir::CollectTensorNeedsWrite(&func_iterator);
std::set<std::string> arg_names;
std::set<std::string> all_tensor_names;

Expand Down Expand Up @@ -448,7 +446,7 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(
VLOG(3) << "In tensor_args_, it has : " << tensor->name;
if (temp_tensor_names.count(tensor->name) > 0) continue;
if (all_tensor_names.count(tensor->name) == 0) continue;
bool is_output = teller.IsWrite(tensor->name);
bool is_output = teller.count(tensor->name);
VLOG(3) << "tensor argument " << tensor->name << " buffer "
<< tensor->buffer->name;

Expand Down Expand Up @@ -485,7 +483,7 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(
VLOG(3) << "Tensor " << tensor->name;
if (tensor->buffer.defined() && !arg_names.count(tensor->buffer->name)) {
bool is_output =
teller.IsWrite(tensor->name) && teller.IsWrite(tensor->name);
teller.count(tensor->name) && teller.count(tensor->name);
if (is_output)
out_args.emplace_back(tensor->buffer, ir::Argument::IO::kOutput);
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/lang/lower_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@

#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/buffer_assign.h"
#include "paddle/cinn/optim/compute_inline_expand.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/optim/replace_call_with_expr.h"
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/poly/ast_gen.h"
Expand Down
1 change: 0 additions & 1 deletion paddle/cinn/optim/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ gather_srcs(
replace_call_with_expr.cc
ir_replace.cc
replace_var_with_expr.cc
tensor_write_tell.cc
ir_simplify.cc
optimize.cc
vectorize_loops.cc
Expand Down
19 changes: 0 additions & 19 deletions paddle/cinn/optim/tensor_write_tell.cc

This file was deleted.

58 changes: 0 additions & 58 deletions paddle/cinn/optim/tensor_write_tell.h

This file was deleted.

7 changes: 3 additions & 4 deletions paddle/cinn/optim/vectorize_loops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_replace.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/optim/unroll_loops.h"
#include "paddle/cinn/utils/functional.h"

Expand Down Expand Up @@ -185,7 +184,7 @@ class CudaVectorizer : public IRMutator<Expr *> {
const Var iter_var_; // the loop var of the vecotrized loop
const int factor_; // the factor for vectorize

TensorWriteTeller write_teller_;
std::set<std::string> write_teller_;
TensorVectorizeTeller vectorized_teller_;

absl::flat_hash_map<std::string, Var> tensor2vectorized_vars_;
Expand Down Expand Up @@ -215,7 +214,7 @@ class CudaVectorizer : public IRMutator<Expr *> {
}

void Visit(Expr *expr) {
write_teller_.Collect(expr);
write_teller_ = ir::CollectTensorNeedsWrite(expr);
vectorized_teller_.Collect(expr);
IRMutator<Expr *>::Visit(expr, expr);
}
Expand Down Expand Up @@ -289,7 +288,7 @@ class CudaVectorizer : public IRMutator<Expr *> {
const std::vector<Expr> &indices,
bool is_store) {
auto *node = tensor.As<ir::_Tensor_>();
bool is_const = !write_teller_.IsWrite(node->name);
bool is_const = !write_teller_.count(node->name);

// generate the corresponding vector type
Type scalar_type = tensor->type().ElementOf();
Expand Down

0 comments on commit 3a34a9c

Please sign in to comment.