Skip to content

Commit

Permalink
impl wrapper for json args (apache#3)
Browse files Browse the repository at this point in the history
* impl wrapper for json args

* fix

* fix

---------

Co-authored-by: Hongyi Jin <jinhongyi02@gmail.com>
  • Loading branch information
yingchen21 and jinhongyii authored Oct 27, 2024
1 parent 60caa16 commit 76a0245
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions src/runtime/contrib/nvshmem/init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include "../../cuda/cuda_common.h"

#include <picojson.h>

namespace tvm {
namespace runtime {

Expand Down Expand Up @@ -79,9 +81,40 @@ void InitNVSHMEM(ShapeTuple uid_64, int num_workers, int worker_id_start) {
<< ", npes=" << nvshmem_n_pes();
}

void InitNVSHMEMWrapper(String args) {
picojson::value v;
std::string err = picojson::parse(v, args);
if (!err.empty()) {
LOG(FATAL) << "JSON parse error: " << err;
}

if (!v.is<picojson::object>()) {
LOG(FATAL) << "JSON is not an object";
}

picojson::object& obj = v.get<picojson::object>();

picojson::array uid_array = obj["uid"].get<picojson::array>();
std::vector<int64_t> uid_vector;
for (const auto& elem : uid_array) {
uid_vector.push_back(elem.get<int64_t>());
}

ShapeTuple uid_64(uid_vector);

int num_workers = static_cast<int>(obj["npes"].get<int64_t>());
int worker_id_start = static_cast<int>(obj["pe_start"].get<int64_t>());

InitNVSHMEM(uid_64, num_workers, worker_id_start);

}

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_uid").set_body_typed(InitNVSHMEMUID);

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem").set_body_typed(InitNVSHMEM);

TVM_REGISTER_GLOBAL("runtime.disco.nvshmem.init_nvshmem_wrapper").set_body_typed(InitNVSHMEMWrapper);


} // namespace runtime
} // namespace tvm

0 comments on commit 76a0245

Please sign in to comment.