Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][M1b] Schedule class #7847

Merged
merged 1 commit into from
Apr 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/tir/schedule/state.h>

namespace tvm {
namespace tir {

/**************** Random variable: BlockRV ****************/

/*! \brief A random variable that evaluates to a TensorIR block */
class BlockRVNode : public runtime::Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "tir.BlockRV";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object);
};

/*!
* \brief Managed reference to BlockRVNode
* \sa BlockRVNode
*/
class BlockRV : public runtime::ObjectRef {
public:
/*! \brief Constructor */
TVM_DLL BlockRV();
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BlockRV, runtime::ObjectRef, BlockRVNode);
};

/**************** Random variable: LoopRV ****************/

/*! \brief A random variable that evaluates to a TensorIR for loop */
class LoopRVNode : public runtime::Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "tir.LoopRV";
TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object);
};

/*!
* \brief Managed reference to LoopRVNode
* \sa LoopRVNode
*/
class LoopRV : public runtime::ObjectRef {
public:
/*! \brief Constructor */
TVM_DLL LoopRV();
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode);
};

/**************** Random variable: IntRV ****************/

/*! \brief An integer random variable */
using IntRV = PrimExpr;

using IntRVNode = PrimExprNode;

/**************** The Schedule class ****************/

class Schedule;

/*! \brief The user-facing schedule class */
class ScheduleNode : public runtime::Object {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
friend class Schedule;

public:
virtual ~ScheduleNode() = default;

static constexpr const char* _type_key = "tir.Schedule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object);

public:
/*! \brief Get the IRModule associated with this schedule. */
virtual IRModule mod() const { return state()->mod; }
/*! \return The internal state of scheduling */
virtual ScheduleState state() const = 0;
/*!
* \brief Returns a copy of the schedule, including both its state and its symbol table,
* guaranteeing that
* 1) SRef tree is completely reconstructed;
* 2) The IRModule being scheduled is not modified;
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref
junrushao marked this conversation as resolved.
Show resolved Hide resolved
* reconstructed
*/
virtual Schedule Copy() const = 0;
/*!
* \brief Seed the randomness
junrushao marked this conversation as resolved.
Show resolved Hide resolved
* \param seed The new random seed, -1 if use device random, otherwise non-negative
*/
virtual void Seed(int64_t seed = -1) {
LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed";
}

public:
/******** Lookup/Remove random variables ********/
/*!
* \brief Get the block corresponding to the specific BlockRV
* \param block_rv The BlockRV to be looked up
* \return The corresponding block
*/
virtual Block Get(const BlockRV& block_rv) const = 0;
/*!
* \brief Get the for loop corresponding to the specific LoopRV
* \param loop_rv The LoopRV to be looked up
* \return The corresponding for loop
*/
virtual For Get(const LoopRV& loop_rv) const = 0;
/*!
* \brief Get the value corresponding to the specific random variable
* \param int_rv The random variable to be looked up
* \return The corresponding value
*/
virtual int64_t Get(const IntRV& int_rv) const = 0;
/*!
* \brief Get the block sref corresponding to the specific BlockRV
* \param block_rv The BlockRV to be looked up
* \return The corresponding block sref
*/
virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0;
/*!
* \brief Get the loop sref corresponding to the specific LoopRV
* \param loop_rv The LoopRV to be looked up
* \return The corresponding loop sref
*/
virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0;
/*!
* \brief Get the block/loop sref corresponding to the specific statement
* \param stmt The statement to be looked up
* \return The corresponding block/loop sref
*/
virtual StmtSRef GetSRef(const StmtNode* stmt) const;
junrushao marked this conversation as resolved.
Show resolved Hide resolved
/*!
* \brief Get the block/loop sref corresponding to the specific statement
* \param stmt The statement to be looked up
* \return The corresponding block/loop sref
*/
StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); }
/*!
* \brief Remove a block random variable from the symbol table
* \param block_rv The random variable to be removed
*/
virtual void RemoveRV(const BlockRV& block_rv) = 0;
/*!
* \brief Remove a loop random variable from the symbol table
* \param loop_rv The random variable to be removed
*/
virtual void RemoveRV(const LoopRV& loop_rv) = 0;
/*!
* \brief Remove an integer random variable from the symbol table
* \param int_rv The random variable to be removed
*/
virtual void RemoveRV(const IntRV& int_rv) = 0;

public:
/******** Block/Loop relation ********/
/*!
* \brief Retrieve a block in a specific function with its name
junrushao marked this conversation as resolved.
Show resolved Hide resolved
* \param name The name of the block to be retrieved
* \param func_name The name of the function
* \return The block retrieved
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name
*/
virtual BlockRV GetBlock(const String& name, const String& func_name = "main") = 0;
/*!
* \brief Get the parent loops of the block in its scope, from outer to inner
* \param block_rv The query block
* \return A list of loops above the given block in its scope, from outer to inner
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
};

/*!
* \brief Managed reference to ScheduleNode
*
* A schedule is a set of transformations that change the order of computation but
* preserve the semantics of computation. Some example of schedules:
* 1) Split a loop into two;
* 2) Reorder two loops;
* 3) Inline the computation of a specific buffer into its consumer
*
* The schedule class stores auxiliary information to schedule correctly and efficiently.
*
* Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html
*
* \sa ScheduleNode
*/
class Schedule : public runtime::ObjectRef {
junrushao marked this conversation as resolved.
Show resolved Hide resolved
public:
/*!
* \brief Construct a concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mode Do extra correctness checking after the class creation
junrushao marked this conversation as resolved.
Show resolved Hide resolved
* and each time after calling the Replace method.
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed includes:
* 1) VerifySRefTree
* 2) VerifyAffineBinding
* 3) VerifyRegionCover
* 4) VerifyStagePipeline
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_SCHEDULE_SCHEDULE_H_
1 change: 1 addition & 0 deletions python/tvm/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, msg):
register_error("TypeError", TypeError)
register_error("AttributeError", AttributeError)
register_error("KeyError", KeyError)
register_error("IndexError", IndexError)


@register_error
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift

from .schedule import StmtSRef, BlockScope, ScheduleState
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule

from . import schedule
from . import ir_builder
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@

from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
from .state import ScheduleDebugMask, ScheduleState
from .schedule import LoopRV, BlockRV, IntRV, RAND_VAR_TYPE, Schedule
Loading