From fcbd8a7322b592e8a4716c5037ba5005dcd14c70 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Fri, 20 May 2022 12:10:55 +0800 Subject: [PATCH] Update base for Update on "[aot] Serialize built graph, deserialize and run." related: #4786 This PR demonstrates a minimal example of serializing a built graph, deserializing and running it. [ghstack-poisoned] --- taichi/program/graph.cpp | 10 +++++++--- tests/cpp/program/graph_test.cpp | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/taichi/program/graph.cpp b/taichi/program/graph.cpp index e527c93ee80a1..690cdef026912 100644 --- a/taichi/program/graph.cpp +++ b/taichi/program/graph.cpp @@ -76,13 +76,17 @@ void Graph::run( const aot::IValue &ival = found->second; if (ival.tag == aot::ArgKind::NDARRAY) { Ndarray *arr = reinterpret_cast(ival.val); - TI_ERROR_IF((symbolic_arg.tag != ival.tag) || - (symbolic_arg.element_shape != arr->shape), + TI_ERROR_IF(ival.tag != aot::ArgKind::NDARRAY, + "Required a ndarray for argument {}", symbolic_arg.name); + auto ndarray_elem_shape = std::vector( + arr->shape.end() - symbolic_arg.element_shape.size(), + arr->shape.end()); + TI_ERROR_IF(ndarray_elem_shape != symbolic_arg.element_shape, "Mismatched shape information for argument {}", symbolic_arg.name); set_runtime_ctx_ndarray(&ctx, i, arr); } else { - TI_ERROR_IF(symbolic_arg.tag != aot::ArgKind::SCALAR, + TI_ERROR_IF(ival.tag != aot::ArgKind::SCALAR, "Required a scalar for argument {}", symbolic_arg.name); ctx.set_arg(i, ival.val); } diff --git a/tests/cpp/program/graph_test.cpp b/tests/cpp/program/graph_test.cpp index 918f13d6a185c..5d7819714e6e6 100644 --- a/tests/cpp/program/graph_test.cpp +++ b/tests/cpp/program/graph_test.cpp @@ -30,7 +30,7 @@ TEST(GraphTest, SimpleGraphRun) { auto g = std::make_unique("test"); auto seq = g->seq(); auto arr_arg = aot::Arg{ - "arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {size}}; + "arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {}}; seq->emplace(ker1.get(), {arr_arg}); seq->emplace(ker2.get(), {arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(),