Skip to content

Commit

Permalink
[type] [refactor] Remove DataType::data_type (#1960)
Browse files Browse the repository at this point in the history
  • Loading branch information
taichi-gardener authored Oct 15, 2020
1 parent 6b3db04 commit 3c2e5e9
Show file tree
Hide file tree
Showing 23 changed files with 176 additions and 195 deletions.
8 changes: 3 additions & 5 deletions taichi/backends/cc/codegen_cc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ class CCTransformer : public IRVisitor {

if (std::holds_alternative<Stmt *>(content)) {
auto arg_stmt = std::get<Stmt *>(content);
format += data_type_format(arg_stmt->ret_type.data_type);
format += data_type_format(arg_stmt->ret_type);
values.push_back(arg_stmt->raw_name());

} else {
Expand Down Expand Up @@ -527,10 +527,8 @@ class CCTransformer : public IRVisitor {
}

void visit(RandStmt *stmt) override {
auto var = define_var(cc_data_type_name(stmt->ret_type.data_type),
stmt->raw_name());
emit("{} = Ti_rand_{}();", var,
data_type_short_name(stmt->ret_type.data_type));
auto var = define_var(cc_data_type_name(stmt->ret_type), stmt->raw_name());
emit("{} = Ti_rand_{}();", var, data_type_short_name(stmt->ret_type));
}

void visit(StackAllocaStmt *stmt) override {
Expand Down
4 changes: 2 additions & 2 deletions taichi/backends/cpu/codegen_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ class CodeGenLLVMCPU : public CodeGenLLVM {

for (auto s : stmt->arg_stmts) {
TI_ASSERT(s->width() == 1);
arg_types.push_back(tlctx->get_data_type(s->ret_type.data_type));
arg_types.push_back(tlctx->get_data_type(s->ret_type));
arg_values.push_back(llvm_val[s]);
}

for (auto s : stmt->output_stmts) {
TI_ASSERT(s->width() == 1);
auto t = tlctx->get_data_type(s->ret_type.data_type);
auto t = tlctx->get_data_type(s->ret_type);
auto ptr = llvm::PointerType::get(t, 0);
arg_types.push_back(ptr);
arg_values.push_back(llvm_val[s]);
Expand Down
41 changes: 20 additions & 21 deletions taichi/backends/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
if (std::holds_alternative<Stmt *>(content)) {
auto arg_stmt = std::get<Stmt *>(content);

formats += data_type_format(arg_stmt->ret_type.data_type);
formats += data_type_format(arg_stmt->ret_type);

auto value_type = tlctx->get_data_type(arg_stmt->ret_type.data_type);
auto value_type = tlctx->get_data_type(arg_stmt->ret_type);
auto value = llvm_val[arg_stmt];
if (arg_stmt->ret_type.data_type == PrimitiveType::f32) {
if (arg_stmt->ret_type == PrimitiveType::f32) {
value_type = tlctx->get_data_type(PrimitiveType::f64);
value = builder->CreateFPExt(value, value_type);
}
Expand Down Expand Up @@ -158,7 +158,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
void emit_extra_unary(UnaryOpStmt *stmt) override {
// functions from libdevice
auto input = llvm_val[stmt->operand];
auto input_taichi_type = stmt->operand->ret_type.data_type;
auto input_taichi_type = stmt->operand->ret_type;
auto op = stmt->op_type;

#define UNARY_STD(x) \
Expand Down Expand Up @@ -232,58 +232,58 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
for (int l = 0; l < stmt->width(); l++) {
llvm::Value *old_value;
if (stmt->op_type == AtomicOpType::add) {
if (is_integral(stmt->val->ret_type.data_type)) {
if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Add, llvm_val[stmt->dest],
llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == PrimitiveType::f32) {
} else if (stmt->val->ret_type == PrimitiveType::f32) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest],
llvm_val[stmt->val], AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == PrimitiveType::f64) {
} else if (stmt->val->ret_type == PrimitiveType::f64) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::FAdd, llvm_val[stmt->dest],
llvm_val[stmt->val], AtomicOrdering::SequentiallyConsistent);
} else {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::min) {
if (is_integral(stmt->val->ret_type.data_type)) {
if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Min, llvm_val[stmt->dest],
llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == PrimitiveType::f32) {
} else if (stmt->val->ret_type == PrimitiveType::f32) {
old_value =
builder->CreateCall(get_runtime_function("atomic_min_f32"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type.data_type == PrimitiveType::f64) {
} else if (stmt->val->ret_type == PrimitiveType::f64) {
old_value =
builder->CreateCall(get_runtime_function("atomic_min_f64"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::max) {
if (is_integral(stmt->val->ret_type.data_type)) {
if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Max, llvm_val[stmt->dest],
llvm_val[stmt->val],
llvm::AtomicOrdering::SequentiallyConsistent);
} else if (stmt->val->ret_type.data_type == PrimitiveType::f32) {
} else if (stmt->val->ret_type == PrimitiveType::f32) {
old_value =
builder->CreateCall(get_runtime_function("atomic_max_f32"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else if (stmt->val->ret_type.data_type == PrimitiveType::f64) {
} else if (stmt->val->ret_type == PrimitiveType::f64) {
old_value =
builder->CreateCall(get_runtime_function("atomic_max_f64"),
{llvm_val[stmt->dest], llvm_val[stmt->val]});
} else {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::bit_and) {
if (is_integral(stmt->val->ret_type.data_type)) {
if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::And, llvm_val[stmt->dest],
llvm_val[stmt->val],
Expand All @@ -292,7 +292,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::bit_or) {
if (is_integral(stmt->val->ret_type.data_type)) {
if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Or, llvm_val[stmt->dest],
llvm_val[stmt->val],
Expand All @@ -301,7 +301,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
TI_NOT_IMPLEMENTED
}
} else if (stmt->op_type == AtomicOpType::bit_xor) {
if (is_integral(stmt->val->ret_type.data_type)) {
if (is_integral(stmt->val->ret_type)) {
old_value = builder->CreateAtomicRMW(
llvm::AtomicRMWInst::BinOp::Xor, llvm_val[stmt->dest],
llvm_val[stmt->val],
Expand All @@ -317,10 +317,9 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
}

void visit(RandStmt *stmt) override {
llvm_val[stmt] =
create_call(fmt::format("cuda_rand_{}",
data_type_short_name(stmt->ret_type.data_type)),
{get_context()});
llvm_val[stmt] = create_call(
fmt::format("cuda_rand_{}", data_type_short_name(stmt->ret_type)),
{get_context()});
}
void visit(RangeForStmt *for_stmt) override {
create_naive_range_for(for_stmt);
Expand Down Expand Up @@ -397,7 +396,7 @@ class CodeGenLLVMCUDA : public CodeGenLLVM {
if (should_cache_as_read_only) {
// Issue an CUDA "__ldg" instruction so that data are cached in
// the CUDA read-only data cache.
auto dtype = stmt->ret_type.data_type;
auto dtype = stmt->ret_type;
auto llvm_dtype = llvm_type(dtype);
auto llvm_dtype_ptr = llvm::PointerType::get(llvm_type(dtype), 0);
llvm::Intrinsic::ID intrin;
Expand Down
8 changes: 4 additions & 4 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class KernelCodegen : public IRVisitor {
}
} else if (opty == SNodeOpType::append) {
TI_ASSERT(is_dynamic);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
TI_ASSERT(stmt->ret_type == PrimitiveType::i32);
emit("{} = {}.append({});", result_var, parent, stmt->val->raw_name());
} else if (opty == SNodeOpType::length) {
TI_ASSERT(is_dynamic);
Expand Down Expand Up @@ -412,7 +412,7 @@ class KernelCodegen : public IRVisitor {
const auto bin_name = bin->raw_name();
const auto op_type = bin->op_type;
if (op_type == BinaryOpType::floordiv) {
if (is_integral(bin->ret_type.data_type)) {
if (is_integral(bin->ret_type)) {
emit("const {} {} = ifloordiv({}, {});", dt_name, bin_name, lhs_name,
rhs_name);
} else {
Expand All @@ -421,7 +421,7 @@ class KernelCodegen : public IRVisitor {
}
return;
}
if (op_type == BinaryOpType::pow && is_integral(bin->ret_type.data_type)) {
if (op_type == BinaryOpType::pow && is_integral(bin->ret_type)) {
// TODO(k-ye): Make sure the type is not i64?
emit("const {} {} = pow_i32({}, {});", dt_name, bin_name, lhs_name,
rhs_name);
Expand Down Expand Up @@ -604,7 +604,7 @@ class KernelCodegen : public IRVisitor {

void visit(RandStmt *stmt) override {
emit("const auto {} = metal_rand_{}({});", stmt->raw_name(),
data_type_short_name(stmt->ret_type.data_type), kRandStateVarName);
data_type_short_name(stmt->ret_type), kRandStateVarName);
}

void visit(PrintStmt *stmt) override {
Expand Down
15 changes: 7 additions & 8 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ class KernelGen : public IRVisitor {
if (std::holds_alternative<Stmt *>(content)) {
auto arg_stmt = std::get<Stmt *>(content);
emit("_msg_set_{}({}, {}, {});",
opengl_data_type_short_name(arg_stmt->ret_type.data_type),
msgid_name, i, arg_stmt->short_name());
opengl_data_type_short_name(arg_stmt->ret_type), msgid_name, i,
arg_stmt->short_name());

} else {
auto str = std::get<std::string>(content);
Expand All @@ -281,9 +281,8 @@ class KernelGen : public IRVisitor {

void visit(RandStmt *stmt) override {
used.random = true;
emit("{} {} = _rand_{}();", opengl_data_type_name(stmt->ret_type.data_type),
stmt->short_name(),
opengl_data_type_short_name(stmt->ret_type.data_type));
emit("{} {} = _rand_{}();", opengl_data_type_name(stmt->ret_type),
stmt->short_name(), opengl_data_type_short_name(stmt->ret_type));
}

void visit(LinearizeStmt *stmt) override {
Expand Down Expand Up @@ -361,7 +360,7 @@ class KernelGen : public IRVisitor {
}

} else if (stmt->op_type == SNodeOpType::is_active) {
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
TI_ASSERT(stmt->ret_type == PrimitiveType::i32);
if (stmt->snode->type == SNodeType::dense ||
stmt->snode->type == SNodeType::root) {
emit("int {} = 1;", stmt->short_name());
Expand All @@ -374,7 +373,7 @@ class KernelGen : public IRVisitor {

} else if (stmt->op_type == SNodeOpType::append) {
TI_ASSERT(stmt->snode->type == SNodeType::dynamic);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
TI_ASSERT(stmt->ret_type == PrimitiveType::i32);
emit("int {} = atomicAdd(_data_i32_[{} >> 2], 1);", stmt->short_name(),
get_snode_meta_address(stmt->snode));
auto dt = stmt->val->element_type();
Expand All @@ -388,7 +387,7 @@ class KernelGen : public IRVisitor {

} else if (stmt->op_type == SNodeOpType::length) {
TI_ASSERT(stmt->snode->type == SNodeType::dynamic);
TI_ASSERT(stmt->ret_type.data_type == PrimitiveType::i32);
TI_ASSERT(stmt->ret_type == PrimitiveType::i32);
emit("int {} = _data_i32_[{} >> 2];", stmt->short_name(),
get_snode_meta_address(stmt->snode));

Expand Down
Loading

0 comments on commit 3c2e5e9

Please sign in to comment.