From 76a0245f4314b6fef523ea0fc0d5ddd1b7a47b4e Mon Sep 17 00:00:00 2001 From: Yingcheng <135535812+yingchen21@users.noreply.github.com> Date: Sun, 27 Oct 2024 22:51:14 +0800 Subject: [PATCH] impl wrapper for json args (#3) * impl wrapper for json args * fix * fix --------- Co-authored-by: Hongyi Jin --- src/runtime/contrib/nvshmem/init.cc | 33 +++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/runtime/contrib/nvshmem/init.cc b/src/runtime/contrib/nvshmem/init.cc index 2733c595720a5..051022ced8f1a 100644 --- a/src/runtime/contrib/nvshmem/init.cc +++ b/src/runtime/contrib/nvshmem/init.cc @@ -24,6 +24,8 @@ #include "../../cuda/cuda_common.h" +#include + namespace tvm { namespace runtime { @@ -79,9 +81,40 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers, int worker_id_start) { << ", npes=" << nvshmem_n_pes(); } +void InitNVSHMEMWrapper(String args) { + picojson::value v; + std::string err = picojson::parse(v, args); + if (!err.empty()) { + LOG(FATAL) << "JSON parse error: " << err; + } + + if (!v.is()) { + LOG(FATAL) << "JSON is not an object"; + } + + picojson::object& obj = v.get(); + + picojson::array uid_array = obj["uid"].get(); + std::vector uid_vector; + for (const auto& elem : uid_array) { + uid_vector.push_back(elem.get()); + } + + ShapeTuple uid_64(uid_vector); + + int num_workers = static_cast(obj["npes"].get()); + int worker_id_start = static_cast(obj["pe_start"].get()); + + InitNVSHMEM(uid_64, num_workers, worker_id_start); + +} + TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID); TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM); +TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper").set_body_typed(InitNVSHMEMWrapper); + + } // namespace runtime } // namespace tvm