Skip to content

Commit

Permalink
[M3a][Meta Schedule] Add Sampling Primitive SampleCategorical. (apach…
Browse files Browse the repository at this point in the history
…e#8817)

Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
  • Loading branch information
7 people authored and ylc committed Jan 13, 2022
1 parent 178ed98 commit a20f09d
Show file tree
Hide file tree
Showing 12 changed files with 429 additions and 44 deletions.
10 changes: 6 additions & 4 deletions include/tvm/support/random_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

/*!
* \file random_engine.h
* \brief Random number generator, for Sampler and Sampling functions.
* \brief Random number generator. It provides a generic interface consistent with
* `std::uniform_random_bit_generator`
*/

#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_
Expand All @@ -41,10 +42,11 @@ namespace support {
* included for simplification. For full member functions of std::minstd_rand, please check out the
* following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine
*/

class LinearCongruentialEngine {
public:
/*!
* \brief The result type is defined as int64_t here for meta_schedule sampler usage.
* \brief The result type is defined as uint64_t here to avoid overflow.
* \note The type name is not in Google style because it is used in STL's distribution inferface.
*/
using result_type = uint64_t;
Expand All @@ -63,13 +65,13 @@ class LinearCongruentialEngine {
* \brief The minimum possible value of random state here.
* \note The function name is uncapilized because it is used in STL's distribution inferface.
*/
result_type min() { return 0; }
static constexpr result_type min() { return 0; }

/*!
* \brief The maximum possible value of random state here.
* \note The function name is uncapilized because it is used in STL's distribution inferface.
*/
result_type max() { return modulus - 1; }
static constexpr result_type max() { return modulus - 1; }

/*!
* \brief Operator to move the random state to the next and return the new random state. According
Expand Down
27 changes: 20 additions & 7 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/support/random_engine.h>
#include <tvm/tir/schedule/state.h>
#include <tvm/tir/schedule/trace.h>

Expand Down Expand Up @@ -118,9 +119,9 @@ class ScheduleNode : public runtime::Object {
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random, otherwise non-negative
*/
virtual void Seed(int64_t seed = -1) {
LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed";
}
virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0;
/*! \brief Fork the random state */
virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;

public:
/******** Lookup/Remove random variables ********/
Expand Down Expand Up @@ -184,6 +185,16 @@ class ScheduleNode : public runtime::Object {

public:
/******** Schedule: Sampling ********/
/*!
* \brief Sample an integer given the probability distribution
* \param candidates The candidates
* \param probs The probability distribution of the candidates
* \param decision The sampling decision
* \return The random variable sampled from candidates
*/
virtual ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs,
Optional<Integer> decision = NullOpt) = 0;

/******** Schedule: Get blocks & loops ********/
/*!
* \brief Retrieve a block in a specific function with its name
Expand Down Expand Up @@ -356,6 +367,7 @@ class Schedule : public runtime::ObjectRef {
/*!
* \brief Construct a concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param seed The seed value for schedule's random state
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
Expand All @@ -365,11 +377,12 @@ class Schedule : public runtime::ObjectRef {
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mask,
ScheduleErrorRenderLevel error_render_level);
TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level);
/*!
* \brief Construct a traced concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param seed The seed value for schedule's random state
* \param debug_mask Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \param error_render_level The level of error rendering
Expand All @@ -379,8 +392,8 @@ class Schedule : public runtime::ObjectRef {
* 1) VerifySRefTree
* 2) VerifyCachedFlags
*/
TVM_DLL static Schedule Traced(IRModule mod, int debug_mask,
ScheduleErrorRenderLevel error_render_level);
TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

Expand Down
57 changes: 57 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,16 @@ def _parse_error_render_level(error_render_level: str) -> int:
return _ERROR_RENDER_LEVEL.get(error_render_level)


def _parse_seed(seed: Optional[int]) -> int:
if seed is None:
return -1
if not isinstance(seed, int):
raise TypeError(f"Expected `seed` to be int or None, but gets: {seed}")
if seed < 1 or seed > 2147483647:
raise ValueError(f"seed must be in the range [1, 2147483647], but gets: {seed}")
return seed


@_register_object("tir.Schedule")
class Schedule(Object):
"""The user-facing schedule class
Expand All @@ -98,6 +108,7 @@ def __init__(
self,
mod: Union[PrimFunc, IRModule],
*,
seed: Optional[int] = None,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
) -> None:
Expand All @@ -107,6 +118,10 @@ def __init__(
----------
mod : Union[PrimFunc, IRModule]
The IRModule or PrimFunc to be scheduled
seed: Optional[int]
The seed value for schedule's random state
Note that None and -1 means use device random, otherwise only integer between 1 and
2147483647 is allowed.
debug_mask : Union[str, int]
Do extra correctness checking after the class creation and each time
after calling the Replace method.
Expand All @@ -130,6 +145,7 @@ def __init__(
self.__init_handle_by_constructor__(
_ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_seed(seed),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
)
Expand All @@ -138,12 +154,14 @@ def __init__(
def _create_non_traced(
mod: Union[PrimFunc, IRModule],
*,
seed: Optional[int] = None,
debug_mask: Union[str, int] = "none",
error_render_level: str = "detail",
) -> "Schedule":
"""Construct a non-traced TensorIR schedule class from an IRModule."""
return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member
_parse_mod(mod),
_parse_seed(seed),
_parse_debug_mask(debug_mask),
_parse_error_render_level(error_render_level),
)
Expand Down Expand Up @@ -190,6 +208,16 @@ def seed(self, seed: int) -> None:
"""
return _ffi_api.ScheduleSeed(self, seed) # type: ignore # pylint: disable=no-member

def fork_seed(self) -> int:
"""Returns a forked random state as seed for new schedules
Returns
-------
seed : int
The forked random state, not the same as the current random state
"""
return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member

def show(self, rand_var: RAND_VAR_TYPE) -> str:
"""Returns a string representation of the value that the random variable evaluates to
Expand Down Expand Up @@ -268,6 +296,35 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None:

########## Schedule: Sampling ##########

def sample_categorical(
self,
candidates: List[int],
probs: List[float],
decision: Optional[int] = None,
) -> ExprRV:
"""Sample an integer given the probability distribution
Parameters
----------
candidates : List[int]
The candidates to be sampled from
probs : List[float]
The probability of each candidate
decision : Optional[int]
The sampling decision, if any
Returns
-------
result : ExprRV
The random variable sampled from candidates
"""
return _ffi_api.ScheduleSampleCategorical( # type: ignore # pylint: disable=no-member
self,
candidates,
probs,
decision,
)

########## Schedule: Get blocks & loops ##########
def get_block(
self,
Expand Down
68 changes: 68 additions & 0 deletions src/support/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
#ifndef TVM_SUPPORT_ARRAY_H_
#define TVM_SUPPORT_ARRAY_H_
#include <tvm/ir/expr.h>
#include <tvm/runtime/container/array.h>

#include <vector>
Expand Down Expand Up @@ -67,6 +68,73 @@ inline bool ArrayWithSameContent(const std::vector<T*>& a, const std::vector<T*>
return true;
}

/*!
* \brief Convert a tvm::runtime::Array to std::vector
* \tparam TSrc The type of elements in the source Array
* \tparam TDst The type of elements in the result vector
* \return The result vector
*/
template <class TSrc, class TDst>
std::vector<TDst> AsVector(const Array<TSrc>& vec);

/********** Implementation details of AsVector<TSrc, TDst> **********/
namespace details {

template <class TSrc, class TDst>
struct AsVectorImpl {};

template <class TSrc>
struct AsVectorImpl<TSrc, TSrc> {
inline std::vector<TSrc> operator()(const Array<TSrc>& vec) const {
return std::vector<TSrc>(vec.begin(), vec.end());
}
};

template <class TSrcObjectRef>
struct AsVectorImpl<TSrcObjectRef, int> {
inline std::vector<int> operator()(const Array<TSrcObjectRef>& vec) const {
std::vector<int> results;
for (const TSrcObjectRef& x : vec) {
const auto* n = x.template as<IntImmNode>();
ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey();
results.push_back(n->value);
}
return results;
}
};

template <class TSrcObjectRef>
struct AsVectorImpl<TSrcObjectRef, int64_t> {
inline std::vector<int64_t> operator()(const Array<TSrcObjectRef>& vec) const {
std::vector<int64_t> results;
for (const TSrcObjectRef& x : vec) {
const auto* n = x.template as<IntImmNode>();
ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey();
results.push_back(n->value);
}
return results;
}
};

template <class TSrcObjectRef>
struct AsVectorImpl<TSrcObjectRef, double> {
inline std::vector<double> operator()(const Array<TSrcObjectRef>& array) const {
std::vector<double> results;
for (const TSrcObjectRef& x : array) {
const auto* n = x.template as<FloatImmNode>();
ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey();
results.push_back(n->value);
}
return results;
}
};
} // namespace details

template <class TSrc, class TDst>
inline std::vector<TDst> AsVector(const Array<TSrc>& vec) {
return details::AsVectorImpl<TSrc, TDst>()(vec);
}

} // namespace support
} // namespace tvm
#endif // TVM_SUPPORT_ARRAY_H_
30 changes: 28 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
*/
#include "./concrete_schedule.h"

#include <random>

namespace tvm {
namespace tir {

Schedule Schedule::Concrete(IRModule mod, int debug_mask,
ScheduleErrorRenderLevel error_render_level) {
Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
int debug_mask, ScheduleErrorRenderLevel error_render_level) {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
n->state_ = ScheduleState(mod, debug_mask);
n->error_render_level_ = error_render_level;
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
support::LinearCongruentialEngine(&n->rand_state_).Seed(seed);
return Schedule(std::move(n));
}

Expand Down Expand Up @@ -208,6 +211,29 @@ Schedule ConcreteScheduleNode::Copy() const {
}

/******** Schedule: Schedule: Sampling ********/

void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) {
if (seed == -1) {
seed = std::random_device()();
}
support::LinearCongruentialEngine(&rand_state_).Seed(seed);
}

support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() {
// In order for reproducibility, we computer the new seed using RNG's random state and a different
// set of parameters. Note that both 32767 and 1999999973 are prime numbers.
return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 1999999973;
}

ExprRV ConcreteScheduleNode::SampleCategorical(const Array<Integer>& candidates,
const Array<FloatImm>& probs,
Optional<Integer> decision) {
TVM_TIR_SCHEDULE_BEGIN();
return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision));
TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_);
throw;
}

/******** Schedule: Get blocks & loops ********/

BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) {
Expand Down
Loading

0 comments on commit a20f09d

Please sign in to comment.