Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][AutoTVM] Relay op strategy #4644

Merged
merged 48 commits into from
Feb 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
42a8989
relay op strategy
icemelon Oct 17, 2019
2786bf0
fix bugs
icemelon Feb 5, 2020
11769f5
lint
icemelon Feb 5, 2020
fa77cb7
address comments
icemelon Feb 6, 2020
9084221
add name to op implement
icemelon Feb 6, 2020
dfb9b2f
Modify topi tests (#9)
kevinthesun Feb 6, 2020
1b94211
fix topi test
icemelon Feb 7, 2020
3abc4ee
fix more topi test
icemelon Feb 8, 2020
2df6a84
lint
icemelon Feb 10, 2020
7908aa4
address comments
icemelon Feb 11, 2020
bde152d
x
icemelon Feb 11, 2020
ef20785
fix more tests & bugs
icemelon Feb 11, 2020
f9a41c6
Modify more tests (#10)
kevinthesun Feb 11, 2020
a7768bb
fix more test
icemelon Feb 12, 2020
3f33c70
try to update vta using strategy
icemelon Feb 12, 2020
8ef29b5
fix cpptest
icemelon Feb 12, 2020
220c455
x
icemelon Feb 12, 2020
2b29197
fix rebase err
icemelon Feb 12, 2020
ea85e73
Fix two tests (#11)
kevinthesun Feb 12, 2020
ca702e1
change autotvm log format
icemelon Feb 12, 2020
ea9720d
lint
icemelon Feb 12, 2020
113a1fa
minor fix
icemelon Feb 13, 2020
b180456
try fix vta test
icemelon Feb 13, 2020
e3e7e72
fix rebase err
icemelon Feb 13, 2020
d872262
tweak
icemelon Feb 13, 2020
58789f2
tmp hack for vta pass
icemelon Feb 13, 2020
fa6a9d7
fix tutorial
icemelon Feb 14, 2020
c1bf725
fix
icemelon Feb 14, 2020
206c859
fix more tutorials
icemelon Feb 14, 2020
0f36deb
fix vta tutorial
icemelon Feb 14, 2020
0bf960b
minor
icemelon Feb 15, 2020
dd17aa1
address comments
icemelon Feb 16, 2020
e1669e3
fix
icemelon Feb 16, 2020
5eee588
address comments
icemelon Feb 17, 2020
2f3c719
fix cpptest
icemelon Feb 17, 2020
0445e5a
fix docs
icemelon Feb 17, 2020
e496ae0
change data structure name and api
icemelon Feb 17, 2020
ea3dfaa
address comments
icemelon Feb 17, 2020
eb630ef
lint
icemelon Feb 17, 2020
6cf45b5
fix rebase err
icemelon Feb 18, 2020
8b0081a
updates
icemelon Feb 19, 2020
3510177
fix winograd test
icemelon Feb 19, 2020
cf43e16
fix doc
icemelon Feb 19, 2020
f07d92a
rebase
icemelon Feb 20, 2020
59bd399
upgrade tophub version number
icemelon Feb 23, 2020
5338a57
fix bug
icemelon Feb 23, 2020
24ae797
re-enable vta tsim test after tophub is upgraded
icemelon Feb 23, 2020
4f1806a
fix vta test to use the correct args so the config can be found in to…
icemelon Feb 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
#include <tvm/target/target.h>
#include <tvm/target/generic_func.h>
#include <tvm/tir/data_layout.h>
#include <string>

Expand Down Expand Up @@ -105,9 +106,8 @@ using TShapeDataDependant = bool;
*/
using FTVMCompute = runtime::TypedPackedFunc<
Array<te::Tensor>(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type,
const Target& target)>;
const Array<te::Tensor>& inputs,
const Type& out_type)>;

/*!
* \brief Build the computation schedule for
Expand All @@ -120,8 +120,18 @@ using FTVMCompute = runtime::TypedPackedFunc<
*/
using FTVMSchedule = runtime::TypedPackedFunc<
te::Schedule(const Attrs& attrs,
const Array<te::Tensor>& outs,
const Target& target)>;
const Array<te::Tensor>& outs,
const Target& target)>;

/*!
* \brief Generate the strategy of operators. This function is a generic
* function and can be re-defined for different targets.
*
* The function signature of generic function is:
* OpStrategy(const Attrs& attrs, const Array<Tensor>& inputs,
* const Type& out_type, const Target& target)
*/
using FTVMStrategy = GenericFunc;

/*!
* \brief Alternate the layout of operators or replace the
Expand All @@ -136,7 +146,8 @@ using FTVMSchedule = runtime::TypedPackedFunc<
using FTVMAlterOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<te::Tensor>& tinfos)>;
const Array<te::Tensor>& tinfos,
const Type& out_type)>;

/*!
* \brief Convert the layout of operators or replace the
Expand Down Expand Up @@ -191,9 +202,7 @@ using FForwardRewrite = runtime::TypedPackedFunc<
* \brief Gradient for a specific op.
*
* \param orig_call the original Expr.
*
* \param output_grad the gradient of the Expr.
*
* \return the gradient for each parameters.
*/
using FPrimalGradient = runtime::TypedPackedFunc<tvm::Array<Expr>(const Expr& orig_call,
Expand All @@ -207,13 +216,13 @@ enum AnyCodegenStrategy {
kVariableDimensions
};

/* \brief A runtime representation of shape. */
/*! \brief A runtime representation of shape. */
using Shape = Array<IndexExpr>;

using FShapeFunc = runtime::TypedPackedFunc<
Array<te::Tensor>(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Array<IndexExpr>& out_ndims)>;
const Array<te::Tensor>& inputs,
const Array<IndexExpr>& out_ndims)>;

} // namespace relay
} // namespace tvm
Expand Down
164 changes: 164 additions & 0 deletions include/tvm/relay/op_strategy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/relay/op_strategy.h
* \brief The Relay operator Strategy and related data structure.
*/

#ifndef TVM_RELAY_OP_STRATEGY_H_
#define TVM_RELAY_OP_STRATEGY_H_

#include <tvm/te/tensor.h>
#include <tvm/te/schedule.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/target/target.h>
#include <string>

namespace tvm {
namespace relay {

/*!
* \brief Operator implementation that includes compute and schedule function.
*/
class OpImplementationNode : public Object {
public:
/*! \brief Compute function */
FTVMCompute fcompute;
/*! \brief Schedule function */
FTVMSchedule fschedule;
/*! \brief Name of the implementation */
std::string name;
/*! \brief Priority level */
int plevel;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("plevel", &plevel);
}

static constexpr const char* _type_key = "relay.OpImplementation";
TVM_DECLARE_FINAL_OBJECT_INFO(OpImplementationNode, Object);
};

/*!
* \brief Operator implementation class.
*/
class OpImplementation : public ObjectRef {
public:
/*!
* \brief Invoke the operator compute function.
* \param attrs The attribute of the primitive
* \param inputs The input tensors.
* \param out_type The output type information.
* \return The output compute description of the operator.
*/
TVM_DLL Array<te::Tensor> Compute(const Attrs& attrs,
const Array<te::Tensor>& inputs,
const Type& out_type);
/*!
* \brief Build the computation schedule.
* \param attrs The attribute of the node.
* \param outs The output tensors.
* \param target The build target.
* \return The computation schedule.
*/
TVM_DLL te::Schedule Schedule(const Attrs& attrs,
const Array<te::Tensor>& outs,
const Target& target);

TVM_DEFINE_OBJECT_REF_METHODS(OpImplementation, ObjectRef, OpImplementationNode);
};

/*!
* \brief Specialized implementations for operators under certain conditions.
*/
class OpSpecializationNode : public Object {
public:
/*! \brief List of implementations. */
Array<OpImplementation> implementations;
/*! \brief Condition to enable the specialization.
* Could be undefined to represent generic case. */
te::SpecializedCondition condition;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("condition", &condition);
v->Visit("implementations", &implementations);
}

static constexpr const char* _type_key = "relay.OpSpecialization";
TVM_DECLARE_FINAL_OBJECT_INFO(OpSpecializationNode, ExprNode);
};

/*!
* \brief Operator specialization class.
*/
class OpSpecialization : public ObjectRef {
public:
/*!
* \brief Add an implementation.
* \param fcompute Compute function
* \param fschedule Schedule function
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
std::string name, int plevel);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpSpecialization, ObjectRef, OpSpecializationNode);
};

/*!
* \brief Operator strategy to choose implementation.
*/
class OpStrategyNode : public Object {
public:
/*! \brief List of operator specializations. */
Array<OpSpecialization> specializations;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("specializations", &specializations);
}

static constexpr const char* _type_key = "relay.OpStrategy";
TVM_DECLARE_FINAL_OBJECT_INFO(OpStrategyNode, ExprNode);
};

/*!
* \brief Operator strategy class.
*/
class OpStrategy : public ObjectRef {
public:
/*!
* \brief Add an implementation.
* \param fcompute Compute function
* \param fschedule Schedule function
* \param name Name of the implementation
* \param plevel Priority level of the implementation
*/
TVM_DLL void AddImplementation(FTVMCompute fcompute, FTVMSchedule fschedule,
std::string name, int plevel);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(OpStrategy, ObjectRef, OpStrategyNode);
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_OP_STRATEGY_H_
49 changes: 49 additions & 0 deletions include/tvm/te/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <tvm/tir/expr.h>
#include <tvm/te/tensor.h>
#include <tvm/te/tensor_intrin.h>
#include <tvm/support/with.h>

#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -742,6 +743,53 @@ class SingletonNode : public IterVarRelationNode {
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
};

/*! \brief Container for specialization conditions. */
class SpecializedConditionNode : public Object {
public:
/*!
* \brief List of conditions in conjunctive joint form (CNF).
* Each condition should be a simple expression, e.g., n > 16, m % 8 == 0, etc.,
* where n, m are tvm::Var that represents a dimension in the tensor shape.
*/
Array<PrimExpr> clauses;

void VisitAttrs(AttrVisitor* v) {
v->Visit("clauses", &clauses);
}

static constexpr const char* _type_key = "SpecializedCondition";
TVM_DECLARE_FINAL_OBJECT_INFO(SpecializedConditionNode, Object);
};

/*!
* \brief Specialized condition to enable op specialization
*/
class SpecializedCondition : public ObjectRef {
public:
/*!
* \brief construct from conditions
* \param conditions The clauses in the specialized condition.
*/
TVM_DLL SpecializedCondition(Array<PrimExpr> conditions); // NOLINT(*)

/*!
* \brief Get the current specialized condition.
* \return the current specialized condition.
*/
TVM_DLL static SpecializedCondition Current();

TVM_DEFINE_OBJECT_REF_METHODS(SpecializedCondition, ObjectRef, SpecializedConditionNode);
class Internal;

private:
// enable with syntax.
friend class Internal;
friend class With<SpecializedCondition>;
/*! \brief Push a new specialized condition onto the thread local stack. */
TVM_DLL void EnterWithScope();
/*! \brief Pop a specialized condition off the thread local context stack. */
TVM_DLL void ExitWithScope();
};

// implementations
inline const StageNode* Stage::operator->() const {
Expand All @@ -765,6 +813,7 @@ inline const IterVarRelationNode* IterVarRelation::operator->() const {
inline const IterVarAttrNode* IterVarAttr::operator->() const {
return static_cast<const IterVarAttrNode*>(get());
}

} // namespace te
} // namespace tvm
#endif // TVM_TE_SCHEDULE_H_
4 changes: 2 additions & 2 deletions python/tvm/autotvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
from .measure import measure_option, MeasureInput, MeasureResult, MeasureErrorNo, \
LocalBuilder, LocalRunner, RPCRunner
from .tuner import callback
from .task import template, get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, \
from .task import get_config, create, ConfigSpace, ConfigEntity, \
register_topi_compute, register_topi_schedule, register_customized_task, \
DispatchContext, FallbackContext, ApplyHistoryBest as apply_history_best, \
ApplyGraphBest as apply_graph_best
from .env import GLOBAL_SCOPE
5 changes: 4 additions & 1 deletion python/tvm/autotvm/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def load(self, inp, get_all=False):
current = self.get(measure_str_key(inp))
if current is not None:
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
results = [rec[1] for rec in records]
results = [rec[1] for rec in records if rec is not None]
if get_all:
return results
return max(results, key=lambda result: result.timestamp)
icemelon marked this conversation as resolved.
Show resolved Hide resolved
Expand Down Expand Up @@ -167,9 +167,12 @@ def filter(self, func):
current = self.get(key)
try:
records = [decode(x) for x in current.split(RedisDatabase.MAGIC_SPLIT)]
records = [rec for rec in records if rec is not None]
except TypeError: # got a badly formatted/old format record
continue

if not records:
continue
inps, results = zip(*records)
inp = inps[0]
if not func(inp, results):
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/autotvm/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,10 @@ def get_flatten_name(fea):
from .record import decode
# flatten line to feature
line = fea
inp, _ = decode(line)
ret = decode(line)
if ret is None:
raise ValueError("Unsupported AutoTVM log format")
inp, _ = ret
target = _target.create(inp.target)
with target:
s, args = inp.template.instantiate(inp.config)
Expand Down
Loading