Skip to content

Commit

Permalink
[skip ci] use a struct_metal style struct compiler
Browse files Browse the repository at this point in the history
  • Loading branch information
archibate committed Feb 19, 2020
1 parent c6d27ee commit fddb46f
Show file tree
Hide file tree
Showing 11 changed files with 282 additions and 96 deletions.
1 change: 0 additions & 1 deletion cmake/TaichiCore.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ if (TI_WITH_CUDA)
target_link_libraries(${LIBRARY_NAME} ${llvm_ptx_libs})
endif()

#target_link_libraries(${LIBRARY_NAME} /usr/lib/libGL.so GL)
target_link_libraries(${LIBRARY_NAME} /usr/lib/libGLEW.so GLEW)
target_link_libraries(${LIBRARY_NAME} /usr/lib/libglfw.so glfw)

Expand Down
3 changes: 2 additions & 1 deletion python/taichi/core/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ def build():
print('Warning: taichi_core.so already removed. This may be caused by '
'simultaneously starting two taichi instances.')
pass
shutil.copy('libtaichi_core.so', 'taichi_core.so')
shutil.copy('libtaichi_core.so', '/tmp/taichi_core.so')
os.symlink('/tmp/taichi_core.so', 'taichi_core.so')
try:
import_tc_core()
except Exception as e:
Expand Down
133 changes: 43 additions & 90 deletions taichi/backends/codegen_opengl.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "codegen_opengl.h"
#include <taichi/platform/opengl/opengl_api.h>
#include <taichi/platform/opengl/opengl_data_types.h>

#include <string>
#include <taichi/ir.h>
Expand All @@ -8,73 +9,16 @@ TLANG_NAMESPACE_BEGIN
namespace opengl {
namespace {

std::string opengl_data_type_name(DataType dt)
{
switch (dt) {
case DataType::f32:
return "float";
case DataType::i32:
return "int";
case DataType::u32:
return "uint";
default:
TI_NOT_IMPLEMENTED;
break;
}
return "";
}

std::string opengl_unary_op_type_symbol(UnaryOpType type)
{
switch (type)
{
case UnaryOpType::neg:
return "-";
case UnaryOpType::sqrt:
return "sqrt";
case UnaryOpType::floor:
return "floor";
case UnaryOpType::ceil:
return "ceil";
case UnaryOpType::abs:
return "abs";
case UnaryOpType::sgn:
return "sign";
case UnaryOpType::sin:
return "sin";
case UnaryOpType::asin:
return "asin";
case UnaryOpType::cos:
return "cos";
case UnaryOpType::acos:
return "acos";
case UnaryOpType::tan:
return "tan";
case UnaryOpType::tanh:
return "tanh";
case UnaryOpType::exp:
return "exp";
case UnaryOpType::log:
return "log";
default:
TI_NOT_IMPLEMENTED;
}
return "";
}

bool is_opengl_binary_op_infix(BinaryOpType type)
{
return !((type == BinaryOpType::min) || (type == BinaryOpType::max) ||
(type == BinaryOpType::atan2) || (type == BinaryOpType::pow));
}

class KernelGen : public IRVisitor
{
Kernel *kernel;

public:
KernelGen(Kernel *kernel, std::string kernel_name)
: kernel(kernel), kernel_name_(kernel_name),
KernelGen(Kernel *kernel, std::string kernel_name,
const StructCompiledResult *struct_compiled)
: kernel(kernel),
struct_compiled_(struct_compiled),
kernel_name_(kernel_name),
glsl_kernel_prefix_(kernel_name)
{
allow_undefined_visitor = true;
Expand All @@ -86,6 +30,7 @@ class KernelGen : public IRVisitor
std::string indent_;
bool is_top_level_{true};

const StructCompiledResult *struct_compiled_;
const SNode *root_snode_;
GetRootStmt *root_stmt_;
std::string kernel_name_;
Expand Down Expand Up @@ -116,9 +61,8 @@ class KernelGen : public IRVisitor
{
emit("#version 430 core");
emit("#extension GL_ARB_compute_shader: enable");
emit("");
emit("{}", struct_compiled_->source_code);
emit("layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;");
emit("");
emit("layout(std430, binding = 0) buffer data");
emit("{{");
emit(" int _args_[{}];", taichi_max_num_args);
Expand All @@ -129,7 +73,7 @@ class KernelGen : public IRVisitor

void generate_bottom()
{
// TODO: <kernel_name>() really necessary? How about just main()?
// TODO(archibate): <kernel_name>() really necessary? How about just main()?
emit("void main()");
emit("{{");
emit(" {}();", glsl_kernel_name_);
Expand All @@ -156,43 +100,50 @@ class KernelGen : public IRVisitor
emit("const uint {} = {};", stmt->raw_name(), val);
}

void visit(OffsetAndExtractBitsStmt *stmt) override
{
emit("uint {} = ((({} + {}) >> {}) & ((1 << {}) - 1));",
stmt->raw_name(), stmt->offset, stmt->input->raw_name(),
stmt->bit_begin, stmt->bit_end - stmt->bit_begin);
}

void visit(GetRootStmt *stmt) override
{
// Should we assert |root_stmt_| is assigned only once?
root_stmt_ = stmt;
emit("const uint {} = 0;", stmt->raw_name());
emit("{} {} = 0;", root_snode_type_name_, stmt->raw_name());
}

void visit(SNodeLookupStmt *stmt) override
{
std::string parent;
Stmt *parent;
std::string parent_type;
if (stmt->input_snode) {
parent = stmt->input_snode->raw_name();
parent = stmt->input_snode;
parent_type = stmt->snode->node_type_name;
} else {
TI_ASSERT(root_stmt_ != nullptr);
parent = root_stmt_->raw_name();
parent = root_stmt_;
parent_type = root_snode_type_name_;
}

int stride = 1; // XXX
emit("const uint {} = {} + {} * {};",
stmt->raw_name(), parent, stride, stmt->input_index->raw_name());
}

void visit(OffsetAndExtractBitsStmt *stmt) override
{
emit("uint {} = ((({} + {}) >> {}) & ((1 << {}) - 1));",
stmt->raw_name(), stmt->offset, stmt->input->raw_name(),
stmt->bit_begin, stmt->bit_end - stmt->bit_begin);
emit("{}_ch {} = {}_children({}, {});", stmt->snode->node_type_name,
stmt->raw_name(), parent_type, parent->raw_name(),
stmt->input_index->raw_name());
}

void visit(GetChStmt *stmt) override
{
if (stmt->output_snode->is_place()) {
emit("const uint {} = {} + 1 * {}; // placed",
stmt->raw_name(), stmt->input_ptr->raw_name(), stmt->chid);
emit("{} /* place {} */ {} = {}_get{}({});",
stmt->output_snode->node_type_name,
opengl_data_type_name(stmt->output_snode->dt),
stmt->raw_name(), stmt->input_snode->node_type_name,
stmt->chid, stmt->input_ptr->raw_name());
} else {
emit("const uint {} = {} + 1 * {};",
stmt->raw_name(), stmt->input_ptr->raw_name(), stmt->chid);
emit("{} {} = {}_get{}({});", stmt->output_snode->node_type_name,
stmt->raw_name(), stmt->input_snode->node_type_name,
stmt->chid, stmt->input_ptr->raw_name());
}
}

Expand Down Expand Up @@ -299,7 +250,8 @@ class KernelGen : public IRVisitor
const_stmt->raw_name(), const_stmt->val[0].stringify());
}

void visit(ArgLoadStmt *stmt) override {
void visit(ArgLoadStmt *stmt) override
{
const auto dt = opengl_data_type_name(stmt->element_type());
if (stmt->is_ptr) {
emit("const {} {} = _args_[{}]; // is_ptr", dt, stmt->raw_name(), stmt->arg_id);
Expand All @@ -308,7 +260,8 @@ class KernelGen : public IRVisitor
}
}

void visit(ArgStoreStmt *stmt) override {
void visit(ArgStoreStmt *stmt) override
{
const auto dt = metal_data_type_name(stmt->element_type());
TI_ASSERT(!stmt->is_ptr);
emit("_args_[{}] = {};", stmt->arg_id, stmt->val->raw_name());
Expand Down Expand Up @@ -360,7 +313,7 @@ class KernelGen : public IRVisitor
root_snode_ = &root_snode;
root_snode_type_name_ = root_snode.node_type_name;
generate_header();
irpass::print(kernel->ir);
//irpass::print(kernel->ir);
kernel->ir->accept(this);
generate_bottom();
}
Expand Down Expand Up @@ -419,7 +372,7 @@ void OpenglCodeGen::lower()

if (kernel_->grad) {
irpass::demote_atomics(ir);
irpass::full_simplify(ir);
irpass::full_simplify(ir, prog_->config);
irpass::typecheck(ir);
if (print_ir) {
TI_TRACE("Before make_adjoint:");
Expand Down Expand Up @@ -469,7 +422,7 @@ void OpenglCodeGen::lower()
irpass::print(ir);
}

irpass::full_simplify(ir);
irpass::full_simplify(ir, prog_->config);
if (print_ir) {
TI_TRACE("Simplified II:");
irpass::re_id(ir);
Expand Down Expand Up @@ -504,10 +457,10 @@ void save_data(Context &ctx, void *data)

FunctionType OpenglCodeGen::gen(void)
{
KernelGen codegen(kernel_, kernel_name_);
KernelGen codegen(kernel_, kernel_name_, struct_compiled_);
codegen.run(*prog_->snode_root);
const std::string kernel_source_code = codegen.kernel_source_code();
TI_INFO("\n{}", kernel_source_code);
//TI_INFO("\n{}", kernel_source_code);

return [kernel_source_code](Context &ctx) {
void *data, *data_r;
Expand Down
6 changes: 4 additions & 2 deletions taichi/backends/codegen_opengl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ namespace opengl {

class OpenglCodeGen {
public:
OpenglCodeGen(const std::string &kernel_name)
: kernel_name_(kernel_name)
OpenglCodeGen(const std::string &kernel_name,
const StructCompiledResult *struct_compiled)
: kernel_name_(kernel_name), struct_compiled_(struct_compiled)
{}

FunctionType compile(Program &program, Kernel &kernel);
Expand All @@ -31,6 +32,7 @@ class OpenglCodeGen {

Program *prog_;
Kernel *kernel_;
const StructCompiledResult *struct_compiled_;
size_t global_tmps_buffer_size_{0};
};

Expand Down
91 changes: 91 additions & 0 deletions taichi/backends/struct_opengl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include "struct_opengl.h"

TLANG_NAMESPACE_BEGIN
namespace opengl {

OpenglStructCompiler::CompiledResult OpenglStructCompiler::run(SNode &node)
{
TI_ASSERT(node.type == SNodeType::root);
collect_snodes(node);
// The host side has run this!
// infer_snode_properties(node);

auto snodes_rev = snodes_;
std::reverse(snodes_rev.begin(), snodes_rev.end());

for (auto &n : snodes_rev) {
generate_types(*n);
}
CompiledResult result;
result.source_code = std::move(src_code_);
result.root_size = compute_snode_size(node);
return result;
}

void OpenglStructCompiler::collect_snodes(SNode &snode) {
snodes_.push_back(&snode);
for (int ch_id = 0; ch_id < (int)snode.ch.size(); ch_id++) {
auto &ch = snode.ch[ch_id];
collect_snodes(*ch);
}
}
// TODO(archibate): really need fit struct_metal so much?
void OpenglStructCompiler::generate_types(const SNode &snode) {
const bool is_place = snode.is_place();
if (!is_place) {
const std::string class_name = snode.node_type_name + "_ch";
emit("#define {} uint", class_name);
std::string stride_str;
for (int i = 0; i < (int)snode.ch.size(); i++) {
const auto &ch_node_name = snode.ch[i]->node_type_name;
if (stride_str.empty()) {
emit("#define {}_get{}(a_) (a_) // {}",
snode.node_type_name, i, ch_node_name);
stride_str = ch_node_name + "_stride";
} else {
emit("#define {}_get{}(a_) ((a_) + ({})) // {}",
snode.node_type_name, i, stride_str, ch_node_name);
stride_str += " + " + ch_node_name + "_stride";
}
}
if (stride_str.empty()) {
// Is it possible for this to have no children?
stride_str = "0";
}
emit("#define {}_stride ({})", class_name, stride_str);
}
emit("");
const auto &node_name = snode.node_type_name;
if (is_place) {
const auto dt_name = opengl_data_type_name(snode.dt);
emit("#define {} uint // place {}", node_name, dt_name);
emit("#define {}_stride {} // sizeof({})", node_name, data_type_size(snode.dt), dt_name);
} else if (snode.type == SNodeType::dense || snode.type == SNodeType::root) {
emit("#define {} uint // {}", node_name, snode_type_name(snode.type));
const int n = (snode.type == SNodeType::dense) ? snode.n : 1;
emit("#define {}_n {}", node_name, n);
emit("#define {}_stride ({}_ch_stride * {}_n)", node_name, node_name, node_name);
emit("#define {}_children(a_, i) ((a_) + {}_ch_stride * (i))", node_name, node_name);
} else {
TI_ERROR("SNodeType={} not supported on OpenGL",
snode_type_name(snode.type));
TI_NOT_IMPLEMENTED;
}
emit("");
}

size_t OpenglStructCompiler::compute_snode_size(const SNode &sn) {
if (sn.is_place()) {
return data_type_size(sn.dt);
}
size_t ch_size = 0;
for (const auto &ch : sn.ch) {
ch_size += compute_snode_size(*ch);
}
const int n = (sn.type == SNodeType::dense) ? sn.n : 1;
return n * ch_size;
}


} // namespace opengl
TLANG_NAMESPACE_END
Loading

0 comments on commit fddb46f

Please sign in to comment.