From 3af2d12b0a227b7150bbb26c3167273717211535 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 19 Aug 2021 17:10:45 -0700 Subject: [PATCH 01/34] Add sampling function SampleInt. --- include/tvm/support/random_engine.h | 9 ++--- src/tir/schedule/concrete_schedule.h | 3 ++ src/tir/schedule/primitive.h | 12 +++++++ src/tir/schedule/primitive/sampling.cc | 39 ++++++++++++++++++++++ tests/cpp/sampling_test.cc | 46 ++++++++++++++++++++++++++ 5 files changed, 105 insertions(+), 4 deletions(-) create mode 100644 src/tir/schedule/primitive/sampling.cc create mode 100644 tests/cpp/sampling_test.cc diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index e73c1193f4c3..98932c65db17 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -19,7 +19,7 @@ /*! * \file random_engine.h - * \brief Random number generator, for Sampler and Sampling functions. + * \brief Random number generator, for Sampling functions. */ #ifndef TVM_SUPPORT_RANDOM_ENGINE_H_ @@ -41,10 +41,11 @@ namespace support { * included for simplification. For full member functions of std::minstd_rand, please check out the * following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine */ + class LinearCongruentialEngine { public: /*! - * \brief The result type is defined as int64_t here for meta_schedule sampler usage. + * \brief The result type is defined as int64_t here to avoid overflow. * \note The type name is not in Google style because it is used in STL's distribution inferface. */ using result_type = uint64_t; @@ -63,13 +64,13 @@ class LinearCongruentialEngine { * \brief The minimum possible value of random state here. * \note The function name is uncapilized because it is used in STL's distribution inferface. */ - result_type min() { return 0; } + static constexpr result_type min() { return 0; } /*! * \brief The maximum possible value of random state here. * \note The function name is uncapilized because it is used in STL's distribution inferface. */ - result_type max() { return modulus - 1; } + static constexpr result_type max() { return modulus - 1; } /*! * \brief Operator to move the random state to the next and return the new random state. According diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 97819d63edb6..736414db2508 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -43,6 +43,8 @@ class ConcreteScheduleNode : public ScheduleNode { TSymbolTable symbol_table_; /*! \brief A persistent stateless arithmetic analyzer. */ std::unique_ptr analyzer_; + /*! \brief The value of random state for sampling. */ + TRandState rand_state_; public: void VisitAttrs(tvm::AttrVisitor* v) { @@ -50,6 +52,7 @@ class ConcreteScheduleNode : public ScheduleNode { // `error_render_level_` is not visited // `symbol_table_` is not visited // `analyzer_` is not visited + // `rand_state_` is not visited } virtual ~ConcreteScheduleNode() = default; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 2cf59f0b27c0..0b34ef7f0750 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -19,12 +19,24 @@ #ifndef TVM_TIR_SCHEDULE_PRIMITIVE_H_ #define TVM_TIR_SCHEDULE_PRIMITIVE_H_ +#include #include namespace tvm { namespace tir { +using RandEngine = support::LinearCongruentialEngine; +using TRandState = RandEngine::TRandState; + /******** Schedule: Sampling ********/ +/*! + * \brief Sample an integer in [min_inclusive, max_exclusive) + * \param min_inclusive The left boundary, inclusive + * \param max_exclusive The right boundary, exclusive + * \return The integer sampled + */ +int SampleInt(TRandState* rand_state, int min_inclusive, int max_exclusive); + /******** Schedule: Get blocks & loops ********/ /*! * \brief Retrieves blocks in a specific function with its name diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc new file mode 100644 index 000000000000..2dc48ec76efb --- /dev/null +++ b/src/tir/schedule/primitive/sampling.cc @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include + +#include "../primitive.h" +#include "../utils.h" + +namespace tvm { +namespace tir { + +int SampleInt(TRandState* rand_state, int min_inclusive, int max_exclusive) { + RandEngine rand_(rand_state); + + if (min_inclusive + 1 == max_exclusive) { + return min_inclusive; + } + std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); + return dist(rand_); +} + +} // namespace tir +} // namespace tvm \ No newline at end of file diff --git a/tests/cpp/sampling_test.cc b/tests/cpp/sampling_test.cc new file mode 100644 index 000000000000..fdf1c42f9977 --- /dev/null +++ b/tests/cpp/sampling_test.cc @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include "../../src/tir/schedule/primitive.h" + +TEST(SamplingFunction, Coverage) { + tvm::tir::TRandState rand_state; + tvm::tir::RandEngine(&rand_state).Seed(20210819); + + bool covered[100]; + memset(covered, 0, sizeof(covered)); + + for (int i = 0; i < 10000; i++) { + int x = tvm::tir::SampleInt(&rand_state, 0, 100); + covered[x] = true; + } + + for (int i = 0; i < 100; i++) { + ICHECK(covered[i]) << "Coverage Test Failed"; + } +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} From 121e658ea504c275b83334320c9dee276ae24e65 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 19 Aug 2021 17:12:53 -0700 Subject: [PATCH 02/34] Add new line. --- src/tir/schedule/primitive/sampling.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 2dc48ec76efb..eac6f9c61a40 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -36,4 +36,4 @@ int SampleInt(TRandState* rand_state, int min_inclusive, int max_exclusive) { } } // namespace tir -} // namespace tvm \ No newline at end of file +} // namespace tvm From 1ca4aed8e154d01c0323344d67c1400b1bbdde48 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 19 Aug 2021 17:14:35 -0700 Subject: [PATCH 03/34] Fix comment. --- include/tvm/support/random_engine.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 98932c65db17..9cfbf8298b74 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -45,7 +45,7 @@ namespace support { class LinearCongruentialEngine { public: /*! - * \brief The result type is defined as int64_t here to avoid overflow. + * \brief The result type is defined as uint64_t here to avoid overflow. * \note The type name is not in Google style because it is used in STL's distribution inferface. */ using result_type = uint64_t; From fcf1d67e23cafa2c031d8d83787b3b70953d717b Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 19 Aug 2021 17:15:57 -0700 Subject: [PATCH 04/34] Update include/tvm/support/random_engine.h Co-authored-by: Junru Shao --- include/tvm/support/random_engine.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 9cfbf8298b74..17c2485fd7b8 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -19,7 +19,7 @@ /*! * \file random_engine.h - * \brief Random number generator, for Sampling functions. + * \brief Random number generator. It provides a generic interface consistent with `std::uniform_random_bit_generator` */ #ifndef TVM_SUPPORT_RANDOM_ENGINE_H_ From ce6e806f01c873ce34d61e76a38cdf134f96578a Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Thu, 19 Aug 2021 17:28:15 -0700 Subject: [PATCH 05/34] Minor fix. --- src/tir/schedule/primitive/sampling.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index eac6f9c61a40..c3e825207862 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -26,11 +26,12 @@ namespace tvm { namespace tir { int SampleInt(TRandState* rand_state, int min_inclusive, int max_exclusive) { - RandEngine rand_(rand_state); - if (min_inclusive + 1 == max_exclusive) { return min_inclusive; } + + RandEngine rand_(rand_state); + std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); return dist(rand_); } From d546c387d6e05f85a8b1e55c3e876ecabbc102a3 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 20 Aug 2021 16:01:22 -0700 Subject: [PATCH 06/34] Fix schedules & add SampleCategorical func. --- include/tvm/tir/schedule/schedule.h | 13 +++-- python/tvm/tir/schedule/schedule.py | 13 +++++ src/tir/schedule/concrete_schedule.cc | 15 ++++- src/tir/schedule/concrete_schedule.h | 11 +++- src/tir/schedule/primitive.h | 17 +++--- src/tir/schedule/primitive/sampling.cc | 79 +++++++++++++++++++++++--- src/tir/schedule/schedule.cc | 12 ++-- src/tir/schedule/traced_schedule.cc | 19 ++++++- src/tir/schedule/traced_schedule.h | 9 +++ src/tir/schedule/utils.h | 68 ++++++++++++++++++++++ tests/cpp/sampling_test.cc | 46 --------------- 11 files changed, 229 insertions(+), 73 deletions(-) delete mode 100644 tests/cpp/sampling_test.cc diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 5e223c98d74d..ed44c9cdd36a 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -19,6 +19,7 @@ #ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_ #define TVM_TIR_SCHEDULE_SCHEDULE_H_ +#include #include #include @@ -184,6 +185,8 @@ class ScheduleNode : public runtime::Object { public: /******** Schedule: Sampling ********/ + virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) = 0; /******** Schedule: Get blocks & loops ********/ /*! * \brief Retrieve a block in a specific function with its name @@ -356,6 +359,7 @@ class Schedule : public runtime::ObjectRef { /*! * \brief Construct a concrete TensorIR schedule from an IRModule * \param mod The IRModule to be scheduled + * \param seed The seed value for schedule's random state * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. * \param error_render_level The level of error rendering @@ -365,11 +369,12 @@ class Schedule : public runtime::ObjectRef { * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Concrete(IRModule mod, int debug_mask, - ScheduleErrorRenderLevel error_render_level); + TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level); /*! * \brief Construct a traced concrete TensorIR schedule from an IRModule * \param mod The IRModule to be scheduled + * \param seed The seed value for schedule's random state * \param debug_mask Do extra correctness checking after the class creation * and each time after calling the Replace method. * \param error_render_level The level of error rendering @@ -379,8 +384,8 @@ class Schedule : public runtime::ObjectRef { * 1) VerifySRefTree * 2) VerifyCachedFlags */ - TVM_DLL static Schedule Traced(IRModule mod, int debug_mask, - ScheduleErrorRenderLevel error_render_level); + TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode); }; diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index c9cbf45b9055..48f5ef242505 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -268,6 +268,19 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None: ########## Schedule: Sampling ########## + def sample_categorical( + self, + candidates: List[int], + probs: List[float], + decision: Optional[int] = None, + ) -> ExprRV: + return _ffi_api.ScheduleSampleCategorical( # type: ignore # pylint: disable=no-member + self, + candidates, + probs, + decision, + ) + ########## Schedule: Get blocks & loops ########## def get_block( self, diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 084d0b0eec6a..e5dda33cbe30 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -21,13 +21,14 @@ namespace tvm { namespace tir { -Schedule Schedule::Concrete(IRModule mod, int debug_mask, - ScheduleErrorRenderLevel error_render_level) { +Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mask); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); + support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); return Schedule(std::move(n)); } @@ -208,6 +209,16 @@ Schedule ConcreteScheduleNode::Copy() const { } /******** Schedule: Schedule: Sampling ********/ +ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { + TVM_TIR_SCHEDULE_BEGIN(); + return CreateRV(static_cast( + tir::SampleCategorical(state_, &this->rand_state_, candidates, probs, &decision))); + TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); + throw; +} + /******** Schedule: Get blocks & loops ********/ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 736414db2508..047878d26f94 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -44,7 +44,7 @@ class ConcreteScheduleNode : public ScheduleNode { /*! \brief A persistent stateless arithmetic analyzer. */ std::unique_ptr analyzer_; /*! \brief The value of random state for sampling. */ - TRandState rand_state_; + support::LinearCongruentialEngine::TRandState rand_state_; public: void VisitAttrs(tvm::AttrVisitor* v) { @@ -78,6 +78,15 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ + /*! + * \brief Sample an integer given the probability distribution + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision + * \return The random variable sampled from candidates + */ + ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) override; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") override; Array GetLoops(const BlockRV& block_rv) override; diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 0b34ef7f0750..102151402029 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -25,17 +25,20 @@ namespace tvm { namespace tir { -using RandEngine = support::LinearCongruentialEngine; -using TRandState = RandEngine::TRandState; - /******** Schedule: Sampling ********/ /*! * \brief Sample an integer in [min_inclusive, max_exclusive) - * \param min_inclusive The left boundary, inclusive - * \param max_exclusive The right boundary, exclusive - * \return The integer sampled + * \param self The schedule to update + * \param rand_state The pointer to schedule's random state + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision + * \return The random variable sampled from candidates */ -int SampleInt(TRandState* rand_state, int min_inclusive, int max_exclusive); +TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, + support::LinearCongruentialEngine::TRandState* rand_state, + const Array& candidates, const Array& probs, + Optional* decision); /******** Schedule: Get blocks & loops ********/ /*! diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index c3e825207862..5fb0d19ab348 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -19,22 +19,87 @@ #include #include +#include + #include "../primitive.h" #include "../utils.h" namespace tvm { namespace tir { -int SampleInt(TRandState* rand_state, int min_inclusive, int max_exclusive) { - if (min_inclusive + 1 == max_exclusive) { - return min_inclusive; +std::function MakeMultinomial(support::LinearCongruentialEngine::TRandState* rand_state, + const std::vector& weights) { + support::LinearCongruentialEngine rand_(rand_state); + std::vector sums; + sums.reserve(weights.size()); + double sum = 0.0; + for (double w : weights) { + sums.push_back(sum += w); } + std::uniform_real_distribution dist(0.0, sum); + auto sampler = [rand_state, dist = std::move(dist), sums = std::move(sums)]() mutable -> int { + support::LinearCongruentialEngine rand_(rand_state); + double p = dist(rand_); + int idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); + int n = sums.size(); + CHECK_LE(0, idx); + CHECK_LE(idx, n); + return (idx == n) ? (n - 1) : idx; + }; + return sampler; +} - RandEngine rand_(rand_state); - - std::uniform_int_distribution<> dist(min_inclusive, max_exclusive - 1); - return dist(rand_); +int64_t SampleCategorical(tir::ScheduleState self, + support::LinearCongruentialEngine::TRandState* rand_state, + const Array& candidates, const Array& probs, + Optional* decision) { + int i = -1; + int n = candidates.size(); + if (decision->defined()) { + const auto* int_imm = decision->as(); + i = int_imm->value; + CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n + << ", but decision is: " << i; + } else { + i = MakeMultinomial(rand_state, AsVector(probs))(); + ICHECK(0 <= i && i < n); + } + *decision = Integer(i); + return candidates[i]; } +struct SampleCategoricalTraits : public UnpackedInstTraits { + static constexpr const char* kName = "SampleCategorical"; + static constexpr bool kIsPure = true; + + private: + static constexpr size_t kNumInputs = 0; + static constexpr size_t kNumAttrs = 2; + static constexpr size_t kNumDecisions = 1; + + static ExprRV UnpackedApplyToSchedule(Schedule sch, // + Array candidates, // + Array probs, // + Optional decision) { + return sch->SampleCategorical(candidates, probs, decision); + } + + static String UnpackedAsPython(Array outputs, // + Array candidates, // + Array probs, // + Optional decision) { + PythonAPICall py("sample_categorical"); + py.Input("candidates", candidates); + py.Input("probs", probs); + py.Decision(decision); + py.SingleOutput(outputs); + return py.Str(); + } + + friend struct UnpackedInstTraits; +}; + +TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 29681fdf0926..7bd36bee0c09 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -60,13 +60,15 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // TVM_REGISTER_GLOBAL("tir.schedule.BlockRV").set_body_typed([]() { return BlockRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.LoopRV").set_body_typed([]() { return LoopRV(); }); TVM_REGISTER_GLOBAL("tir.schedule.ConcreteSchedule") - .set_body_typed([](IRModule mod, int debug_mask, int error_render_level) -> Schedule { - return Schedule::Concrete(mod, debug_mask, + .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, int error_render_level) -> Schedule { + return Schedule::Concrete(mod, debug_mask, seed, static_cast(error_render_level)); }); TVM_REGISTER_GLOBAL("tir.schedule.TracedSchedule") - .set_body_typed([](IRModule mod, int debug_mask, int error_render_level) -> Schedule { - return Schedule::Traced(mod, debug_mask, + .set_body_typed([](IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, int error_render_level) -> Schedule { + return Schedule::Traced(mod, seed, debug_mask, static_cast(error_render_level)); }); @@ -117,6 +119,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") }); /******** (FFI) Sampling ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") + .set_body_method(&ScheduleNode::SampleCategorical); /******** (FFI) Get blocks & loops ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetBlock") .set_body_method(&ScheduleNode::GetBlock); diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index ae6a194b9888..f2a224f72ad3 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -21,14 +21,15 @@ namespace tvm { namespace tir { -Schedule Schedule::Traced(IRModule mod, int debug_mask, - ScheduleErrorRenderLevel error_render_level) { +Schedule Schedule::Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed, + int debug_mask, ScheduleErrorRenderLevel error_render_level) { ObjectPtr n = make_object(); n->state_ = ScheduleState(mod, debug_mask); n->error_render_level_ = error_render_level; n->symbol_table_ = {}; n->analyzer_ = std::make_unique(); n->trace_ = Trace(); + support::LinearCongruentialEngine(&n->rand_state_).Seed(seed); return Schedule(std::move(n)); } @@ -42,6 +43,20 @@ Schedule TracedScheduleNode::Copy() const { } /******** Schedule: Sampling ********/ +ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, + const Array& probs, + Optional decision) { + ExprRV result = CreateRV(static_cast( + tir::SampleCategorical(this->state_, &this->rand_state_, candidates, probs, &decision))); + + static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); + trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // + /*inputs=*/{}, + /*attrs=*/{candidates, probs}, + /*outputs=*/{result}), + /*decision=*/decision); + return result; +} /******** Schedule: Get blocks & loops ********/ diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 11128ba32fad..90b4a40d912d 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -47,6 +47,15 @@ class TracedScheduleNode : public ConcreteScheduleNode { public: /******** Schedule: Sampling ********/ + /*! + * \brief Sample an integer given the probability distribution + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision + * \return The random variable sampled from candidates + */ + ExprRV SampleCategorical(const Array& candidates, const Array& probs, + Optional decision = NullOpt) final; /******** Schedule: Get blocks & loops ********/ BlockRV GetBlock(const String& name, const String& func_name = "main") final; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 8ccf8da731b5..e1d0b734c6a3 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -117,6 +117,15 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { return loops; } +/*! + * \brief Convert a tvm::runtime::Array to std::vector + * \tparam TSrc The type of elements in the source Array + * \tparam TDst The type of elements in the result vector + * \return The result vector + */ +template +std::vector AsVector(const Array& vec); + /******** Storage scope ********/ /*! @@ -192,6 +201,65 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } +/**************** AsVector ****************/ + +namespace details { + +template +struct AsVectorImpl {}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + return std::vector(vec.begin(), vec.end()); + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + std::vector results; + for (const TSrcObjectRef& x : vec) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + std::vector results; + for (const TSrcObjectRef& x : vec) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& array) const { + std::vector results; + for (const TSrcObjectRef& x : array) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; +} // namespace details + +template +inline std::vector AsVector(const Array& vec) { + return details::AsVectorImpl()(vec); +} + } // namespace tir } // namespace tvm diff --git a/tests/cpp/sampling_test.cc b/tests/cpp/sampling_test.cc deleted file mode 100644 index fdf1c42f9977..000000000000 --- a/tests/cpp/sampling_test.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include - -#include "../../src/tir/schedule/primitive.h" - -TEST(SamplingFunction, Coverage) { - tvm::tir::TRandState rand_state; - tvm::tir::RandEngine(&rand_state).Seed(20210819); - - bool covered[100]; - memset(covered, 0, sizeof(covered)); - - for (int i = 0; i < 10000; i++) { - int x = tvm::tir::SampleInt(&rand_state, 0, 100); - covered[x] = true; - } - - for (int i = 0; i < 100; i++) { - ICHECK(covered[i]) << "Coverage Test Failed"; - } -} - -int main(int argc, char** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} From b7542f1ac1e7dfb63606c2f568ffdb231974df2d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 20 Aug 2021 16:10:28 -0700 Subject: [PATCH 07/34] Fix SampleCategorical brief. --- src/tir/schedule/primitive.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index 102151402029..b3968d863b55 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -27,12 +27,12 @@ namespace tir { /******** Schedule: Sampling ********/ /*! - * \brief Sample an integer in [min_inclusive, max_exclusive) + * \brief Sample once category from candidates according to the probability weights. * \param self The schedule to update * \param rand_state The pointer to schedule's random state * \param candidates The candidates * \param probs The probability distribution of the candidates - * \param decision The sampling decision + * \param decision The sampling decision, if any * \return The random variable sampled from candidates */ TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, From cb0d96ec311f44181767c0f233fbfa37bf01e2fa Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 20 Aug 2021 17:23:30 -0700 Subject: [PATCH 08/34] Fix issues and make FFI work. --- include/tvm/tir/schedule/schedule.h | 24 +++++++++++++++++------- python/tvm/tir/schedule/schedule.py | 16 ++++++++++++++++ src/tir/schedule/concrete_schedule.cc | 10 ++++++++++ src/tir/schedule/concrete_schedule.h | 7 +++++++ src/tir/schedule/schedule.cc | 4 ++-- src/tir/schedule/traced_schedule.cc | 1 - src/tir/schedule/utils.h | 1 + 7 files changed, 53 insertions(+), 10 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index ed44c9cdd36a..2ab16330d0fc 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -115,13 +115,6 @@ class ScheduleNode : public runtime::Object { * reconstructed */ virtual Schedule Copy() const = 0; - /*! - * \brief Seed the randomness - * \param seed The new random seed, -1 if use device random, otherwise non-negative - */ - virtual void Seed(int64_t seed = -1) { - LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed"; - } public: /******** Lookup/Remove random variables ********/ @@ -185,8 +178,25 @@ class ScheduleNode : public runtime::Object { public: /******** Schedule: Sampling ********/ + /*! + * \brief Seed the randomness + * \param seed The new random seed, -1 if use device random, otherwise non-negative + */ + virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) { + LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed"; + } + /*! \brief Fork the random state */ + virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0; + /*! + * \brief Sample an integer given the probability distribution + * \param candidates The candidates + * \param probs The probability distribution of the candidates + * \param decision The sampling decision + * \return The random variable sampled from candidates + */ virtual ExprRV SampleCategorical(const Array& candidates, const Array& probs, Optional decision = NullOpt) = 0; + /******** Schedule: Get blocks & loops ********/ /*! * \brief Retrieve a block in a specific function with its name diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 48f5ef242505..e5351087fd2a 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -98,6 +98,7 @@ def __init__( self, mod: Union[PrimFunc, IRModule], *, + seed: int = -1, debug_mask: Union[str, int] = "none", error_render_level: str = "detail", ) -> None: @@ -107,6 +108,8 @@ def __init__( ---------- mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to be scheduled + seed: int + The seed value for schedule's random state debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. @@ -130,6 +133,7 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member _parse_mod(mod), + seed, _parse_debug_mask(debug_mask), _parse_error_render_level(error_render_level), ) @@ -138,12 +142,14 @@ def __init__( def _create_non_traced( mod: Union[PrimFunc, IRModule], *, + seed: int = -1, debug_mask: Union[str, int] = "none", error_render_level: str = "detail", ) -> "Schedule": """Construct a non-traced TensorIR schedule class from an IRModule.""" return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member _parse_mod(mod), + seed, _parse_debug_mask(debug_mask), _parse_error_render_level(error_render_level), ) @@ -190,6 +196,16 @@ def seed(self, seed: int) -> None: """ return _ffi_api.ScheduleSeed(self, seed) # type: ignore # pylint: disable=no-member + def fork_seed(self) -> int: + """Returns a forked random state as seed for new schedules + + Returns + ------- + seed : int + The forked random state, not the same as the current random state + """ + return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member + def show(self, rand_var: RAND_VAR_TYPE) -> str: """Returns a string representation of the value that the random variable evaluates to diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index e5dda33cbe30..c163988cf400 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -209,6 +209,16 @@ Schedule ConcreteScheduleNode::Copy() const { } /******** Schedule: Schedule: Sampling ********/ + +void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) { + support::LinearCongruentialEngine(&rand_state_).Seed(seed == -1 ? std::random_device()() : seed); +} +support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { + // In order for reproducibility, we computer the new seed using RNG's random state and a different + // set of parameters. Note that both 32767 and 1999999973 are prime numbers. + return (rand_state_ * 32767) % 1999999973; +} + ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 047878d26f94..4cf5f3f957f4 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -78,6 +78,13 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ + /*! + * \brief Seed the randomness + * \param seed The new random seed, -1 if use device random, otherwise non-negative + */ + void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final; + /*! \brief Fork the random state */ + support::LinearCongruentialEngine::TRandState ForkSeed() final; /*! * \brief Sample an integer given the probability distribution * \param candidates The candidates diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 7bd36bee0c09..21408723420c 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -50,8 +50,6 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // .set_body_method(&ScheduleNode::state); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // .set_body_method(&ScheduleNode::trace); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // - .set_body_method(&ScheduleNode::Seed); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // .set_body_method(&ScheduleNode::Copy); @@ -119,6 +117,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") }); /******** (FFI) Sampling ********/ +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed").set_body_method(&ScheduleNode::Seed); +TVM_REGISTER_GLOBAL("tir.schedule.ForkSeed").set_body_method(&ScheduleNode::ForkSeed); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); /******** (FFI) Get blocks & loops ********/ diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index f2a224f72ad3..794da3058ac9 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -48,7 +48,6 @@ ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, Optional decision) { ExprRV result = CreateRV(static_cast( tir::SampleCategorical(this->state_, &this->rand_state_, candidates, probs, &decision))); - static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // /*inputs=*/{}, diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index e1d0b734c6a3..fe3ab15a5f85 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -31,6 +31,7 @@ #include #include +#include #include #include From d735664bbb264cd2b2f9427189cc55fe594a2f6d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 20 Aug 2021 17:26:49 -0700 Subject: [PATCH 09/34] Fix ffi name. --- src/tir/schedule/schedule.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 21408723420c..13b84621935d 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -118,7 +118,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") /******** (FFI) Sampling ********/ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed").set_body_method(&ScheduleNode::Seed); -TVM_REGISTER_GLOBAL("tir.schedule.ForkSeed").set_body_method(&ScheduleNode::ForkSeed); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") + .set_body_method(&ScheduleNode::ForkSeed); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); /******** (FFI) Get blocks & loops ********/ From 35aa00c684011d72b88b363cb04ff24f9513e49c Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 20 Aug 2021 17:33:43 -0700 Subject: [PATCH 10/34] Test sample categorical. --- .../unittest/test_tir_schedule_sampling.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 tests/python/unittest/test_tir_schedule_sampling.py diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py new file mode 100644 index 000000000000..2a377e254c77 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +from collections import defaultdict +import pytest +import tvm +from tvm import tir +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def elementwise(a: ty.handle, b: ty.handle) -> None: + A = tir.match_buffer(a, (128, 128, 128)) + B = tir.match_buffer(b, (128, 128, 128)) + with tir.block([128, 128, 128], "B") as [vi, vj, vk]: + B[vi, vj, vk] = A[vi, vj, vk] * 2.0 + + +# pylint: disable=no-member,invalid-name,unused-variable + + +def test_fuse_sample_categorical(): + sch = tir.Schedule(elementwise, seed=42) + + n = 100 + candidates = [5, 2, 7, 1] + probs = [0.15, 0.55, 0.05, 0.25] + sch.get(sch.sample_categorical(candidates, probs)) + counter = defaultdict(int) + + for _ in range(n): + v = sch.get(sch.sample_categorical(candidates, probs)) + counter[v] += 1 + + +if __name__ == "__main__": + # sys.exit(pytest.main([__file__] + sys.argv[1:])) + test_fuse_sample_categorical() From 91454b1267c21c9708cd327b98925d16e31d2d4a Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 20 Aug 2021 18:42:59 -0700 Subject: [PATCH 11/34] Add Integer type change. --- src/tir/schedule/concrete_schedule.cc | 4 ++-- src/tir/schedule/traced_schedule.cc | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index c163988cf400..e40708b85090 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -223,8 +223,8 @@ ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(static_cast( - tir::SampleCategorical(state_, &this->rand_state_, candidates, probs, &decision))); + return CreateRV( + Integer(tir::SampleCategorical(state_, &this->rand_state_, candidates, probs, &decision))); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); throw; } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 794da3058ac9..d54e55d29085 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -46,7 +46,7 @@ Schedule TracedScheduleNode::Copy() const { ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { - ExprRV result = CreateRV(static_cast( + ExprRV result = CreateRV(Integer( tir::SampleCategorical(this->state_, &this->rand_state_, candidates, probs, &decision))); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // From 3b30c933dedcca22e62ede57818e0b8e7d6c9329 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sat, 21 Aug 2021 03:06:18 +0000 Subject: [PATCH 12/34] bugfix for xiyou --- src/tir/schedule/concrete_schedule.h | 29 ++++++---------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 4cf5f3f957f4..352e6c9e7bb6 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -145,17 +145,11 @@ class ConcreteScheduleNode : public ScheduleNode { template inline T CreateRV(const StmtSRef& sref); /*! - * \brief Add an expr as a random variable into the symbol table - * \param expr The expr to be added to the symbol table + * \brief Add an integer as a random variable into the symbol table + * \param value The integer to be added to the symbol table * \return The new random variable created */ - inline ExprRV CreateRV(const PrimExpr& expr); - /*! - * \brief Add expr as random variables into the symbol table - * \param exprs The expr to be added to the symbol table - * \return The new random variables created - */ - inline Array CreateRV(const Array& exprs); + inline ExprRV CreateRV(int64_t value); /*! \brief Remove a random variable from the symbol table */ inline void RemoveFromSymbolTable(const ObjectRef& rv); }; @@ -270,23 +264,12 @@ inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) { return std::move(rv); } -inline ExprRV ConcreteScheduleNode::CreateRV(const PrimExpr& expr) { - ExprRV rv; - this->symbol_table_.Set(rv, expr); +inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { + Var rv("v", DataType::Int(32)); + this->symbol_table_.Set(rv, Integer(static_cast(value))); return std::move(rv); } -inline Array ConcreteScheduleNode::CreateRV(const Array& exprs) { - Array result; - result.reserve(exprs.size()); - for (const PrimExpr& expr : exprs) { - ExprRV rv; - this->symbol_table_.Set(rv, expr); - result.push_back(rv); - } - return result; -} - inline void ConcreteScheduleNode::RemoveFromSymbolTable(const ObjectRef& obj) { auto it = this->symbol_table_.find(obj); if (it != this->symbol_table_.end()) { From ef4fda1c35b2518152701fb7f0aba99abf1c3b6c Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 20 Aug 2021 21:48:23 -0700 Subject: [PATCH 13/34] Make tests work. --- src/tir/schedule/concrete_schedule.cc | 3 +-- src/tir/schedule/traced_schedule.cc | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index e40708b85090..999f74bbf70c 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -223,8 +223,7 @@ ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV( - Integer(tir::SampleCategorical(state_, &this->rand_state_, candidates, probs, &decision))); + return CreateRV(tir::SampleCategorical(state_, &this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); throw; } diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index d54e55d29085..248f1a5a562f 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -46,8 +46,8 @@ Schedule TracedScheduleNode::Copy() const { ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { - ExprRV result = CreateRV(Integer( - tir::SampleCategorical(this->state_, &this->rand_state_, candidates, probs, &decision))); + ExprRV result = CreateRV( + tir::SampleCategorical(this->state_, &this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // /*inputs=*/{}, From 6e740b4756c66eeee7246b6c7fb8680dbe11cebf Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Fri, 20 Aug 2021 22:17:25 -0700 Subject: [PATCH 14/34] Fixed sampling test. --- .../unittest/test_tir_schedule_sampling.py | 48 +++++++++++++++---- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index 2a377e254c77..f0c3d9f70639 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -15,15 +15,33 @@ # specific language governing permissions and limitations # under the License. import sys +from typing import Union from collections import defaultdict + import pytest import tvm from tvm import tir from tvm.script import ty +from tvm.ir import IRModule +from tvm.tir import PrimFunc +from tvm.tir.schedule import Trace + # pylint: disable=no-member,invalid-name,unused-variable +def _check_serialization(sch: tir.Schedule, mod: Union[PrimFunc, IRModule]) -> tir.Schedule: + record = sch.trace.as_json() + new_sch = tir.Schedule(mod=mod) + Trace.apply_json_to_schedule(record, sch=new_sch) + assert tvm.ir.structural_equal(new_sch.mod, sch.mod) + py_repr = "\n".join(sch.trace.as_python()) + new_py_repr = "\n".join(new_sch.trace.as_python()) + assert py_repr == new_py_repr + # print(py_repr) + return new_sch + + @tvm.script.tir def elementwise(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) @@ -32,23 +50,35 @@ def elementwise(a: ty.handle, b: ty.handle) -> None: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 -# pylint: disable=no-member,invalid-name,unused-variable - - def test_fuse_sample_categorical(): + """Test sample categprical sampling function""" + n = 1000 sch = tir.Schedule(elementwise, seed=42) - - n = 100 + counter = defaultdict(int) candidates = [5, 2, 7, 1] probs = [0.15, 0.55, 0.05, 0.25] - sch.get(sch.sample_categorical(candidates, probs)) - counter = defaultdict(int) + for _ in range(n): + v = sch.get(sch.sample_categorical(candidates, probs)) + counter[v] += 1 + for i, prob in enumerate(probs): + assert (prob - 0.07) * n <= counter[candidates[i]] <= (prob + 0.07) * n + _check_serialization(sch, mod=elementwise) + +@pytest.mark.xfail +def test_fuse_sample_categorical_out_of_range(): + """Test sample categprical sampling function""" + n = 1000 + sch = tir.Schedule(elementwise, seed=42) + counter = defaultdict(int) + candidates = [5, 2, 7, 1] + probs = [0.2, 0.2, 0.2, 0.2, 0.2] for _ in range(n): v = sch.get(sch.sample_categorical(candidates, probs)) counter[v] += 1 + for i, prob in enumerate(probs): + assert (prob - 0.07) * n <= counter[candidates[i]] <= (prob + 0.07) * n if __name__ == "__main__": - # sys.exit(pytest.main([__file__] + sys.argv[1:])) - test_fuse_sample_categorical() + sys.exit(pytest.main([__file__] + sys.argv[1:])) From 73f57e87abbea58ea6339776e8c303a7bbbb4a38 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sat, 21 Aug 2021 22:38:19 -0700 Subject: [PATCH 15/34] Add seed value guard in python side. --- python/tvm/tir/schedule/schedule.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index e5351087fd2a..04e2a5ae00c8 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -79,6 +79,16 @@ def _parse_error_render_level(error_render_level: str) -> int: return _ERROR_RENDER_LEVEL.get(error_render_level) +def _parse_seed(seed: Optional[int]) -> int: + if seed is None: + return -1 + if not isinstance(seed, int): + raise TypeError(f"Expected `seed` to be int or None, but gets: {seed}") + if seed < 1 or seed > 2147483647: + raise ValueError(f"seed must be in the range [1, 2147483647], but gets: {seed}") + return seed + + @_register_object("tir.Schedule") class Schedule(Object): """The user-facing schedule class @@ -98,7 +108,7 @@ def __init__( self, mod: Union[PrimFunc, IRModule], *, - seed: int = -1, + seed: Optional[int] = None, debug_mask: Union[str, int] = "none", error_render_level: str = "detail", ) -> None: @@ -108,8 +118,10 @@ def __init__( ---------- mod : Union[PrimFunc, IRModule] The IRModule or PrimFunc to be scheduled - seed: int + seed: Optional[int] The seed value for schedule's random state + Note that None and -1 means use device random, otherwise only integer between 1 and + 2147483647 is allowed. debug_mask : Union[str, int] Do extra correctness checking after the class creation and each time after calling the Replace method. @@ -133,7 +145,7 @@ def __init__( self.__init_handle_by_constructor__( _ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member _parse_mod(mod), - seed, + _parse_seed(seed), _parse_debug_mask(debug_mask), _parse_error_render_level(error_render_level), ) @@ -142,14 +154,14 @@ def __init__( def _create_non_traced( mod: Union[PrimFunc, IRModule], *, - seed: int = -1, + seed: Optional[int] = None, debug_mask: Union[str, int] = "none", error_render_level: str = "detail", ) -> "Schedule": """Construct a non-traced TensorIR schedule class from an IRModule.""" return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member _parse_mod(mod), - seed, + _parse_seed(seed), _parse_debug_mask(debug_mask), _parse_error_render_level(error_render_level), ) From e937881c49e1e8845daa6a9f3834b06ba4aeca1d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sat, 21 Aug 2021 22:45:21 -0700 Subject: [PATCH 16/34] Fix ForkSeed func. --- recomp.sh | 8 ++++++++ src/tir/schedule/concrete_schedule.cc | 2 +- src/tir/schedule/primitive/sampling.cc | 2 -- 3 files changed, 9 insertions(+), 3 deletions(-) create mode 100644 recomp.sh diff --git a/recomp.sh b/recomp.sh new file mode 100644 index 000000000000..b45f0ae84a88 --- /dev/null +++ b/recomp.sh @@ -0,0 +1,8 @@ +mv build/config.cmake ./; +rm -rf build; +mkdir build; +mv config.cmake build/; +cd build; +cmake ..; +make -j40 + diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 999f74bbf70c..99d6c5276127 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -216,7 +216,7 @@ void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState se support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { // In order for reproducibility, we computer the new seed using RNG's random state and a different // set of parameters. Note that both 32767 and 1999999973 are prime numbers. - return (rand_state_ * 32767) % 1999999973; + return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 1999999973; } ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 5fb0d19ab348..4dba797c7d9a 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -16,8 +16,6 @@ * specific language governing permissions and limitations * under the License. */ -#include -#include #include From 0da789435d03218e79d9677e1a59f27e552f0be9 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sat, 21 Aug 2021 22:52:06 -0700 Subject: [PATCH 17/34] Remove extra files, add size check and update test. --- recomp.sh | 8 ----- src/tir/schedule/concrete_schedule.cc | 2 ++ .../unittest/test_tir_schedule_sampling.py | 30 ++----------------- 3 files changed, 4 insertions(+), 36 deletions(-) delete mode 100644 recomp.sh diff --git a/recomp.sh b/recomp.sh deleted file mode 100644 index b45f0ae84a88..000000000000 --- a/recomp.sh +++ /dev/null @@ -1,8 +0,0 @@ -mv build/config.cmake ./; -rm -rf build; -mkdir build; -mv config.cmake build/; -cd build; -cmake ..; -make -j40 - diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 99d6c5276127..521405967dea 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -222,6 +222,8 @@ support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { + CHECK(candidates.size() == probs.size()) + << "ValueError: number of candidates does not match number of probabilities."; TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::SampleCategorical(state_, &this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index f0c3d9f70639..ab4345f137ce 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -25,23 +25,12 @@ from tvm.ir import IRModule from tvm.tir import PrimFunc from tvm.tir.schedule import Trace +from tvm.tir.schedule.testing import verify_trace_roundtrip # pylint: disable=no-member,invalid-name,unused-variable -def _check_serialization(sch: tir.Schedule, mod: Union[PrimFunc, IRModule]) -> tir.Schedule: - record = sch.trace.as_json() - new_sch = tir.Schedule(mod=mod) - Trace.apply_json_to_schedule(record, sch=new_sch) - assert tvm.ir.structural_equal(new_sch.mod, sch.mod) - py_repr = "\n".join(sch.trace.as_python()) - new_py_repr = "\n".join(new_sch.trace.as_python()) - assert py_repr == new_py_repr - # print(py_repr) - return new_sch - - @tvm.script.tir def elementwise(a: ty.handle, b: ty.handle) -> None: A = tir.match_buffer(a, (128, 128, 128)) @@ -62,22 +51,7 @@ def test_fuse_sample_categorical(): counter[v] += 1 for i, prob in enumerate(probs): assert (prob - 0.07) * n <= counter[candidates[i]] <= (prob + 0.07) * n - _check_serialization(sch, mod=elementwise) - - -@pytest.mark.xfail -def test_fuse_sample_categorical_out_of_range(): - """Test sample categprical sampling function""" - n = 1000 - sch = tir.Schedule(elementwise, seed=42) - counter = defaultdict(int) - candidates = [5, 2, 7, 1] - probs = [0.2, 0.2, 0.2, 0.2, 0.2] - for _ in range(n): - v = sch.get(sch.sample_categorical(candidates, probs)) - counter[v] += 1 - for i, prob in enumerate(probs): - assert (prob - 0.07) * n <= counter[candidates[i]] <= (prob + 0.07) * n + verify_trace_roundtrip(sch, mod=elementwise) if __name__ == "__main__": From e75ed2b47a6984168fc05adb896948dbd1134f95 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sat, 21 Aug 2021 23:14:29 -0700 Subject: [PATCH 18/34] Update header and imports. --- src/tir/schedule/concrete_schedule.cc | 2 ++ src/tir/schedule/utils.h | 1 - tests/python/unittest/test_tir_schedule_sampling.py | 4 ---- 3 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 521405967dea..98ee291d2675 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -18,6 +18,8 @@ */ #include "./concrete_schedule.h" +#include + namespace tvm { namespace tir { diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index fe3ab15a5f85..e1d0b734c6a3 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -31,7 +31,6 @@ #include #include -#include #include #include diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index ab4345f137ce..aee3088d5b41 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -15,16 +15,12 @@ # specific language governing permissions and limitations # under the License. import sys -from typing import Union from collections import defaultdict import pytest import tvm from tvm import tir from tvm.script import ty -from tvm.ir import IRModule -from tvm.tir import PrimFunc -from tvm.tir.schedule import Trace from tvm.tir.schedule.testing import verify_trace_roundtrip From 9f40d16a615a9103a7f6c4688d7a2dd9baf849f9 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 22 Aug 2021 02:20:00 -0700 Subject: [PATCH 19/34] Move AsVector func. --- src/support/array.h | 68 ++++++++++++++++++++++++++ src/tir/schedule/primitive/sampling.cc | 3 +- src/tir/schedule/utils.h | 68 -------------------------- 3 files changed, 70 insertions(+), 69 deletions(-) diff --git a/src/support/array.h b/src/support/array.h index 2cf416c471ec..3fd70b980151 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -18,6 +18,7 @@ */ #ifndef TVM_SUPPORT_ARRAY_H_ #define TVM_SUPPORT_ARRAY_H_ +#include #include #include @@ -67,6 +68,73 @@ inline bool ArrayWithSameContent(const std::vector& a, const std::vector return true; } +/*! + * \brief Convert a tvm::runtime::Array to std::vector + * \tparam TSrc The type of elements in the source Array + * \tparam TDst The type of elements in the result vector + * \return The result vector + */ +template +std::vector AsVector(const Array& vec); +/**************** AsVector ****************/ + +namespace details { + +template +struct AsVectorImpl {}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + return std::vector(vec.begin(), vec.end()); + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + std::vector results; + for (const TSrcObjectRef& x : vec) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& vec) const { + std::vector results; + for (const TSrcObjectRef& x : vec) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; + +template +struct AsVectorImpl { + inline std::vector operator()(const Array& array) const { + std::vector results; + for (const TSrcObjectRef& x : array) { + const auto* n = x.template as(); + ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); + results.push_back(n->value); + } + return results; + } +}; +} // namespace details + +template +inline std::vector AsVector(const Array& vec) { + return details::AsVectorImpl()(vec); +} + } // namespace support } // namespace tvm #endif // TVM_SUPPORT_ARRAY_H_ diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 4dba797c7d9a..530c247f6a7c 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -19,6 +19,7 @@ #include +#include "../../../support/array.h" #include "../primitive.h" #include "../utils.h" @@ -59,7 +60,7 @@ int64_t SampleCategorical(tir::ScheduleState self, CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n << ", but decision is: " << i; } else { - i = MakeMultinomial(rand_state, AsVector(probs))(); + i = MakeMultinomial(rand_state, support::AsVector(probs))(); ICHECK(0 <= i && i < n); } *decision = Integer(i); diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index e1d0b734c6a3..8ccf8da731b5 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -117,15 +117,6 @@ inline Array LoopSRefs2Loops(const Array& loop_srefs) { return loops; } -/*! - * \brief Convert a tvm::runtime::Array to std::vector - * \tparam TSrc The type of elements in the source Array - * \tparam TDst The type of elements in the result vector - * \return The result vector - */ -template -std::vector AsVector(const Array& vec); - /******** Storage scope ********/ /*! @@ -201,65 +192,6 @@ inline Map AsIntSet(const Map& var_dom) { return {result.begin(), result.end()}; } -/**************** AsVector ****************/ - -namespace details { - -template -struct AsVectorImpl {}; - -template -struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { - return std::vector(vec.begin(), vec.end()); - } -}; - -template -struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { - std::vector results; - for (const TSrcObjectRef& x : vec) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); - } - return results; - } -}; - -template -struct AsVectorImpl { - inline std::vector operator()(const Array& vec) const { - std::vector results; - for (const TSrcObjectRef& x : vec) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); - } - return results; - } -}; - -template -struct AsVectorImpl { - inline std::vector operator()(const Array& array) const { - std::vector results; - for (const TSrcObjectRef& x : array) { - const auto* n = x.template as(); - ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); - results.push_back(n->value); - } - return results; - } -}; -} // namespace details - -template -inline std::vector AsVector(const Array& vec) { - return details::AsVectorImpl()(vec); -} - } // namespace tir } // namespace tvm From 83cff2fdfa68ff41f4fa46c5bec7cf2de6d96eba Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 22 Aug 2021 02:29:08 -0700 Subject: [PATCH 20/34] Move Seed and ForkSeed definition & ffi position. --- include/tvm/tir/schedule/schedule.h | 16 +++++++--------- src/tir/schedule/concrete_schedule.h | 9 ++------- src/tir/schedule/schedule.cc | 7 ++++--- 3 files changed, 13 insertions(+), 19 deletions(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 2ab16330d0fc..79fed09c3e36 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -115,6 +115,13 @@ class ScheduleNode : public runtime::Object { * reconstructed */ virtual Schedule Copy() const = 0; + /*! + * \brief Seed the randomness + * \param seed The new random seed, -1 if use device random, otherwise non-negative + */ + virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0; + /*! \brief Fork the random state */ + virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0; public: /******** Lookup/Remove random variables ********/ @@ -178,15 +185,6 @@ class ScheduleNode : public runtime::Object { public: /******** Schedule: Sampling ********/ - /*! - * \brief Seed the randomness - * \param seed The new random seed, -1 if use device random, otherwise non-negative - */ - virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) { - LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed"; - } - /*! \brief Fork the random state */ - virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0; /*! * \brief Sample an integer given the probability distribution * \param candidates The candidates diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 352e6c9e7bb6..b1c1cab93a93 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -61,6 +61,8 @@ class ConcreteScheduleNode : public ScheduleNode { ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } Schedule Copy() const override; + void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final; + support::LinearCongruentialEngine::TRandState ForkSeed() final; public: /******** Lookup random variables ********/ @@ -78,13 +80,6 @@ class ConcreteScheduleNode : public ScheduleNode { public: /******** Schedule: Sampling ********/ - /*! - * \brief Seed the randomness - * \param seed The new random seed, -1 if use device random, otherwise non-negative - */ - void Seed(support::LinearCongruentialEngine::TRandState seed = -1) final; - /*! \brief Fork the random state */ - support::LinearCongruentialEngine::TRandState ForkSeed() final; /*! * \brief Sample an integer given the probability distribution * \param candidates The candidates diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index 13b84621935d..d24cdc625912 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -52,6 +52,10 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // .set_body_method(&ScheduleNode::trace); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // .set_body_method(&ScheduleNode::Copy); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // + .set_body_method(&ScheduleNode::Seed); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") // + .set_body_method(&ScheduleNode::ForkSeed); /**************** (FFI) Constructor ****************/ @@ -117,9 +121,6 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleRemoveRV") }); /******** (FFI) Sampling ********/ -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed").set_body_method(&ScheduleNode::Seed); -TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") - .set_body_method(&ScheduleNode::ForkSeed); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSampleCategorical") .set_body_method(&ScheduleNode::SampleCategorical); /******** (FFI) Get blocks & loops ********/ From 3d2d5d2d7952657e9eb392e92dd4788402c92214 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 22 Aug 2021 20:56:19 -0700 Subject: [PATCH 21/34] Modify Sample Categorical func to work with discrete distribution. --- src/tir/schedule/concrete_schedule.cc | 2 -- src/tir/schedule/primitive/sampling.cc | 34 ++++++++------------------ 2 files changed, 10 insertions(+), 26 deletions(-) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 98ee291d2675..44758fba52d9 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -224,8 +224,6 @@ support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { - CHECK(candidates.size() == probs.size()) - << "ValueError: number of candidates does not match number of probabilities."; TVM_TIR_SCHEDULE_BEGIN(); return CreateRV(tir::SampleCategorical(state_, &this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 530c247f6a7c..74a31206273f 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -26,43 +26,29 @@ namespace tvm { namespace tir { -std::function MakeMultinomial(support::LinearCongruentialEngine::TRandState* rand_state, - const std::vector& weights) { - support::LinearCongruentialEngine rand_(rand_state); - std::vector sums; - sums.reserve(weights.size()); - double sum = 0.0; - for (double w : weights) { - sums.push_back(sum += w); - } - std::uniform_real_distribution dist(0.0, sum); - auto sampler = [rand_state, dist = std::move(dist), sums = std::move(sums)]() mutable -> int { - support::LinearCongruentialEngine rand_(rand_state); - double p = dist(rand_); - int idx = std::lower_bound(sums.begin(), sums.end(), p) - sums.begin(); - int n = sums.size(); - CHECK_LE(0, idx); - CHECK_LE(idx, n); - return (idx == n) ? (n - 1) : idx; - }; - return sampler; -} - int64_t SampleCategorical(tir::ScheduleState self, support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { + CHECK(candidates.size() == probs.size()) + << "ValueError: number of candidates does not match number of probabilities."; int i = -1; int n = candidates.size(); + if (decision->defined()) { const auto* int_imm = decision->as(); i = int_imm->value; CHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n << ", but decision is: " << i; } else { - i = MakeMultinomial(rand_state, support::AsVector(probs))(); - ICHECK(0 <= i && i < n); + std::vector weights = support::AsVector(probs); + std::discrete_distribution dist(weights.begin(), weights.end()); + support::LinearCongruentialEngine rand_(rand_state); + i = dist(rand_); + ICHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n + << ", but decision is: " << i; } + *decision = Integer(i); return candidates[i]; } From 8efe5a3bfb60de6e7a331474033de3d30738940b Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 22 Aug 2021 21:38:25 -0700 Subject: [PATCH 22/34] Remote unused argument ScheduleState self in function signature. --- src/tir/schedule/concrete_schedule.cc | 2 +- src/tir/schedule/primitive.h | 3 +-- src/tir/schedule/primitive/sampling.cc | 3 +-- src/tir/schedule/traced_schedule.cc | 4 ++-- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 44758fba52d9..227b1f4042b5 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -225,7 +225,7 @@ ExprRV ConcreteScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { TVM_TIR_SCHEDULE_BEGIN(); - return CreateRV(tir::SampleCategorical(state_, &this->rand_state_, candidates, probs, &decision)); + return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_); throw; } diff --git a/src/tir/schedule/primitive.h b/src/tir/schedule/primitive.h index b3968d863b55..be33c2acca10 100644 --- a/src/tir/schedule/primitive.h +++ b/src/tir/schedule/primitive.h @@ -35,8 +35,7 @@ namespace tir { * \param decision The sampling decision, if any * \return The random variable sampled from candidates */ -TVM_DLL int64_t SampleCategorical(tir::ScheduleState self, - support::LinearCongruentialEngine::TRandState* rand_state, +TVM_DLL int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision); diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 74a31206273f..4682b2f0b872 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -26,8 +26,7 @@ namespace tvm { namespace tir { -int64_t SampleCategorical(tir::ScheduleState self, - support::LinearCongruentialEngine::TRandState* rand_state, +int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_state, const Array& candidates, const Array& probs, Optional* decision) { CHECK(candidates.size() == probs.size()) diff --git a/src/tir/schedule/traced_schedule.cc b/src/tir/schedule/traced_schedule.cc index 248f1a5a562f..af4a6588f064 100644 --- a/src/tir/schedule/traced_schedule.cc +++ b/src/tir/schedule/traced_schedule.cc @@ -46,8 +46,8 @@ Schedule TracedScheduleNode::Copy() const { ExprRV TracedScheduleNode::SampleCategorical(const Array& candidates, const Array& probs, Optional decision) { - ExprRV result = CreateRV( - tir::SampleCategorical(this->state_, &this->rand_state_, candidates, probs, &decision)); + ExprRV result = + CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision)); static const InstructionKind& kind = InstructionKind::Get("SampleCategorical"); trace_->Append(/*inst=*/Instruction(/*kind=*/kind, // /*inputs=*/{}, From 93878a2184627f1e69ed9c249ab99ff944da1cde Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 22 Aug 2021 21:52:15 -0700 Subject: [PATCH 23/34] Renable pylint. --- tests/python/unittest/test_tir_schedule_sampling.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index aee3088d5b41..c841ee0a81f8 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -35,6 +35,9 @@ def elementwise(a: ty.handle, b: ty.handle) -> None: B[vi, vj, vk] = A[vi, vj, vk] * 2.0 +# pylint: enable=no-member,invalid-name,unused-variable + + def test_fuse_sample_categorical(): """Test sample categprical sampling function""" n = 1000 From b67f14adf1e55057bccab26a4e61ac18d21b0b30 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sun, 22 Aug 2021 21:56:10 -0700 Subject: [PATCH 24/34] Fix test name and debug mask. --- tests/python/unittest/test_tir_schedule_sampling.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index c841ee0a81f8..825453993cf5 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -38,10 +38,10 @@ def elementwise(a: ty.handle, b: ty.handle) -> None: # pylint: enable=no-member,invalid-name,unused-variable -def test_fuse_sample_categorical(): +def test_sample_categorical(): """Test sample categprical sampling function""" n = 1000 - sch = tir.Schedule(elementwise, seed=42) + sch = tir.Schedule(elementwise, seed=42, debug_mask="all") counter = defaultdict(int) candidates = [5, 2, 7, 1] probs = [0.15, 0.55, 0.05, 0.25] From 72e245602b62f05ddd09022d506c21290f77cf61 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 23 Aug 2021 11:25:42 -0700 Subject: [PATCH 25/34] Fix format clang. --- include/tvm/support/random_engine.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/tvm/support/random_engine.h b/include/tvm/support/random_engine.h index 17c2485fd7b8..6b733d074f6a 100644 --- a/include/tvm/support/random_engine.h +++ b/include/tvm/support/random_engine.h @@ -19,7 +19,8 @@ /*! * \file random_engine.h - * \brief Random number generator. It provides a generic interface consistent with `std::uniform_random_bit_generator` + * \brief Random number generator. It provides a generic interface consistent with + * `std::uniform_random_bit_generator` */ #ifndef TVM_SUPPORT_RANDOM_ENGINE_H_ From ce8e6bb56e715622de206ab24ae9a28a55473a1d Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 23 Aug 2021 14:00:34 -0700 Subject: [PATCH 26/34] Fix non-class template compilation problem. --- src/tir/schedule/primitive/sampling.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 4682b2f0b872..eda6dbbd094b 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -80,7 +80,7 @@ struct SampleCategoricalTraits : public UnpackedInstTraits; }; TVM_REGISTER_INST_KIND_TRAITS(SampleCategoricalTraits); From b2fffa022e7074d5bfea509d035ec5ced714b522 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 23 Aug 2021 15:01:19 -0700 Subject: [PATCH 27/34] Add copy & serialization test. --- .../unittest/test_tir_schedule_sampling.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index 825453993cf5..49bef9cfec24 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -22,6 +22,7 @@ from tvm import tir from tvm.script import ty from tvm.tir.schedule.testing import verify_trace_roundtrip +from tvm.tir.schedule import Trace # pylint: disable=no-member,invalid-name,unused-variable @@ -53,5 +54,42 @@ def test_sample_categorical(): verify_trace_roundtrip(sch, mod=elementwise) +def test_sample_categorical_copy(): + """Check the random variable sampling results after schedule copy""" + n = 100 + sch = tir.Schedule(elementwise, seed=42, debug_mask="all") + candidates = [1, 2, 3, 4] + probs = [0.1, 0.2, 0.3, 0.4] + rvs = [] + decisions = [] + for _ in range(n): + rv = sch.sample_categorical(candidates, probs) # pylint: disable=invalid-name + decision = sch.get(rv) + rvs.append(rv) + decisions.append(decision) + sch_copy = sch.copy() + for rv, decision in zip(rvs, decisions): # pylint: disable=invalid-name + decision_copy = sch_copy.get(rv) + assert int(decision) == int(decision_copy) + + +def test_sample_categorical_serialize(): + """Check the random variable sampling results after schedule serialization""" + n = 100 + sch = tir.Schedule(elementwise, seed=42, debug_mask="all") + candidates = [5, 6, 7, 8] + probs = [0.23, 0.19, 0.37, 0.21] + for _ in range(n): + sch.get(sch.sample_categorical(candidates, probs)) + trace = sch.trace + json_obj = trace.as_json() + new_sch = tir.Schedule(mod=elementwise, debug_mask="all") + Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) + assert len(sch.trace.insts) == len(new_sch.trace.insts) + for i, inst in enumerate(sch.trace.insts): + new_inst = new_sch.trace.insts[i] + assert sch.trace.decisions[inst] == new_sch.trace.decisions[new_inst] + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 414f440bfece3450f060532a2f6cea88ee8c2678 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 23 Aug 2021 22:51:38 -0700 Subject: [PATCH 28/34] Add docs on python-side function. --- python/tvm/tir/schedule/schedule.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 04e2a5ae00c8..c531eaefab84 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -302,6 +302,19 @@ def sample_categorical( probs: List[float], decision: Optional[int] = None, ) -> ExprRV: + """Sample an integer given the probability distribution + + Parameters + ---------- + candidates : The candidates to be sampled from + probs : The probability of each candidate + decision : The sampling decision, if any + + Returns + ------- + result : ExprRV + The random variable sampled from candidates + """ return _ffi_api.ScheduleSampleCategorical( # type: ignore # pylint: disable=no-member self, candidates, From d7a545e5a37565801d51d98a88963b766b670c28 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 23 Aug 2021 22:55:36 -0700 Subject: [PATCH 29/34] Minor fix. --- src/tir/schedule/primitive/sampling.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index eda6dbbd094b..159b8800809e 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -19,7 +19,6 @@ #include -#include "../../../support/array.h" #include "../primitive.h" #include "../utils.h" @@ -44,7 +43,7 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st std::discrete_distribution dist(weights.begin(), weights.end()); support::LinearCongruentialEngine rand_(rand_state); i = dist(rand_); - ICHECK(0 <= i && i < n) << "ValueError: Wrong decision value, where n = " << n + ICHECK(0 <= i && i < n) << "ValueError: Unexpected decision generated, where n = " << n << ", but decision is: " << i; } From 58de4a94dc9f0d64d66139c0ec11accc9034f337 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 23 Aug 2021 23:02:00 -0700 Subject: [PATCH 30/34] Minor fix. --- src/support/array.h | 2 +- src/tir/schedule/concrete_schedule.cc | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/support/array.h b/src/support/array.h index 3fd70b980151..89e17433344b 100644 --- a/src/support/array.h +++ b/src/support/array.h @@ -76,8 +76,8 @@ inline bool ArrayWithSameContent(const std::vector& a, const std::vector */ template std::vector AsVector(const Array& vec); -/**************** AsVector ****************/ +/********** Implementation details of AsVector **********/ namespace details { template diff --git a/src/tir/schedule/concrete_schedule.cc b/src/tir/schedule/concrete_schedule.cc index 227b1f4042b5..cd9aad8ae512 100644 --- a/src/tir/schedule/concrete_schedule.cc +++ b/src/tir/schedule/concrete_schedule.cc @@ -213,8 +213,12 @@ Schedule ConcreteScheduleNode::Copy() const { /******** Schedule: Schedule: Sampling ********/ void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) { - support::LinearCongruentialEngine(&rand_state_).Seed(seed == -1 ? std::random_device()() : seed); + if (seed == -1) { + seed = std::random_device()(); + } + support::LinearCongruentialEngine(&rand_state_).Seed(seed); } + support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() { // In order for reproducibility, we computer the new seed using RNG's random state and a different // set of parameters. Note that both 32767 and 1999999973 are prime numbers. From cb25711090a12c05f0de90f800088319d15f528e Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 23 Aug 2021 23:09:50 -0700 Subject: [PATCH 31/34] Modify ExprRV constructor from int64_t. --- src/tir/schedule/concrete_schedule.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index b1c1cab93a93..217a352609e5 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -260,7 +260,7 @@ inline T ConcreteScheduleNode::CreateRV(const StmtSRef& sref) { } inline ExprRV ConcreteScheduleNode::CreateRV(int64_t value) { - Var rv("v", DataType::Int(32)); + Var rv("v" + std::to_string(this->symbol_table_.size() + 1), DataType::Int(32)); this->symbol_table_.Set(rv, Integer(static_cast(value))); return std::move(rv); } From f9c5458edbac09296b709146494935ac9e68296b Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 23 Aug 2021 23:14:22 -0700 Subject: [PATCH 32/34] Fix docs. --- python/tvm/tir/schedule/schedule.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index c531eaefab84..9433d019f9a5 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -306,9 +306,12 @@ def sample_categorical( Parameters ---------- - candidates : The candidates to be sampled from - probs : The probability of each candidate - decision : The sampling decision, if any + candidates : List[int] + The candidates to be sampled from + probs : List[float] + The probability of each candidate + decision : Optional[int] + The sampling decision, if any Returns ------- From b541d49b97861243b552e75ea5081d92c8ad591f Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 23 Aug 2021 23:38:59 -0700 Subject: [PATCH 33/34] Modify docs. --- src/tir/schedule/concrete_schedule.h | 3 ++- src/tir/schedule/primitive/sampling.cc | 2 +- src/tir/schedule/traced_schedule.h | 3 ++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 217a352609e5..0bd902d183bf 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -84,7 +84,8 @@ class ConcreteScheduleNode : public ScheduleNode { * \brief Sample an integer given the probability distribution * \param candidates The candidates * \param probs The probability distribution of the candidates - * \param decision The sampling decision + * \param decision The sampling decision, if it's given we would validate the decision, otherwise + * we would sample a decision from the distribution and set the decision accordingly. * \return The random variable sampled from candidates */ ExprRV SampleCategorical(const Array& candidates, const Array& probs, diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index 159b8800809e..ac40d27c4bf3 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -47,7 +47,7 @@ int64_t SampleCategorical(support::LinearCongruentialEngine::TRandState* rand_st << ", but decision is: " << i; } - *decision = Integer(i); + *decision = Integer(i); // decision is guaranteed not to be nullptr. return candidates[i]; } diff --git a/src/tir/schedule/traced_schedule.h b/src/tir/schedule/traced_schedule.h index 90b4a40d912d..48dadbc03b3b 100644 --- a/src/tir/schedule/traced_schedule.h +++ b/src/tir/schedule/traced_schedule.h @@ -51,7 +51,8 @@ class TracedScheduleNode : public ConcreteScheduleNode { * \brief Sample an integer given the probability distribution * \param candidates The candidates * \param probs The probability distribution of the candidates - * \param decision The sampling decision + * \param decision The sampling decision, if it's given we would validate the decision, otherwise + * we would sample a decision from the distribution and set the decision accordingly. * \return The random variable sampled from candidates */ ExprRV SampleCategorical(const Array& candidates, const Array& probs, From 5a6b2d3246f953fca18da2ea8d08b1d1962dc077 Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Mon, 23 Aug 2021 23:48:06 -0700 Subject: [PATCH 34/34] Modify tests. --- .../unittest/test_tir_schedule_sampling.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_sampling.py b/tests/python/unittest/test_tir_schedule_sampling.py index 49bef9cfec24..2bfd68663c99 100644 --- a/tests/python/unittest/test_tir_schedule_sampling.py +++ b/tests/python/unittest/test_tir_schedule_sampling.py @@ -60,15 +60,12 @@ def test_sample_categorical_copy(): sch = tir.Schedule(elementwise, seed=42, debug_mask="all") candidates = [1, 2, 3, 4] probs = [0.1, 0.2, 0.3, 0.4] - rvs = [] - decisions = [] + rv_decisions = [] for _ in range(n): rv = sch.sample_categorical(candidates, probs) # pylint: disable=invalid-name - decision = sch.get(rv) - rvs.append(rv) - decisions.append(decision) + rv_decisions.append((rv, sch.get(rv))) sch_copy = sch.copy() - for rv, decision in zip(rvs, decisions): # pylint: disable=invalid-name + for rv, decision in rv_decisions: # pylint: disable=invalid-name decision_copy = sch_copy.get(rv) assert int(decision) == int(decision_copy) @@ -79,16 +76,13 @@ def test_sample_categorical_serialize(): sch = tir.Schedule(elementwise, seed=42, debug_mask="all") candidates = [5, 6, 7, 8] probs = [0.23, 0.19, 0.37, 0.21] + decisions = [] for _ in range(n): - sch.get(sch.sample_categorical(candidates, probs)) - trace = sch.trace - json_obj = trace.as_json() - new_sch = tir.Schedule(mod=elementwise, debug_mask="all") - Trace.apply_json_to_schedule(json_obj=json_obj, sch=new_sch) - assert len(sch.trace.insts) == len(new_sch.trace.insts) - for i, inst in enumerate(sch.trace.insts): - new_inst = new_sch.trace.insts[i] - assert sch.trace.decisions[inst] == new_sch.trace.decisions[new_inst] + rv = sch.get(sch.sample_categorical(candidates, probs)) # pylint: disable=invalid-name + decisions.append(rv) + new_sch = verify_trace_roundtrip(sch, mod=elementwise) + for i, new_inst in enumerate(new_sch.trace.insts): + assert decisions[i] == candidates[new_sch.trace.decisions[new_inst].value] if __name__ == "__main__":