diff --git a/src/auto_scheduler/cost_model.cc b/src/auto_scheduler/cost_model.cc index 03e778e9129bc..e753c08a6a739 100644 --- a/src/auto_scheduler/cost_model.cc +++ b/src/auto_scheduler/cost_model.cc @@ -31,9 +31,11 @@ TVM_REGISTER_OBJECT_TYPE(CostModelNode); TVM_REGISTER_OBJECT_TYPE(RandomModelNode); TVM_REGISTER_OBJECT_TYPE(PythonBasedModelNode); -void RandomNumber(size_t n, void* data) { +void RandomNumber(TVMArgs args, TVMRetValue* rv) { + int n = args[0]; + void* data = args[1]; float* fdata = reinterpret_cast(data); - for (size_t i = 0; i < n; i++) { + for (int i = 0; i < n; i++) { fdata[i] = static_cast(rand_r(nullptr)) / (static_cast(RAND_MAX)); } } @@ -44,11 +46,10 @@ RandomModel::RandomModel() { if (node->random_number_func == nullptr) { LOG(WARNING) << "auto_scheduler.cost_model.random_fill_float is not registered, " << "use C++ default random_number func instead."; - static TypedPackedFunc cost_model_random_number(RandomNumber); - node->random_number_func = &cost_model_random_number; - } else { - node->random_number_func = reinterpret_cast*>(f); + const auto pf = PackedFunc(RandomNumber); + f = &pf; } + node->random_number_func = reinterpret_cast*>(f); data_ = std::move(node); }