Skip to content

Commit

Permalink
[Relay] Prepare for new plan_devices.cc (part II) (#9130)
Browse files Browse the repository at this point in the history
* Prepare for new plan_devices.cc (part II)

These changes came from changing #9038 to use
tvm.parser.fromtext instead of manual AST construction.

- Demote FunctionOnDeviceAttrs to just a pair of DictAttrs entries so
  that the parser will understand them on Function definitions.
- Connect some special operators to their attributes so parsing understands them
  at call sites.
- Don't silently ignore attributes during parsing.
- Implement OptFunctionOnDevice so won't add device annotations for kUnknownDeviceType.
- Allow the parser to be given an initial metadata map to support examples which
  need constants.
- More DLOG -> VLOG conversions to reduce debug clutter.

* [checkpoint] Keep existing ParseModule ffi to simplify rust bindings

* [checkpoint] Address Christopher's comments.

* [checkpoint] Andrew's comments from #9038

* [checkpoint] Jared's comments from #9038

* [checkpoint] Woops, forgot rename.
  • Loading branch information
mbs-octoml authored Sep 28, 2021
1 parent 5506472 commit 163322c
Show file tree
Hide file tree
Showing 19 changed files with 211 additions and 141 deletions.
25 changes: 25 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,31 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v
return input;
}

/*!
* \brief Copy the function or module, but overrides the attributes with the entries from \p attrs.
*
* \param input The thing to annotate (BaseFunc or IRModule)
* \param attrs Key/values attributes to add to \p input.
*
* \tparam TFunc The corresponding function or module type.
*
* \returns The new function or module with updated attributes.
*/
template <typename TFunc>
inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
if (node->attrs.defined()) {
for (const auto& pair : attrs) {
node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
}
} else {
node->attrs = DictAttrs(std::move(attrs));
}
return input;
}

// Namespace containing detail implementations
namespace detail {
using runtime::TVMArgValue;
Expand Down
21 changes: 21 additions & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,27 @@ constexpr const char* kTarget = "target";
* Type: String
*/
constexpr const char* kGlobalSymbol = "global_symbol";

/*!
* \brief The device type which will hold each of the functions parameters.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: Array<Integer> (but interpreted as Array<DLDeviceType>)
*/
constexpr const char* kParamDeviceTypes = "param_device_types";

/*!
* \brief The device type which will hold the function result.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: Integer (but interpreted as DLDeviceType)
*/
constexpr const char* kResultDeviceType = "result_device_type";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
8 changes: 6 additions & 2 deletions include/tvm/parser/parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
* \file parser.h
* \brief A parser for TVM IR.
*/
#include <tvm/ir/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>

Expand All @@ -32,8 +33,11 @@
namespace tvm {
namespace parser {

IRModule ParseModule(std::string file_name, std::string file_content,
Optional<IRModule> init_module = Optional<IRModule>());
using MetaTable = Map<String, Array<ObjectRef>>;

IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module = Optional<IRModule>(),
const MetaTable& init_meta_table = MetaTable());

} // namespace parser
} // namespace tvm
Expand Down
41 changes: 34 additions & 7 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,44 @@ namespace tvm {
namespace relay {

/*!
* \brief Attributes for the "on_device" operator.
* \brief Attributes for the "on_device" special operator.
*
* The relay call
* The Relay call (aka 'annotation'):
* \code
* on_device(expr, device_type=2)
* on_device(sub_expr, device_type=2)
* \endcode
* denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2
* (i.e. \p kDLCuda). Semantically the operator is the identity function.
* constrains \p sub_expr to execute and store its result on a device with \p DLDeviceType \p 2
* (i.e. a \p kDLCuda device). However the annotation itself may appear in an expression to be
* executed and stored on a different device. If so the compiler will automatically insert a
* "device_copy" call to mediate the transition between devices.
*
* See also FunctionOnDeviceAttrs in include/relay/attrs/function.h for the function-level
* companion.
* E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then:
* \code
* multiply(on_device(add(%x, %y), device_type=2), %z)
* \endcode
* indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU.
* The compiler will rewrite this to:
* \code
* multiply(device_copy(add(%x, %y), src_dev_type=2, dst_dev_type=1), %z)
* \endcode
*
* The Relay call
* \code
* on_device(sub_expr, device_type=2, is_fixed=True)
* \endcode
* is similar to the above, however the annotation itself must appear in an expression on the
* same device. The compiler will check the devices are consistent, and will not insert any
* "device_copy" call. This form of annotation shouldn't be necessary in user programs. However
* it is needed by the \p PlanDevices pass to fully specify the results of device planning so that
* the pass is idempotent.
*
* E.g.: The following program is equivalent to the above:
* \code
* let %a = on_device(add(%x, %y), device_type=2, is_fixed=True)
* multiply(device_copy(%a, src_dev_type=2, dst_dev_type=1), %z)
* \endcode
* The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored
* on the GPU.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
// TODO(mbs): Replace device types with TargetDevice.
Expand Down
66 changes: 0 additions & 66 deletions include/tvm/relay/attrs/function.h

This file was deleted.

6 changes: 4 additions & 2 deletions python/tvm/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ def add(self, name, content):
return _ffi.get_global_func("SourceMapAdd")(self, name, content)


def parse(source, source_name="from_string"):
return _ffi_api.ParseModule(source_name, source)
def parse(source, source_name="from_string", init_module=None, init_meta_table=None):
if init_meta_table is None:
init_meta_table = {}
return _ffi_api.ParseModuleInContext(source_name, source, init_module, init_meta_table)


def parse_expr(source):
Expand Down
6 changes: 3 additions & 3 deletions src/ir/diagnostic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,10 @@ void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& s
}

auto source = (*it).second;
DLOG(INFO) << "Source: " << std::endl << source->source;
VLOG(1) << "Source: " << std::endl << source->source;

DLOG(INFO) << "ReportAt "
<< "span = " << span << " msg = " << diagnostic->message;
VLOG(1) << "ReportAt "
<< "span = " << span << " msg = " << diagnostic->message;

auto line_text = source.GetLine(span->line);

Expand Down
3 changes: 1 addition & 2 deletions src/parser/meta_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_PARSER_META_REF_H_

#include <tvm/ir/attrs.h>
#include <tvm/parser/parser.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>

Expand All @@ -36,8 +37,6 @@ namespace parser {

using namespace relay;

using MetaTable = Map<String, Array<ObjectRef>>;

/*!
* \brief Options for allocating storage.
*/
Expand Down
52 changes: 37 additions & 15 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1092,8 +1092,6 @@ class Parser {

Array<TypeVar> generics;
if (Peek()->token_type == TokenType::kLSquare) {
// If we have generics we need to add a type scope.
PushTypeScope();
generics = ParseSequence<TypeVar>(
TokenType::kLSquare, TokenType::kComma, TokenType::kRSquare, [&]() {
auto type_var_name = Match(TokenType::kIdentifier).ToString();
Expand Down Expand Up @@ -1444,6 +1442,10 @@ class Parser {
ICHECK(attr_obj.defined());
attrs = Downcast<Attrs>(attr_obj);
}
} else {
this->diag_ctx.EmitFatal(Diagnostic::Error(op->span)
<< "unable to determine the 'attrs_type_key' with which "
"to represent the call attributes for this operator");
}
}
return true;
Expand Down Expand Up @@ -1867,7 +1869,7 @@ class Parser {
};

Parser InitParser(const std::string& file_name, const std::string& file_content,
Optional<IRModule> init_module) {
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(0) << "InitParser: file_name: " << file_name << "file_content_size: " << file_content.size();
SourceName src_name = SourceName::Get(file_name);
Source source(src_name, file_content);
Expand All @@ -1886,42 +1888,62 @@ Parser InitParser(const std::string& file_name, const std::string& file_content,
auto tokens_and_table = Tokenize(diag_ctx, source);

auto tokens = tokens_and_table.first;
auto meta_data_table = tokens_and_table.second;
MetaTable meta_data_table = tokens_and_table.second.ToMetadata();

// Merge any entries in init_meta_table into anything captured in the #[metadata] section
// of the file_content. Metadata references within file_content must use indexes which account
// for this ordering.
for (const auto& pair : init_meta_table) {
Array<ObjectRef> items;
if (meta_data_table.count(pair.first)) {
items = meta_data_table[pair.first];
}
for (const auto& obj : pair.second) {
items.push_back(obj);
}
meta_data_table.Set(pair.first, items);
}

return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), meta_data_table.ToMetadata());
return Parser(module, diag_ctx, source, tokens, DefaultOpTable(), std::move(meta_data_table));
}

IRModule ParseModule(std::string file_name, std::string file_content,
Optional<IRModule> init_module) {
IRModule ParseModule(const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
VLOG(0) << "ParseModule";
auto parser = InitParser(file_name, file_content, init_module);
auto parser = InitParser(file_name, file_content, init_module, init_meta_table);
auto mod = parser.ParseModule();
ICHECK(mod.defined()) << "The parser must return a non-null module.";
// NB(@jroesch): it is very important that we render any errors before we procede
// if there were any errors which allow the parser to procede we must render them
// NB(@jroesch): it is very important that we render any errors before we proceed
// if there were any errors which allow the parser to proceed we must render them
// here.
parser.diag_ctx.Render();
auto infer_type = tvm::relay::transform::InferType();
ICHECK(infer_type.defined()) << "The type inferencer must be non-null.";
return infer_type(mod);
}

Expr ParseExpr(std::string file_name, std::string file_content) {
Expr ParseExpr(const std::string& file_name, const std::string& file_content) {
VLOG(0) << "ParseExpr";
auto parser = InitParser(file_name, file_content, Optional<IRModule>());
auto parser = InitParser(file_name, file_content, Optional<IRModule>(), MetaTable());
parser.ParseSemVer(false);
parser.PushScope();
auto expr = parser.ParseExpr();
parser.Match(TokenType::kEndOfFile);
// NB(@jroesch): it is very important that we render any errors before we procede
// if there were any errors which allow the parser to procede we must render them
// NB(@jroesch): it is very important that we render any errors before we proceed
// if there were any errors which allow the parser to proceed we must render them
// here.
parser.diag_ctx.Render();
return expr;
}

TVM_REGISTER_GLOBAL("parser.ParseModuleInContext")
.set_body_typed([](const std::string& file_name, const std::string& file_content,
const Optional<IRModule>& init_module, const MetaTable& init_meta_table) {
return ParseModule(file_name, file_content, init_module, init_meta_table);
});

TVM_REGISTER_GLOBAL("parser.ParseModule")
.set_body_typed([](tvm::String file_name, tvm::String file_content) {
.set_body_typed([](const std::string& file_name, const std::string& file_content) {
return ParseModule(file_name, file_content);
});

Expand Down
6 changes: 3 additions & 3 deletions src/parser/source_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Source::Source(SourceName src_name, std::string source) {
}

tvm::String Source::GetLine(int line) {
DLOG(INFO) << "Source::GetLine: line=" << line;
VLOG(1) << "Source::GetLine: line=" << line;
ICHECK(line - 1 < static_cast<int64_t>((*this)->line_map.size()))
<< "requested line: " << line << "at index: " << (line - 1)
<< "line_map size: " << (*this)->line_map.size() << "source: " << (*this)->source;
Expand All @@ -69,10 +69,10 @@ tvm::String Source::GetLine(int line) {
auto range = (*this)->line_map.at(line - 1);
int line_start = range.first;
int line_length = range.second;
DLOG(INFO) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length;
VLOG(1) << "Source::GetLine: line_start=" << line_start << " line_length=" << line_length;
// TODO(@jroesch): expose substring on tvm::String.
auto line_text = std::string((*this)->source).substr(line_start, line_length);
DLOG(INFO) << "Source::GetLine: line_text=" << line_text;
VLOG(1) << "Source::GetLine: line_text=" << line_text;
return line_text;
}

Expand Down
Loading

0 comments on commit 163322c

Please sign in to comment.