Skip to content

Commit

Permalink
[CINN] Make New ASTGen Without ISL and Stage (#57207)
Browse files Browse the repository at this point in the history
Make art_gen_ius::AstGen as replacement of old isl AstGen, art_gen_ius::TensorGroup as a replacement of tensor relation data structure as old Stage and StageMap, however we remove old style schedule function of Stage.
  • Loading branch information
zhhsplendid authored Sep 19, 2023
1 parent 0eaf59a commit a79ba09
Show file tree
Hide file tree
Showing 12 changed files with 333 additions and 89 deletions.
112 changes: 96 additions & 16 deletions paddle/cinn/ast_gen_ius/ast_gen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,114 @@
#include "paddle/cinn/ir/operation.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"

namespace cinn {
namespace ast_gen_ius {

ir::Expr AstGen::Build(const ir::Tensor& tensor) {
ir::Expr ConvertReduceBody(ir::Expr body,
ir::Tensor tensor,
const std::vector<Expr>& axis_exprs) {
ir::Reduce* reduce_node = body.As<ir::Reduce>();
if (!reduce_node) {
return ir::Store::Make(tensor, body, axis_exprs);
}

switch (reduce_node->reduce_type) {
case ir::Reduce::kSum:
return ir::Store::Make(
tensor, tensor(axis_exprs) + reduce_node->body, axis_exprs);
case ir::Reduce::kMul:
return ir::Store::Make(
tensor, tensor(axis_exprs) * reduce_node->body, axis_exprs);
case ir::Reduce::kMax:
return ir::Store::Make(
tensor,
ir::Max::Make(tensor(axis_exprs), reduce_node->body),
axis_exprs);
case ir::Reduce::kMin:
return ir::Store::Make(
tensor,
ir::Min::Make(tensor(axis_exprs), reduce_node->body),
axis_exprs);
case ir::Reduce::kAll:
return ir::Store::Make(
tensor, tensor(axis_exprs) && reduce_node->body, axis_exprs);
case ir::Reduce::kAny:
return ir::Store::Make(
tensor, tensor(axis_exprs) || reduce_node->body, axis_exprs);
default:
CINN_NOT_IMPLEMENTED
}
}

ir::Expr AstGen::Build(const ir::Tensor& tensor, TensorGroup* tensor_group) {
const std::vector<ir::Var>& axis = tensor->axis();
const std::vector<ir::Expr>& shape = tensor->shape;
size_t axis_len = axis.size();
CHECK_EQ(shape.size(), axis_len)
<< "Internal Error: Tensor has different shape and axis length in AstGen";

CHECK_EQ(shape.size(), axis_len) << "Internal Error: Tensor has different "
"shape and axis length in AstGen";
std::vector<ir::Expr> axis_exprs;
for (const auto& a : axis) {
axis_exprs.push_back(a);
}
ir::Expr body = ir::Store::Make(tensor, tensor->body(), axis_exprs);

for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
ir::Var loop_var = axis[i];
ir::Expr loop_extent = shape[i];
body = ir::For::Make(loop_var,
Expr(0),
loop_extent,
ir::ForType::Serial,
ir::DeviceAPI::Host,
ir::Block::Make({body}));

if (tensor->is_reduce_tensor()) {
// Make an init Tensor for domain without reduce axis
Expr init_value = tensor->GetReduceInitVal();
// TODO(zhhsplendid): Clean the handcoded "__reduce_init" string
std::string reduce_init_name = tensor->name + "__reduce_init";
const std::vector<Expr>& domain = tensor->domain_without_reduce_axis();
ir::Tensor init_tensor = lang::Compute(
domain,
[=](const std::vector<Expr>& axis) { return init_value; },
reduce_init_name);
tensor_group->Insert(init_tensor);
tensor_group->MarkShareMemBuffer(tensor, init_tensor);
tensor_group->CtrlDepend(tensor, init_tensor);
Expr init_body = ir::Store::Make(init_tensor, init_value, axis_exprs);

// For the remaining reduce axis, make reduce body
const std::vector<ir::Var>& reduce_axis = tensor->reduce_axis;
ir::Expr reduce_body =
ConvertReduceBody(tensor->body(), tensor, axis_exprs);
for (int i = static_cast<int>(reduce_axis.size()) - 1; i >= 0; --i) {
reduce_body = ir::For::Make(reduce_axis[i],
reduce_axis[i]->lower_bound,
reduce_axis[i]->upper_bound,
ir::ForType::Serial,
ir::DeviceAPI::Host,
ir::Block::Make({reduce_body}));
}

// Put the two parts together
ir::Expr body = ir::Block::Make({init_body, reduce_body});
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
ir::Var loop_var = axis[i];
ir::Expr loop_extent = shape[i];
body = ir::For::Make(
loop_var,
Expr(0),
loop_extent,
ir::ForType::Serial,
ir::DeviceAPI::Host,
i == static_cast<int>(axis_len) - 1 ? body : ir::Block::Make({body}));
}
return body;
} else {
ir::Expr body = ir::Store::Make(tensor, tensor->body(), axis_exprs);
for (int i = static_cast<int>(axis_len) - 1; i >= 0; --i) {
ir::Var loop_var = axis[i];
ir::Expr loop_extent = shape[i];
body = ir::For::Make(loop_var,
Expr(0),
loop_extent,
ir::ForType::Serial,
ir::DeviceAPI::Host,
ir::Block::Make({body}));
}
return body;
}
return body;
}

} // namespace ast_gen_ius
Expand Down
2 changes: 1 addition & 1 deletion paddle/cinn/ast_gen_ius/ast_gen.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace ast_gen_ius {

class AstGen {
public:
static ir::Expr Build(const ir::Tensor& tensor);
static ir::Expr Build(const ir::Tensor& tensor, TensorGroup* tensor_group);
};

} // namespace ast_gen_ius
Expand Down
4 changes: 3 additions & 1 deletion paddle/cinn/ast_gen_ius/ast_gen_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <vector>

#include "paddle/cinn/ast_gen_ius/ast_gen.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/tensor.h"
Expand All @@ -36,7 +37,8 @@ TEST(AstGen, Build) {
shape,
[&](const std::vector<Expr>& indice) { return lang::Relu(A(indice), 0); },
"relu_test");
Expr out = AstGen::Build(B);
TensorGroup tensor_group({B});
Expr out = AstGen::Build(B, &tensor_group);
LOG(INFO) << out;
}

Expand Down
27 changes: 10 additions & 17 deletions paddle/cinn/ast_gen_ius/tensor_group.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,12 @@ std::vector<ir::Tensor> TensorGroup::GetGenFuncTopoOrder(
}

std::vector<ir::Tensor> ret;
std::vector<std::string> stack;

// Using set instead of vector/stack in order to get fix alaphbeta order topo
std::set<std::string> node_set;
for (const auto& name_tensor : name_to_tensor_) {
if (!in_degree.count(name_tensor.first)) {
stack.emplace_back(name_tensor.first);
node_set.insert(name_tensor.first);
}
}

Expand All @@ -90,9 +92,9 @@ std::vector<ir::Tensor> TensorGroup::GetGenFuncTopoOrder(
input_arg_names.erase(name);
}

while (!stack.empty()) {
const std::string& cur = stack.back();
stack.pop_back();
while (!node_set.empty()) {
const std::string cur = *(node_set.begin());
node_set.erase(node_set.begin());

if (!input_arg_names.count(cur)) {
ret.push_back(name_to_tensor_[cur]);
Expand All @@ -103,23 +105,14 @@ std::vector<ir::Tensor> TensorGroup::GetGenFuncTopoOrder(
if (dep_tensor_names.count(cur)) {
--in_degree[dep_pair.first];
if (in_degree[dep_pair.first] == 0) {
stack.emplace_back(dep_pair.first);
node_set.insert(dep_pair.first);
}
}
}
}
return ret;
}

bool TensorGroup::HasMarkedReduceInit(const std::string& tensor_name) const {
return tensor_name_needs_reduce_init_.count(tensor_name);
}

ir::Tensor TensorGroup::MarkReduceInit(const std::string& tensor_name) {
// TODO(zhhsplendid): add check
tensor_name_needs_reduce_init_.insert(tensor_name);
}

void TensorGroup::CtrlDepend(const ir::Tensor& tensor,
const ir::Tensor& to_dep) {
ctrl_dep_[tensor->name].insert(to_dep->name);
Expand Down Expand Up @@ -156,8 +149,8 @@ std::string TensorGroup::GetShareMemRootName(const std::string& tensor_name) {
return share_memory_tensor_[tensor_name];
}

void TensorGroup::ShareMemoryBuffer(const ir::Tensor& tensor,
const ir::Tensor& to_share) {
void TensorGroup::MarkShareMemBuffer(const ir::Tensor& tensor,
const ir::Tensor& to_share) {
share_memory_tensor_[GetShareMemRootName(to_share->name)] =
GetShareMemRootName(tensor->name);
}
Expand Down
71 changes: 56 additions & 15 deletions paddle/cinn/ast_gen_ius/tensor_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,54 +28,95 @@
namespace cinn {
namespace ast_gen_ius {

/* Collection used for Tensors, used in AST generation */
/**
* Collection which maintains the relation between Tensor(s) such as control
* dependency, memory sharing ... it is used in AST generation
*/
class TensorGroup {
public:
/**
* Constructor for a TensorGroup, the argument tensors should be output tensor
* arguments of the AST body to be generated. The dependent tensors of the
* output tensors will be collected during construction.
*/
explicit TensorGroup(const std::vector<ir::Tensor>& tensors);

/**
* Destructor.
*/
~TensorGroup();

/**
* Returns true if TensorGroup collection contains a tensor with input name.
*/
bool Contain(const std::string& name) const;

/**
* Insert a Tensor into TensorGroup collection.
*/
void Insert(const ir::Tensor& tensor);

/**
* Returns the Tensor in TensorGroup collection with the given name.
*/
ir::Tensor Get(const std::string& name);

/**
* Returns all Tensors in TensorGroup.
*/
std::set<ir::Tensor> GetAllTensors();

/**
* Mark `tensor` depends on `to_dep`.
*/
void CtrlDepend(const ir::Tensor& tensor, const ir::Tensor& to_dep);

/**
* Get all tensors which the tensor with given name depends on.
*/
std::set<ir::Tensor> GetCrtlDepTensors(const std::string& tensor_name);

/**
* Get Union-Find set algorithm root tensor name which shares memory with the
* tensor whose name is the input.
*/
std::string GetShareMemRootName(const std::string& tensor_name);

void ShareMemoryBuffer(const ir::Tensor& tensor, const ir::Tensor& to_share);
/**
* Mark two tensors share memory, it only marks using Union-Find set
* algorithm, doesn't do really memory sharing/allocation
*/
void MarkShareMemBuffer(const ir::Tensor& tensor, const ir::Tensor& to_share);

/**
* Allocate buffers for Tensors in TensorGroup, it handles the shared memory
* using Union-Find set algorithm.
*/
absl::flat_hash_map<std::string, ir::Tensor> AllocateBuffers();

// Returns tensors in topological order and remove those args
// Becuase the order is used for generating function body, we don't have to
// generate args
/**
* Returns tensors in topological order and remove those args
* Becuase the order is used for generating function body, we don't have to
* generate args
*/
std::vector<ir::Tensor> GetGenFuncTopoOrder(
const std::vector<ir::Tensor>& func_args = {});

bool HasMarkedReduceInit(const std::string& tensor_name) const;

// Marks a tensor needs to do reduce init
ir::Tensor MarkReduceInit(const std::string& tensor_name);

private:
/** collection of output tensor names */
std::set<std::string> output_tensor_names_;

/** collection of all tensors in this TensorGroup */
absl::flat_hash_map<std::string, ir::Tensor> name_to_tensor_;

// Stores vector of tensor names, which the key tensor depends on
/** Stores vector of tensor names, which the key tensor depends on */
std::unordered_map<std::string, std::unordered_set<std::string>> ctrl_dep_;

// Keeps Union Find Set style, each tensor name whose buffer is shared maps to
// the same name tensor
/**
* Keeps Union Find Set style, each tensor name whose buffer is shared, maps
* to the same name tensor.
*/
std::unordered_map<std::string, std::string> share_memory_tensor_;

std::unordered_set<std::string> tensor_name_needs_reduce_init_;
};

} // namespace ast_gen_ius
Expand Down
Loading

0 comments on commit a79ba09

Please sign in to comment.