Skip to content

Commit

Permalink
[Lang] Add config.force_scalarize_matrix to avoid perf-regression in …
Browse files Browse the repository at this point in the history
…certain scenario (taichi-dev#8509)

Issue: #

### Brief Summary

copilot:summary

### Walkthrough

copilot:walkthrough
  • Loading branch information
jim19930609 authored Apr 18, 2024
1 parent 52b24f3 commit 0da6846
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 16 deletions.
1 change: 1 addition & 0 deletions taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
serializer(config.experimental_auto_mesh_local);
serializer(config.auto_mesh_local_default_occupacy);
serializer(config.real_matrix_scalarize);
serializer(config.force_scalarize_matrix);
serializer(config.half2_vectorization);
serializer.finalize();

Expand Down
2 changes: 1 addition & 1 deletion taichi/codegen/codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace taichi::lang {

inline bool codegen_vector_type(const CompileConfig &config) {
return !config.real_matrix_scalarize;
return !(config.real_matrix_scalarize || config.force_scalarize_matrix);
}

// Parses a C-style printf format string specifier into its constituent parts.
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ CompileConfig::CompileConfig() {
make_block_local = true;
detect_read_only = true;
real_matrix_scalarize = true;
force_scalarize_matrix = false;
half2_vectorization = false;
make_cpu_multithreading_loop = true;

Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ struct CompileConfig {
bool make_block_local;
bool detect_read_only;
bool real_matrix_scalarize;
bool force_scalarize_matrix;
bool half2_vectorization;
bool make_cpu_multithreading_loop;
DataType default_fp;
Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,8 @@ void export_lang(py::module &m) {
.def_readwrite("detect_read_only", &CompileConfig::detect_read_only)
.def_readwrite("real_matrix_scalarize",
&CompileConfig::real_matrix_scalarize)
.def_readwrite("force_scalarize_matrix",
&CompileConfig::force_scalarize_matrix)
.def_readwrite("half2_vectorization", &CompileConfig::half2_vectorization)
.def_readwrite("make_cpu_multithreading_loop",
&CompileConfig::make_cpu_multithreading_loop)
Expand Down
51 changes: 40 additions & 11 deletions taichi/transforms/auto_diff.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,13 @@ class IndependentBlocksJudger : public BasicStmtVisitor {
if (is_inside_loop_)
return;

if (stmt->dest->is<ExternalPtrStmt>()) {
if (stmt->dest->as<ExternalPtrStmt>()
Stmt *dest = stmt->dest;
if (dest->is<MatrixPtrStmt>()) {
dest = dest->as<MatrixPtrStmt>()->origin;
}

if (dest->is<ExternalPtrStmt>()) {
if (dest->as<ExternalPtrStmt>()
->base_ptr->as<ArgLoadStmt>()
->ret_type.ptr_removed()
->as<StructType>()
Expand All @@ -92,8 +97,8 @@ class IndependentBlocksJudger : public BasicStmtVisitor {
qualified_glb_operations_ = true;
}
} else {
TI_ASSERT(stmt->dest->is<GlobalPtrStmt>());
if (stmt->dest->as<GlobalPtrStmt>()->snode->has_adjoint()) {
TI_ASSERT(dest->is<GlobalPtrStmt>());
if (dest->as<GlobalPtrStmt>()->snode->has_adjoint()) {
qualified_glb_operations_ = true;
}
}
Expand All @@ -108,15 +113,21 @@ class IndependentBlocksJudger : public BasicStmtVisitor {
// another IndependentBlocksJudger
if (is_inside_loop_)
return;
if ((stmt->src->is<ExternalPtrStmt>() &&
stmt->src->as<ExternalPtrStmt>()

Stmt *src = stmt->src;
if (src->is<MatrixPtrStmt>()) {
src = src->as<MatrixPtrStmt>()->origin;
}

if ((src->is<ExternalPtrStmt>() &&
src->as<ExternalPtrStmt>()
->base_ptr->as<ArgLoadStmt>()
->ret_type.ptr_removed()
->as<StructType>()
->elements()
.size() > TypeFactory::GRAD_PTR_POS_IN_NDARRAY) ||
(stmt->src->is<GlobalPtrStmt>() &&
stmt->src->as<GlobalPtrStmt>()->snode->has_adjoint())) {
(src->is<GlobalPtrStmt>() &&
src->as<GlobalPtrStmt>()->snode->has_adjoint())) {
qualified_glb_operations_ = true;
}
}
Expand Down Expand Up @@ -2425,7 +2436,13 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor {
using BasicStmtVisitor::visit;

void visit(GlobalLoadStmt *stmt) override {
GlobalPtrStmt *src = stmt->src->as<GlobalPtrStmt>();
GlobalPtrStmt *src = nullptr;
if (stmt->src->is<GlobalPtrStmt>()) {
src = stmt->src->as<GlobalPtrStmt>();
} else {
TI_ASSERT(stmt->src->is<MatrixPtrStmt>());
src = stmt->src->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
}
auto snode = src->snode;
if (!snode->has_adjoint_checkbit()) {
return;
Expand Down Expand Up @@ -2466,12 +2483,24 @@ class GloablDataAccessRuleChecker : public BasicStmtVisitor {
}

void visit(GlobalStoreStmt *stmt) override {
GlobalPtrStmt *dest = stmt->dest->as<GlobalPtrStmt>();
GlobalPtrStmt *dest = nullptr;
if (stmt->dest->is<GlobalPtrStmt>()) {
dest = stmt->dest->as<GlobalPtrStmt>();
} else {
TI_ASSERT(stmt->dest->is<MatrixPtrStmt>());
dest = stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
}
visit_gloabl_store_stmt_and_atomic_add(stmt, dest);
}

void visit(AtomicOpStmt *stmt) override {
GlobalPtrStmt *dest = stmt->dest->as<GlobalPtrStmt>();
GlobalPtrStmt *dest = nullptr;
if (stmt->dest->is<GlobalPtrStmt>()) {
dest = stmt->dest->as<GlobalPtrStmt>();
} else {
TI_ASSERT(stmt->dest->is<MatrixPtrStmt>());
dest = stmt->dest->as<MatrixPtrStmt>()->origin->as<GlobalPtrStmt>();
}
visit_gloabl_store_stmt_and_atomic_add(stmt, dest);
}

Expand Down
13 changes: 11 additions & 2 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ void compile_to_offloads(IRNode *ir,
irpass::analysis::gather_meshfor_relation_types(ir);
}

if (config.force_scalarize_matrix) {
irpass::scalarize(ir, false /*half2_optimization_enabled*/);
}

if (config.debug && autodiff_mode == AutodiffMode::kCheckAutodiffValid) {
// Check whether the kernel obeys the autodiff limitation e.g., gloabl data
// access rule
Expand Down Expand Up @@ -136,8 +140,9 @@ void compile_to_offloads(IRNode *ir,
// TODO: This pass may be redundant as cfg_optimization() is already called
// in full_simplify().
if (config.opt_level > 0 && config.cfg_optimization) {
irpass::cfg_optimization(ir, false, /*autodiff_enabled*/ false,
!config.real_matrix_scalarize);
irpass::cfg_optimization(
ir, false, /*autodiff_enabled*/ false,
!config.real_matrix_scalarize && !config.force_scalarize_matrix);
print("Optimized by CFG");
irpass::analysis::verify(ir);
}
Expand Down Expand Up @@ -371,6 +376,10 @@ void compile_function(IRNode *ir,
func->set_ir_stage(Function::IRStage::BeforeLowerAccess);
}

if (config.force_scalarize_matrix) {
irpass::scalarize(ir, false /*half2_optimization_enabled*/);
}

if (target_stage >= Function::IRStage::OptimizedIR &&
current_stage < Function::IRStage::OptimizedIR) {
irpass::lower_access(ir, config, {{}, true});
Expand Down
5 changes: 3 additions & 2 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,9 @@ void full_simplify(IRNode *root,
// Don't do this time-consuming optimization pass again if the IR is
// not modified.
if (config.opt_level > 0 && first_iteration && config.cfg_optimization &&
cfg_optimization(root, args.after_lower_access, args.autodiff_enabled,
!config.real_matrix_scalarize))
cfg_optimization(
root, args.after_lower_access, args.autodiff_enabled,
!config.real_matrix_scalarize && !config.force_scalarize_matrix))
modified = true;
print("cfg_optimization");
first_iteration = false;
Expand Down

0 comments on commit 0da6846

Please sign in to comment.