Skip to content

Commit

Permalink
[bug] Fix that the cgraph doesn't respect the caps set by ti.aot.Modu…
Browse files Browse the repository at this point in the history
…le() (#6520)

Issue: #6508 
This PR fixes #6508 but causes the bug reported in #5699 (only) when
enabling the offline cache.
Using `CacheManager` manage the compilation in `AotModuleBuilder` can
fix the bug. I will fix it in next PR.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PGZXB and pre-commit-ci[bot] authored Nov 4, 2022
1 parent 83a41d8 commit 38b8ef7
Show file tree
Hide file tree
Showing 25 changed files with 39 additions and 129 deletions.
14 changes: 12 additions & 2 deletions taichi/aot/graph_data.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "taichi/aot/graph_data.h"
#include "taichi/program/ndarray.h"
#include "taichi/program/texture.h"
#include "taichi/program/kernel.h"

#include <numeric>

Expand All @@ -12,7 +13,7 @@ void CompiledGraph::run(
for (const auto &dispatch : dispatches) {
RuntimeContext ctx = ctx_;

TI_ASSERT(dispatch.compiled_kernel);
TI_ASSERT(dispatch.ti_kernel || dispatch.compiled_kernel);

// Populate args metadata into RuntimeContext
const auto &symbolic_args_ = dispatch.symbolic_args;
Expand Down Expand Up @@ -77,7 +78,16 @@ void CompiledGraph::run(
TI_ERROR("Error in compiled graph: unknown tag {}", ival.tag);
}
}
dispatch.compiled_kernel->launch(&ctx);

if (dispatch.compiled_kernel) {
// Run cgraph loaded from AOT module
dispatch.compiled_kernel->launch(&ctx);
} else {
// JIT & Run
TI_ASSERT(dispatch.ti_kernel);
lang::Kernel::LaunchContextBuilder launch_ctx(dispatch.ti_kernel, &ctx);
dispatch.ti_kernel->operator()(launch_ctx);
}
}
}
} // namespace aot
Expand Down
2 changes: 2 additions & 0 deletions taichi/aot/graph_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace taichi::lang {
class AotModuleBuilder;
class Ndarray;
class Texture;
class Kernel;

namespace aot {
// Currently only scalar, matrix and ndarray are supported.
Expand Down Expand Up @@ -149,6 +150,7 @@ struct CompiledDispatch {
std::string kernel_name;
std::vector<Arg> symbolic_args;
Kernel *compiled_kernel{nullptr};
taichi::lang::Kernel *ti_kernel{nullptr};

TI_IO_DEF(kernel_name, symbolic_args);
};
Expand Down
6 changes: 5 additions & 1 deletion taichi/aot/module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,12 @@ void AotModuleBuilder::add_graph(const std::string &name,
TI_ERROR("Graph {} already exists", name);
}
// Handle adding kernels separately.
std::unordered_map<std::string, lang::Kernel *> kernels;
for (const auto &dispatch : graph.dispatches) {
add_compiled_kernel(dispatch.kernel_name, dispatch.compiled_kernel);
kernels[dispatch.kernel_name] = dispatch.ti_kernel;
}
for (auto &e : kernels) {
add(e.first, e.second);
}
graphs_[name] = graph;
}
Expand Down
5 changes: 0 additions & 5 deletions taichi/aot/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ class AotModuleBuilder {
TI_NOT_IMPLEMENTED;
}

virtual void add_compiled_kernel(const std::string &identifier,
aot::Kernel *kernel) {
TI_NOT_IMPLEMENTED;
}

virtual void add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) {
Expand Down
17 changes: 10 additions & 7 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2566,13 +2566,16 @@ void KernelCodegen::run(TaichiKernelAttributes &kernel_attribs,
}

void lower(Kernel *kernel) {
auto &config = kernel->program->this_thread_config();
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(kernel->ir.get(), config, kernel,
kernel->autodiff_mode,
/*ad_use_stack=*/false, config.print_ir,
/*lower_global_access=*/true,
/*make_thread_local=*/false);
if (!kernel->lowered()) {
auto &config = kernel->program->this_thread_config();
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(kernel->ir.get(), config, kernel,
kernel->autodiff_mode,
/*ad_use_stack=*/false, config.print_ir,
/*lower_global_access=*/true,
/*make_thread_local=*/false);
kernel->set_lowered(true);
}
}

} // namespace spirv
Expand Down
10 changes: 5 additions & 5 deletions taichi/program/graph_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
namespace taichi::lang {
void Dispatch::compile(
std::vector<aot::CompiledDispatch> &compiled_dispatches) {
if (kernel_->compiled_aot_kernel() == nullptr) {
kernel_->compile_to_aot_kernel();
}
aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_,
kernel_->compiled_aot_kernel()};
aot::CompiledDispatch dispatch;
dispatch.kernel_name = kernel_->get_name();
dispatch.symbolic_args = symbolic_args_;
dispatch.ti_kernel = kernel_;
dispatch.compiled_kernel = nullptr;
compiled_dispatches.push_back(std::move(dispatch));
}

Expand Down
4 changes: 0 additions & 4 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ void Kernel::compile() {
compiled_ = program->compile(*this);
}

void Kernel::compile_to_aot_kernel() {
compiled_aot_kernel_ = program->make_aot_kernel(*this);
}

void Kernel::lower(bool to_executable) {
TI_ASSERT(!lowered_);
TI_ASSERT(supports_lowering(arch));
Expand Down
14 changes: 4 additions & 10 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,12 @@ class TI_DLL_EXPORT Kernel : public Callable {
return lowered_;
}

void compile();

void compile_to_aot_kernel();

aot::Kernel *compiled_aot_kernel() {
return compiled_aot_kernel_.get();
void set_lowered(bool lowered) {
lowered_ = lowered;
}

void compile();

/**
* Lowers |ir| to CHI IR level
*
Expand Down Expand Up @@ -159,10 +157,6 @@ class TI_DLL_EXPORT Kernel : public Callable {
bool ir_is_ast_{false};
// The closure that, if invoked, launches the backend kernel (shader)
FunctionType compiled_{nullptr};
// TODO[#5114]: It's kinda redundant to keep both compiled_ (used for JIT
// execution) as well as compiled_aot_kernel_. In fact we'd better unify
// everything around compiled_aot_kernel and rename it.
std::unique_ptr<aot::Kernel> compiled_aot_kernel_{nullptr};
// A flag to record whether |ir| has been fully lowered.
// lower initial AST all the way down to a bunch of
// OffloadedStmt for async execution TODO(Lin): Check this comment
Expand Down
4 changes: 0 additions & 4 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,6 @@ class TI_DLL_EXPORT Program {
// TODO(Lin): remove the offloaded parameter
FunctionType compile(Kernel &kernel, OffloadedStmt *offloaded = nullptr);

std::unique_ptr<aot::Kernel> make_aot_kernel(Kernel &kernel) {
return program_impl_->make_aot_kernel(kernel);
}

void check_runtime_error();

Kernel &get_snode_reader(SNode *snode);
Expand Down
7 changes: 0 additions & 7 deletions taichi/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,6 @@ class ProgramImpl {
virtual std::unique_ptr<AotModuleBuilder> make_aot_module_builder(
const DeviceCapabilityConfig &caps) = 0;

/**
* Compile a taichi::lang::Kernel to taichi::lang::aot::Kernel.
*/
virtual std::unique_ptr<aot::Kernel> make_aot_kernel(Kernel &kernel) {
TI_NOT_IMPLEMENTED;
}

/**
* Dump Offline-cache data to disk
*/
Expand Down
6 changes: 0 additions & 6 deletions taichi/runtime/dx12/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,6 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
compiled_kernel.tasks = compiled_data.tasks;
}

void AotModuleBuilderImpl::add_compiled_kernel(const std::string &identifier,
aot::Kernel *kernel) {
// FIXME: implement add_compiled_kernel.
TI_NOT_IMPLEMENTED;
}

void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
Expand Down
3 changes: 0 additions & 3 deletions taichi/runtime/dx12/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const std::string &key,
Kernel *kernel) override;

void add_compiled_kernel(const std::string &identifier,
aot::Kernel *kernel) override;

LlvmProgramImpl *prog;
ModuleDataDX12 module_data;
};
Expand Down
10 changes: 0 additions & 10 deletions taichi/runtime/gfx/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,16 +205,6 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes);
}

void AotModuleBuilderImpl::add_compiled_kernel(const std::string &identifier,
aot::Kernel *kernel) {
GfxRuntime::RegisterParams register_params =
static_cast<KernelImpl *>(kernel)->params();
register_params.kernel_attribs.name = identifier;
ti_aot_data_.kernels.push_back(std::move(register_params.kernel_attribs));
ti_aot_data_.spirv_codes.push_back(
std::move(register_params.task_spirv_source_codes));
}

void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
Expand Down
3 changes: 0 additions & 3 deletions taichi/runtime/gfx/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const std::string &key,
Kernel *kernel) override;

void add_compiled_kernel(const std::string &identifier,
aot::Kernel *kernel) override;

std::string write_spv_file(const std::string &output_dir,
const TaskAttributes &k,
const std::vector<uint32_t> &source_code) const;
Expand Down
14 changes: 0 additions & 14 deletions taichi/runtime/llvm/llvm_aot_module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,4 @@ void LlvmAotModuleBuilder::add_field_per_backend(const std::string &identifier,
cache_.fields[snode_tree_id] = std::move(field_cache);
}

void LlvmAotModuleBuilder::add_compiled_kernel(const std::string &identifier,
aot::Kernel *kernel) {
auto *kernel_impl = dynamic_cast<llvm_aot::KernelImpl *>(kernel);
TI_ASSERT(kernel_impl);
if (!kernel_impl->kernel_data_.created_at) {
kernel_impl->kernel_data_.last_used_at = std::time(nullptr);
kernel_impl->kernel_data_.created_at = std::time(nullptr);
}
const std::string &kernel_name = identifier;
if (cache_.kernels.find(kernel_name) == cache_.kernels.end()) {
cache_.kernels[kernel_name] = std::move(kernel_impl->kernel_data_);
}
}

} // namespace taichi::lang
3 changes: 0 additions & 3 deletions taichi/runtime/llvm/llvm_aot_module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ class LlvmAotModuleBuilder : public AotModuleBuilder {
int row_num,
int column_num) override;

void add_compiled_kernel(const std::string &identifier,
aot::Kernel *kernel) override;

const LlvmOfflineCache &get_cache() {
return cache_;
}
Expand Down
9 changes: 0 additions & 9 deletions taichi/runtime/program_impls/dx/dx_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,6 @@ DeviceAllocation Dx11ProgramImpl::allocate_memory_ndarray(
/*export_sharing=*/false});
}

std::unique_ptr<aot::Kernel> Dx11ProgramImpl::make_aot_kernel(Kernel &kernel) {
spirv::lower(&kernel);
std::vector<gfx::CompiledSNodeStructs> compiled_structs;
gfx::GfxRuntime::RegisterParams kparams = gfx::run_codegen(
&kernel, Arch::dx11, get_compute_device()->get_current_caps(),
compiled_structs);
return std::make_unique<gfx::KernelImpl>(runtime_.get(), std::move(kparams));
}

} // namespace taichi::lang

#endif
2 changes: 0 additions & 2 deletions taichi/runtime/program_impls/dx/dx_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class Dx11ProgramImpl : public ProgramImpl {
return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id);
}

std::unique_ptr<aot::Kernel> make_aot_kernel(Kernel &kernel) override;

private:
std::shared_ptr<Device> device_{nullptr};
std::unique_ptr<gfx::GfxRuntime> runtime_{nullptr};
Expand Down
14 changes: 0 additions & 14 deletions taichi/runtime/program_impls/llvm/llvm_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,6 @@ std::unique_ptr<AotModuleBuilder> LlvmProgramImpl::make_aot_module_builder(
return nullptr;
}

std::unique_ptr<aot::Kernel> LlvmProgramImpl::make_aot_kernel(Kernel &kernel) {
auto compiled_fn =
this->compile(&kernel, nullptr); // Offloaded used in async mode only

const std::string &kernel_key = kernel.get_cached_kernel_key();
TI_ASSERT(cache_data_->kernels.count(kernel_key));
const LlvmOfflineCache::KernelCacheData &kernel_data =
cache_data_->kernels[kernel_key];
LlvmOfflineCache::KernelCacheData compiled_kernel = kernel_data.clone();
compiled_kernel.kernel_key = kernel.get_name();
return std::make_unique<llvm_aot::KernelImpl>(compiled_fn,
std::move(compiled_kernel));
}

void LlvmProgramImpl::cache_kernel(const std::string &kernel_key,
const LLVMCompiledKernel &data,
std::vector<LlvmLaunchArgInfo> &&args) {
Expand Down
2 changes: 0 additions & 2 deletions taichi/runtime/program_impls/llvm/llvm_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ class LlvmProgramImpl : public ProgramImpl {
std::unique_ptr<StructCompiler> compile_snode_tree_types_impl(
SNodeTree *tree);

std::unique_ptr<aot::Kernel> make_aot_kernel(Kernel &kernel) override;

std::unique_ptr<AotModuleBuilder> make_aot_module_builder(
const DeviceCapabilityConfig &caps) override;

Expand Down
6 changes: 0 additions & 6 deletions taichi/runtime/program_impls/opengl/opengl_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ DeviceAllocation OpenglProgramImpl::allocate_texture(
return runtime_->create_image(params);
}

std::unique_ptr<aot::Kernel> OpenglProgramImpl::make_aot_kernel(
Kernel &kernel) {
auto params = get_cache_manager()->load_or_compile(config, &kernel);
return std::make_unique<gfx::KernelImpl>(runtime_.get(), std::move(params));
}

void OpenglProgramImpl::dump_cache_data_to_disk() {
const auto &mgr = get_cache_manager();
mgr->clean_offline_cache(offline_cache::string_to_clean_cache_policy(
Expand Down
2 changes: 0 additions & 2 deletions taichi/runtime/program_impls/opengl/opengl_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class OpenglProgramImpl : public ProgramImpl {
return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id);
}

std::unique_ptr<aot::Kernel> make_aot_kernel(Kernel &kernel) override;

void dump_cache_data_to_disk() override;

const std::unique_ptr<gfx::CacheManager> &get_cache_manager();
Expand Down
7 changes: 0 additions & 7 deletions taichi/runtime/program_impls/vulkan/vulkan_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,6 @@ DeviceAllocation VulkanProgramImpl::allocate_texture(
return vulkan_runtime_->create_image(params);
}

std::unique_ptr<aot::Kernel> VulkanProgramImpl::make_aot_kernel(
Kernel &kernel) {
auto params = get_cache_manager()->load_or_compile(config, &kernel);
return std::make_unique<gfx::KernelImpl>(vulkan_runtime_.get(),
std::move(params));
}

void VulkanProgramImpl::enqueue_compute_op_lambda(
std::function<void(Device *device, CommandList *cmdlist)> op,
const std::vector<ComputeOpImageRef> &image_refs) {
Expand Down
2 changes: 0 additions & 2 deletions taichi/runtime/program_impls/vulkan/vulkan_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ class VulkanProgramImpl : public ProgramImpl {
return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id);
}

std::unique_ptr<aot::Kernel> make_aot_kernel(Kernel &kernel) override;

void enqueue_compute_op_lambda(
std::function<void(Device *device, CommandList *cmdlist)> op,
const std::vector<ComputeOpImageRef> &image_refs) override;
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def test(a: ti.types.ndarray(), c: ti.u8):

g.run({'a': a, 'c': c})

m = ti.aot.Module(ti.lang.impl.current_cfg().arch)
m = ti.aot.Module(ti.lang.impl.current_cfg().arch, caps=['spirv_has_int8'])
m.add_graph('g_init', g)
with tempfile.TemporaryDirectory() as tmpdir:
m.save(tmpdir, '')

0 comments on commit 38b8ef7

Please sign in to comment.