Skip to content

Commit

Permalink
[Aot] Deprecate element shape and field dim for AOT symbolic args (ta…
Browse files Browse the repository at this point in the history
…ichi-dev#7100)

Issue: taichi-dev#6572 

* Fix using dtype for symbolic args in AOT usages
* Cleanup all the usages of `field_dim` and `element_shape` in examples and tests
  • Loading branch information
turbo0628 authored and lin-hitonami committed Jan 11, 2023
1 parent 26cc34b commit 0b136ea
Show file tree
Hide file tree
Showing 17 changed files with 255 additions and 209 deletions.
19 changes: 15 additions & 4 deletions c_api/tests/c_api_cgraph_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,34 @@ void graph_aot_test(TiArch arch) {
ti::AotModule aot_mod = runtime.load_aot_module(aot_mod_ss.str().c_str());
ti::ComputeGraph run_graph = aot_mod.get_compute_graph("run_graph");

ti::NdArray<int32_t> arr_array =
ti::NdArray<int32_t> arr_array_0 =
runtime.allocate_ndarray<int32_t>({kArrLen}, {}, true);
ti::NdArray<int32_t> arr_array_1 =
runtime.allocate_ndarray<int32_t>({kArrLen}, {1}, true);

run_graph["base0"] = base0_val;
run_graph["base1"] = base1_val;
run_graph["base2"] = base2_val;
run_graph["arr"] = arr_array;
run_graph["arr0"] = arr_array_0;
run_graph["arr1"] = arr_array_1;
run_graph.launch();
runtime.wait();

// Check Results
auto *data = reinterpret_cast<int32_t *>(arr_array.map());
auto *data = reinterpret_cast<int32_t *>(arr_array_0.map());

for (int i = 0; i < kArrLen; i++) {
EXPECT_EQ(data[i], 3 * i + base0_val + base1_val + base2_val);
}
arr_array.unmap();

data = reinterpret_cast<int32_t *>(arr_array_1.map());

for (int i = 0; i < kArrLen; i++) {
EXPECT_EQ(data[i], 3 * i + base0_val + base1_val + base2_val);
}

arr_array_0.unmap();
arr_array_1.unmap();
}

void texture_aot_test(TiArch arch) {
Expand Down
29 changes: 11 additions & 18 deletions python/taichi/examples/graph/mpm88_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,32 +114,25 @@ def main():
# Build graph
sym_x = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'x',
ti.f32,
field_dim=1,
element_shape=(2, ))
dtype=ti.math.vec2,
ndim=1)
sym_v = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'v',
ti.f32,
field_dim=1,
element_shape=(2, ))
dtype=ti.math.vec2,
ndim=1)
sym_C = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'C',
ti.f32,
field_dim=1,
element_shape=(2, 2))
sym_J = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'J',
ti.f32,
field_dim=1)
dtype=ti.math.mat2,
ndim=1)
sym_J = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, 'J', ti.f32, ndim=1)
sym_grid_v = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'grid_v',
ti.f32,
field_dim=2,
element_shape=(2, ))
dtype=ti.math.vec2,
ndim=2)
sym_grid_m = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'grid_m',
ti.f32,
field_dim=2)
dtype=ti.f32,
ndim=2)
g_init_builder = ti.graph.GraphBuilder()
g_init_builder.dispatch(init_particles, sym_x, sym_v, sym_J)

Expand Down
36 changes: 16 additions & 20 deletions python/taichi/examples/graph/stable_fluid_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,40 +240,36 @@ def main():
print('running in graph mode')
velocities_pair_cur = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'velocities_pair_cur',
ti.f32,
field_dim=2,
element_shape=(2, ))
dtype=ti.math.vec2,
ndim=2)
velocities_pair_nxt = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'velocities_pair_nxt',
ti.f32,
field_dim=2,
element_shape=(2, ))
dtype=ti.math.vec2,
ndim=2)
dyes_pair_cur = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'dyes_pair_cur',
ti.f32,
field_dim=2,
element_shape=(3, ))
dtype=ti.math.vec3,
ndim=2)
dyes_pair_nxt = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'dyes_pair_nxt',
ti.f32,
field_dim=2,
element_shape=(3, ))
dtype=ti.math.vec3,
ndim=2)
pressures_pair_cur = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'pressures_pair_cur',
ti.f32,
field_dim=2)
dtype=ti.f32,
ndim=2)
pressures_pair_nxt = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'pressures_pair_nxt',
ti.f32,
field_dim=2)
dtype=ti.f32,
ndim=2)
velocity_divs = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'velocity_divs',
ti.f32,
field_dim=2)
dtype=ti.f32,
ndim=2)
mouse_data = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'mouse_data',
ti.f32,
field_dim=1)
dtype=ti.f32,
ndim=1)

g1_builder = ti.graph.GraphBuilder()
g1_builder.dispatch(advect, velocities_pair_cur, velocities_pair_cur,
Expand Down
5 changes: 2 additions & 3 deletions python/taichi/examples/graph/texture_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ def main():
_t = ti.graph.Arg(ti.graph.ArgKind.SCALAR, 't', ti.f32)
_pixels_arr = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'pixels_arr',
ti.f32,
field_dim=2,
element_shape=(4, ))
dtype=ti.math.vec4,
ndim=2)

_rw_tex = ti.graph.Arg(ti.graph.ArgKind.RWTEXTURE,
'rw_tex',
Expand Down
46 changes: 41 additions & 5 deletions python/taichi/graph/_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

from taichi._lib import core as _ti_core
from taichi.aot.utils import produce_injected_args
from taichi.lang import kernel_impl
Expand Down Expand Up @@ -97,12 +99,46 @@ def run(self, args):
def Arg(tag,
name,
dtype=None,
field_dim=0,
ndim=0,
field_dim=None,
element_shape=(),
channel_format=None,
shape=(),
num_channels=None):
if isinstance(dtype, MatrixType):
if field_dim is not None:
if ndim != 0:
raise TaichiRuntimeError(
'field_dim is deprecated, please do not specify field_dim and ndim at the same time.'
)
warnings.warn(
"The field_dim argument for ndarray will be deprecated in v1.5.0, use ndim instead.",
DeprecationWarning)
ndim = field_dim

if tag == ArgKind.SCALAR:
# The scalar tag should never work with array-like parameters
if ndim > 0 or isinstance(dtype, MatrixType) or len(element_shape) > 0:
raise TaichiRuntimeError(
f'Illegal Arg parameter (dtype={dtype}, ndim={ndim}, element_shape={element_shape}) for Scalar tag.'
)
return _ti_core.Arg(tag, name, dtype, ndim, element_shape)

if tag == ArgKind.NDARRAY:
# Ndarray with matrix data type
if isinstance(dtype, MatrixType):
return _ti_core.Arg(tag, name, dtype.dtype, ndim,
dtype.get_shape())
# Ndarray with scalar data type
if len(element_shape) > 0:
warnings.warn(
"The element_shape argument for ndarray will be deprecated in v1.5.0, use vector or matrix data type instead.",
DeprecationWarning)
return _ti_core.Arg(tag, name, dtype, ndim, element_shape)

if tag == ArgKind.MATRIX:
if not isinstance(dtype, MatrixType):
raise TaichiRuntimeError(
f'Tag {tag} must specify matrix data type, but got {dtype}.')
if len(element_shape) > 0:
raise TaichiRuntimeError(
f'Element shape for MatrixType argument "{name}" is not supported.'
Expand All @@ -114,8 +150,8 @@ def Arg(tag,
arg_sublist = []
for _ in range(mat_type.m):
arg_sublist.append(
_ti_core.Arg(tag, f'{name}_mat_arg_{i}', dtype.dtype,
field_dim, element_shape))
_ti_core.Arg(tag, f'{name}_mat_arg_{i}', dtype.dtype, ndim,
element_shape))
i += 1
arg_list.append(arg_sublist)
return arg_list
Expand All @@ -130,7 +166,7 @@ def Arg(tag,
channel_format=channel_format,
num_channels=num_channels,
shape=shape)
return _ti_core.Arg(tag, name, dtype, field_dim, element_shape)
raise TaichiRuntimeError(f'Unknowm tag {tag} for graph Arg {name}.')


__all__ = ['GraphBuilder', 'Graph', 'Arg', 'ArgKind']
18 changes: 13 additions & 5 deletions tests/cpp/aot/gfx_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,15 +279,18 @@ void run_cgraph1(Arch arch, taichi::lang::Device *device_) {
alloc_params.host_read = true;
alloc_params.size = size * sizeof(int);
alloc_params.usage = taichi::lang::AllocUsage::Storage;
DeviceAllocation devalloc_arr_ = device_->allocate_memory(alloc_params);
Ndarray arr = Ndarray(devalloc_arr_, PrimitiveType::i32, {size}, {1});
DeviceAllocation devalloc_arr_0 = device_->allocate_memory(alloc_params);
DeviceAllocation devalloc_arr_1 = device_->allocate_memory(alloc_params);
Ndarray arr0 = Ndarray(devalloc_arr_0, PrimitiveType::i32, {size});
Ndarray arr1 = Ndarray(devalloc_arr_1, PrimitiveType::i32, {size}, {1});

int base0 = 10;
int base1 = 20;
int base2 = 30;

std::unordered_map<std::string, taichi::lang::aot::IValue> args;
args.insert({"arr", taichi::lang::aot::IValue::create(arr)});
args.insert({"arr0", taichi::lang::aot::IValue::create(arr0)});
args.insert({"arr1", taichi::lang::aot::IValue::create(arr1)});
args.insert({"base0", taichi::lang::aot::IValue::create(base0)});
args.insert({"base1", taichi::lang::aot::IValue::create(base1)});
args.insert({"base2", taichi::lang::aot::IValue::create(base2)});
Expand All @@ -298,13 +301,18 @@ void run_cgraph1(Arch arch, taichi::lang::Device *device_) {
gfx_runtime->synchronize();

int dst[size] = {0};
load_devalloc(devalloc_arr_, dst, sizeof(dst));
load_devalloc(devalloc_arr_0, dst, sizeof(dst));
for (int i = 0; i < size; i++) {
EXPECT_EQ(dst[i], 3 * i + base0 + base1 + base2);
}
load_devalloc(devalloc_arr_1, dst, sizeof(dst));
for (int i = 0; i < size; i++) {
EXPECT_EQ(dst[i], 3 * i + base0 + base1 + base2);
}

// Deallocate
device_->dealloc_memory(devalloc_arr_);
device_->dealloc_memory(devalloc_arr_0);
device_->dealloc_memory(devalloc_arr_1);
}

void run_cgraph2(Arch arch, taichi::lang::Device *device_) {
Expand Down
61 changes: 46 additions & 15 deletions tests/cpp/aot/llvm/graph_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,32 +45,45 @@ TEST(LlvmCGraph, RunGraphCpu) {

constexpr int ArrLength = 100;
constexpr int kArrBytes_arr = ArrLength * 1 * sizeof(int32_t);
auto devalloc_arr =
auto devalloc_arr_0 =
exec.allocate_memory_ndarray(kArrBytes_arr, result_buffer);
auto devalloc_arr_1 =
exec.allocate_memory_ndarray(kArrBytes_arr, result_buffer);

/* Test with Graph */
// Prepare & Run "init" Graph
auto run_graph = mod->get_graph("run_graph");

auto arr = taichi::lang::Ndarray(
devalloc_arr, taichi::lang::PrimitiveType::i32, {ArrLength}, {1});
auto arr0 = taichi::lang::Ndarray(
devalloc_arr_0, taichi::lang::PrimitiveType::i32, {ArrLength});
auto arr1 = taichi::lang::Ndarray(
devalloc_arr_1, taichi::lang::PrimitiveType::i32, {ArrLength},
{
1,
});

int base0 = 10;
int base1 = 20;
int base2 = 30;
std::unordered_map<std::string, taichi::lang::aot::IValue> args;
args.insert({"arr", taichi::lang::aot::IValue::create(arr)});
args.insert({"arr0", taichi::lang::aot::IValue::create(arr0)});
args.insert({"arr1", taichi::lang::aot::IValue::create(arr1)});
args.insert({"base0", taichi::lang::aot::IValue::create(base0)});
args.insert({"base1", taichi::lang::aot::IValue::create(base1)});
args.insert({"base2", taichi::lang::aot::IValue::create(base2)});

run_graph->run(args);
exec.synchronize();

auto *data = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr));
auto *data_0 = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr_0));
auto *data_1 = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr_1));
for (int i = 0; i < ArrLength; i++) {
EXPECT_EQ(data_0[i], 3 * i + base0 + base1 + base2);
}
for (int i = 0; i < ArrLength; i++) {
EXPECT_EQ(data[i], 3 * i + base0 + base1 + base2);
EXPECT_EQ(data_1[i], 3 * i + base0 + base1 + base2);
}
}

Expand Down Expand Up @@ -99,34 +112,52 @@ TEST(LlvmCGraph, RunGraphCuda) {

constexpr int ArrLength = 100;
constexpr int kArrBytes_arr = ArrLength * 1 * sizeof(int32_t);
auto devalloc_arr =
auto devalloc_arr_0 =
exec.allocate_memory_ndarray(kArrBytes_arr, result_buffer);

auto devalloc_arr_1 =
exec.allocate_memory_ndarray(kArrBytes_arr, result_buffer);

/* Test with Graph */
// Prepare & Run "init" Graph
auto run_graph = mod->get_graph("run_graph");

auto arr = taichi::lang::Ndarray(
devalloc_arr, taichi::lang::PrimitiveType::i32, {ArrLength}, {1});
auto arr0 = taichi::lang::Ndarray(
devalloc_arr_0, taichi::lang::PrimitiveType::i32, {ArrLength});

auto arr1 = taichi::lang::Ndarray(
devalloc_arr_1, taichi::lang::PrimitiveType::i32, {ArrLength}, {1});

int base0 = 10;
int base1 = 20;
int base2 = 30;
std::unordered_map<std::string, taichi::lang::aot::IValue> args;
args.insert({"arr", taichi::lang::aot::IValue::create(arr)});
args.insert({"arr0", taichi::lang::aot::IValue::create(arr0)});
args.insert({"arr1", taichi::lang::aot::IValue::create(arr1)});
args.insert({"base0", taichi::lang::aot::IValue::create(base0)});
args.insert({"base1", taichi::lang::aot::IValue::create(base1)});
args.insert({"base2", taichi::lang::aot::IValue::create(base2)});

run_graph->run(args);
exec.synchronize();

auto *data = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr));

std::vector<int32_t> cpu_data(ArrLength);

auto *data_0 = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr_0));

CUDADriver::get_instance().memcpy_device_to_host(
(void *)cpu_data.data(), (void *)data_0, ArrLength * sizeof(int32_t));

for (int i = 0; i < ArrLength; ++i) {
EXPECT_EQ(cpu_data[i], 3 * i + base0 + base1 + base2);
}

auto *data_1 = reinterpret_cast<int32_t *>(
exec.get_ndarray_alloc_info_ptr(devalloc_arr_1));

CUDADriver::get_instance().memcpy_device_to_host(
(void *)cpu_data.data(), (void *)data, ArrLength * sizeof(int32_t));
(void *)cpu_data.data(), (void *)data_1, ArrLength * sizeof(int32_t));

for (int i = 0; i < ArrLength; ++i) {
EXPECT_EQ(cpu_data[i], 3 * i + base0 + base1 + base2);
Expand Down
6 changes: 1 addition & 5 deletions tests/cpp/aot/python_scripts/comet_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,7 @@
count = ti.field(ti.i32, ())
img = ti.field(ti.f32, (res, res))

sym_arr = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
'arr',
ti.f32,
field_dim=3,
element_shape=())
sym_arr = ti.graph.Arg(ti.graph.ArgKind.NDARRAY, 'arr', dtype=ti.f32, ndim=3)
img_c = 4


Expand Down
Loading

0 comments on commit 0b136ea

Please sign in to comment.