Skip to content

Commit

Permalink
[test] Unify kernel setup for ndarray related tests
Browse files Browse the repository at this point in the history
We'll reuse these two kernels for cgraph tests as well so let's clean it
up first.
  • Loading branch information
Ailing Zhang committed May 19, 2022
1 parent 7157b13 commit b9d8c50
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 70 deletions.
39 changes: 3 additions & 36 deletions tests/cpp/aot/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "taichi/ir/statements.h"
#include "taichi/inc/constants.h"
#include "taichi/program/program.h"
#include "tests/cpp/ir/ndarray_kernel.h"
#include "tests/cpp/program/test_program.h"
#ifdef TI_WITH_VULKAN
#include "taichi/backends/vulkan/aot_module_loader_impl.h"
Expand Down Expand Up @@ -109,42 +110,8 @@ using namespace lang;
TestProgram test_prog;
test_prog.setup(arch);
auto aot_builder = test_prog.prog()->make_aot_module_builder(arch);
IRBuilder builder1, builder2;

{
auto *arg = builder1.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *zero = builder1.get_int32(0);
auto *one = builder1.get_int32(1);
auto *two = builder1.get_int32(2);
auto *a1ptr = builder1.create_external_ptr(arg, {one});
builder1.create_global_store(a1ptr, one); // a[1] = 1
auto *a0 =
builder1.create_global_load(builder1.create_external_ptr(arg, {zero}));
auto *a2ptr = builder1.create_external_ptr(arg, {two});
auto *a2 = builder1.create_global_load(a2ptr);
auto *a0plusa2 = builder1.create_add(a0, a2);
builder1.create_global_store(a2ptr, a0plusa2); // a[2] = a[0] + a[2]
}
auto block = builder1.extract_ir();
auto ker1 =
std::make_unique<Kernel>(*test_prog.prog(), std::move(block), "ker1");
ker1->insert_arg(get_data_type<int>(), /*is_array=*/true);
{
auto *arg0 = builder2.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *arg1 = builder2.create_arg_load(/*arg_id=*/1, get_data_type<int>(),
/*is_ptr=*/false);
auto *one = builder2.get_int32(1);
auto *a1ptr = builder2.create_external_ptr(arg0, {one});
builder2.create_global_store(a1ptr, arg1); // a[1] = arg1
}
auto block2 = builder2.extract_ir();
auto ker2 =
std::make_unique<Kernel>(*test_prog.prog(), std::move(block2), "ker2");
ker2->insert_arg(get_data_type<int>(), /*is_array=*/true);
ker2->insert_arg(get_data_type<int>(), /*is_array=*/false);

auto ker1 = setup_kernel1(test_prog.prog());
auto ker2 = setup_kernel2(test_prog.prog());
aot_builder->add("ker1", ker1.get());
aot_builder->add("ker2", ker2.get());
aot_builder->dump(".", "");
Expand Down
37 changes: 3 additions & 34 deletions tests/cpp/ir/ir_builder_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "taichi/ir/ir_builder.h"
#include "taichi/ir/statements.h"
#include "tests/cpp/program/test_program.h"
#include "tests/cpp/ir/ndarray_kernel.h"
#ifdef TI_WITH_VULKAN
#include "taichi/backends/vulkan/vulkan_loader.h"
#endif
Expand Down Expand Up @@ -136,25 +137,7 @@ TEST(IRBuilder, Ndarray) {
auto array = Ndarray(test_prog.prog(), PrimitiveType::i32, {size});
array.write_int({0}, 2);
array.write_int({2}, 40);
{
auto *arg = builder1.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *zero = builder1.get_int32(0);
auto *one = builder1.get_int32(1);
auto *two = builder1.get_int32(2);
auto *a1ptr = builder1.create_external_ptr(arg, {one});
builder1.create_global_store(a1ptr, one); // a[1] = 1
auto *a0 =
builder1.create_global_load(builder1.create_external_ptr(arg, {zero}));
auto *a2ptr = builder1.create_external_ptr(arg, {two});
auto *a2 = builder1.create_global_load(a2ptr);
auto *a0plusa2 = builder1.create_add(a0, a2);
builder1.create_global_store(a2ptr, a0plusa2); // a[2] = a[0] + a[2]
}
auto block1 = builder1.extract_ir();
auto ker1 =
std::make_unique<Kernel>(*test_prog.prog(), std::move(block1), "ker1");
ker1->insert_arg(get_data_type<int>(), /*is_array=*/true);
auto ker1 = setup_kernel1(test_prog.prog());
auto launch_ctx1 = ker1->make_launch_context();
launch_ctx1.set_arg_external_array(
/*arg_id=*/0, array.get_device_allocation_ptr_as_int(), size,
Expand All @@ -164,21 +147,7 @@ TEST(IRBuilder, Ndarray) {
EXPECT_EQ(array.read_int({1}), 1);
EXPECT_EQ(array.read_int({2}), 42);

IRBuilder builder2;
{
auto *arg0 = builder2.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *arg1 = builder2.create_arg_load(/*arg_id=*/1, PrimitiveType::i32,
/*is_ptr=*/false);
auto *one = builder2.get_int32(1);
auto *a1ptr = builder2.create_external_ptr(arg0, {one});
builder2.create_global_store(a1ptr, arg1); // a[1] = arg1
}
auto block2 = builder2.extract_ir();
auto ker2 =
std::make_unique<Kernel>(*test_prog.prog(), std::move(block2), "ker2");
ker2->insert_arg(get_data_type<int>(), /*is_array=*/true);
ker2->insert_arg(get_data_type<int>(), /*is_array=*/false);
auto ker2 = setup_kernel2(test_prog.prog());
auto launch_ctx2 = ker2->make_launch_context();
launch_ctx2.set_arg_external_array(
/*arg_id=*/0, array.get_device_allocation_ptr_as_int(), size,
Expand Down
48 changes: 48 additions & 0 deletions tests/cpp/ir/ndarray_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "tests/cpp/ir/ndarray_kernel.h"

namespace taichi {
namespace lang {

std::unique_ptr<Kernel> setup_kernel1(Program *prog) {
IRBuilder builder1;
{
auto *arg = builder1.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *zero = builder1.get_int32(0);
auto *one = builder1.get_int32(1);
auto *two = builder1.get_int32(2);
auto *a1ptr = builder1.create_external_ptr(arg, {one});
builder1.create_global_store(a1ptr, one); // a[1] = 1
auto *a0 =
builder1.create_global_load(builder1.create_external_ptr(arg, {zero}));
auto *a2ptr = builder1.create_external_ptr(arg, {two});
auto *a2 = builder1.create_global_load(a2ptr);
auto *a0plusa2 = builder1.create_add(a0, a2);
builder1.create_global_store(a2ptr, a0plusa2); // a[2] = a[0] + a[2]
}
auto block = builder1.extract_ir();
auto ker1 = std::make_unique<Kernel>(*prog, std::move(block), "ker1");
ker1->insert_arg(get_data_type<int>(), /*is_array=*/true);
return ker1;
}

std::unique_ptr<Kernel> setup_kernel2(Program *prog) {
IRBuilder builder2;

{
auto *arg0 = builder2.create_arg_load(/*arg_id=*/0, get_data_type<int>(),
/*is_ptr=*/true);
auto *arg1 = builder2.create_arg_load(/*arg_id=*/1, get_data_type<int>(),
/*is_ptr=*/false);
auto *one = builder2.get_int32(1);
auto *a1ptr = builder2.create_external_ptr(arg0, {one});
builder2.create_global_store(a1ptr, arg1); // a[1] = arg1
}
auto block2 = builder2.extract_ir();
auto ker2 = std::make_unique<Kernel>(*prog, std::move(block2), "ker2");
ker2->insert_arg(get_data_type<int>(), /*is_array=*/true);
ker2->insert_arg(get_data_type<int>(), /*is_array=*/false);
return ker2;
}
} // namespace lang
} // namespace taichi
14 changes: 14 additions & 0 deletions tests/cpp/ir/ndarray_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once
#include "taichi/ir/ir_builder.h"
#include "taichi/ir/statements.h"
#include "taichi/inc/constants.h"
#include "taichi/program/program.h"

namespace taichi {
namespace lang {

std::unique_ptr<Kernel> setup_kernel1(Program *prog);

std::unique_ptr<Kernel> setup_kernel2(Program *prog);
} // namespace lang
} // namespace taichi

0 comments on commit b9d8c50

Please sign in to comment.