From b6718139df9a82ff77725d5268112c0ba988b39a Mon Sep 17 00:00:00 2001 From: Mark Shields <87091372+mbs-octoml@users.noreply.github.com> Date: Wed, 29 Sep 2021 10:00:31 -0700 Subject: [PATCH] [Relay] Merge analysis/context_analysis.cc and transforms/device_annotation.cc (#9038) * [Relay] Merge analysis/context_analysis.cc and transforms/device_annotation.cc Currently LowerTEPass (backend/te_compiler.cc) is a 'special' pass because it depends on a side-input DeviceMap. We'd like to remove that side-input, and instead recover the Device (and, ultimately, Target) for each (fused) primitive call from the AST alone. By doing so we also avoid needing to perform device planning twice: - It needs to be done before lowering so we know which primitives need to be compiled for which devices. - It then needs to be re-done after lowering and optimization as a prelude to memory planning. By baking the device plan into the AST we can simply do device planning before lowering, and run memory planning later, both as ordinary passes. While working on that issue we realized we currently have 3 'device planners': - transforms/device_annotation.cc, which supports only a small subset of Relay and uses a simple top-down algorithm to assign a device to every sub-expression. - analysis/context_analysis.cc, which makes a galant effort to support most of Relay, is based on unification rather than a top-down algorithm, but handles higher order functions by ad-hoc and fragile inlining. - transforms/annotate_target.cc, which works on Targets instead of Devices, but is otherwise like 'device planning'. We'd like to bring these together. In this PR we introduce a new transforms/device_planner.cc intended to replace transforms/device_annotation.cc and analysis/context_analysis.cc. We don't delete those two just yet since we need to switch all users off of them in the next PR. We also leave transforms/annotate_target.cc alone pending a proper RFC to bring devices and targets together sensibly, but have it firmly in our sights. transforms/device_planner.cc is based on analysis/context_analysis.cc, but is heavily reworked to: 1. Handle starting from existing "on_device" annotations as well as existing "device_copy" calls. 2. Be idempotent, with the idea we'll probably need to re-run it to 'refine' device planning to account for storge scopes. 3. Robustly handle all of Relay, particularly higher-order functions. For that we replace the inlining approach in analysis/context_analysis.cc with a higher-order unification domain. 4. Be a little more systematic with defaulting. 5. Capture the result of the analysis within the AST as new "device_copy" calls at device boundaries, and new/replaced "on_device" calls wherever the device for a sub-expression is not already 'obvious' from the sub-expression's lexical scope. 6. Provide helper visitors for passes which need to ask for the device for any sub-expression they are processing and/or preserve device information on rewrites. Those passes include: - backend/aot_executor_codegen.cc (AOTOnDemandAllocator) - backend/graph_plan_memory.cc (StorageAllocaBaseVisitor etc) - backend/te_compiler.cc (LowerTensorExprMutator) - backend/vm/lambda_lift.cc (LambdaLifter) - transforms/memory_alloc.cc (DialectRewriter) - transforms/to_a_normal_form.cc (Fill) - backend/vm/compiler.cc (VMFunctionCompiler) However we won't change any of those in this PR. See the draft https://github.com/apache/tvm/pull/8788 for the end game. * [checkpoint] Use Relay script for all unit tests. * [checkpoint] Hoist out DeviceDomain and DeviceDomains. * [checkpoint] Hoist out visitors * [checkpoint] Woops, left debug-only code in --- include/tvm/relay/transform.h | 11 + python/tvm/relay/transform/transform.py | 10 + src/relay/op/annotation/annotation.h | 3 - src/relay/transforms/device_aware_visitors.cc | 285 ++++ src/relay/transforms/device_aware_visitors.h | 317 ++++ src/relay/transforms/device_domains.cc | 482 ++++++ src/relay/transforms/device_domains.h | 304 ++++ src/relay/transforms/device_planner.cc | 1123 ++++++++++++++ .../relay/transforms/device_domains_test.cc | 71 + tests/python/relay/test_pass_plan_devices.py | 1320 +++++++++++++++++ 10 files changed, 3923 insertions(+), 3 deletions(-) create mode 100644 src/relay/transforms/device_aware_visitors.cc create mode 100644 src/relay/transforms/device_aware_visitors.h create mode 100644 src/relay/transforms/device_domains.cc create mode 100644 src/relay/transforms/device_domains.h create mode 100644 src/relay/transforms/device_planner.cc create mode 100644 tests/cpp/relay/relay/transforms/device_domains_test.cc create mode 100644 tests/python/relay/test_pass_plan_devices.py diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index cdd4c9c1dbd2..e740776d6d4f 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -444,6 +444,17 @@ TVM_DLL Pass RelayToTIRTargetHook(); */ TVM_DLL Pass ManifestAlloc(Target target_host, Map targets); +/*! + * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + * every Relay sub-expression should run (and the result stored). Captures the result of that + * analysis using new "on_device" and "device_copy" CallNodes. See + * tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} + * for help recovering the device for an arbitrary sub-expression in downstream transformations. + * + * \param default_device_type DLDeviceType for default device. + */ +TVM_DLL Pass PlanDevices(DLDeviceType default_device_type); + } // namespace transform /*! diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 7c79464bdd30..bb91afc06195 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1167,6 +1167,16 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() +def PlanDevices(default_device): + """ + Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + every Relay sub-expression should run (and the result stored). Captures the result of that + analysis using new "on_device" and "device_copy" CallNodes. Note that the device_id of + the default_device is ignored. + """ + return _ffi_api.PlanDevices(default_device) + + def FoldExplicitPadding(): """ FoldExplicitPadding finds explict padding before an op that can support diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h index 643a82116b5b..35f8b6bf50b6 100644 --- a/src/relay/op/annotation/annotation.h +++ b/src/relay/op/annotation/annotation.h @@ -81,9 +81,6 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node); */ OnDeviceProps GetOnDeviceProps(const Expr& expr); -/*! \brief Returns true if \p expr is an on_device CallNode. */ -inline bool IsOnDeviceCall(const Expr& expr) { return GetOnDeviceProps(expr).body.defined(); } - /*! * \brief Returns \p function annotated with "param_device_types" and "result_device_type" * attributes capturing parameter and result devices types respectively. diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc new file mode 100644 index 000000000000..204bce53207b --- /dev/null +++ b/src/relay/transforms/device_aware_visitors.cc @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/device_aware_visitors.cc + * \brief Visitors which track the device for the current Relay expression and Relay Vars. + */ + +#include "./device_aware_visitors.h" + +namespace tvm { +namespace relay { +namespace transform { + +// TODO(mbs): We'd probably have less tendious code duplication if we redefined the memoizing +// mutator on top of the generic Functor. + +DLDeviceType LexicalOnDeviceMixin::GetInScopeDeviceType(const Expr& expr) const { + auto props = GetOnDeviceProps(expr); + if (props.body.defined() && props.is_fixed) { + // Look through any fixed "on_device" annotations. + return props.device_type; + } + if (expr->IsInstance()) { + // Lookup variable binding. + auto itr = var_device_types_.find(Downcast(expr)); + if (itr == var_device_types_.end()) { + return kInvalidDeviceType; + } else { + return itr->second; + } + } + // Otherwise use the currently in-scope device type. + if (expr_device_types_.empty()) { + return kInvalidDeviceType; + } else { + return expr_device_types_.back(); + } +} + +void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } + +void LexicalOnDeviceMixin::ExitFunctionBody() { + ICHECK_GT(function_nesting_, 0); + --function_nesting_; +} + +void LexicalOnDeviceMixin::PushDeviceType(DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + expr_device_types_.emplace_back(device_type); +} + +void LexicalOnDeviceMixin::PopDeviceType() { + if (expr_device_types_.empty()) { + return; + } + expr_device_types_.pop_back(); +} + +void LexicalOnDeviceMixin::PushBoundVar(Var var, DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + ICHECK(var_device_types_.find(var) == var_device_types_.end()); + var_device_types_.emplace(std::move(var), device_type); +} + +void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { + auto itr = var_device_types_.find(var); + if (itr == var_device_types_.end()) { + return; + } + var_device_types_.erase(itr); +} + +void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } +} + +void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec). + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + PopBoundVar((*itr)->var); + PostVisitLet_(*itr); + } + PostVisitLetBlock_(let_node); +} + +void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + ExprVisitor::VisitExpr_(function_node); +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const CallNode* call_node) { + ExprVisitor::VisitExpr_(call_node); +} + +void DeviceAwareExprVisitor::PreVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); +} + +void DeviceAwareExprVisitor::PostVisitLet_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PostVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + Expr result = DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + + return result; + } +} + +Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector> bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + std::pair pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node); + expr = inner_let_node->body; + } + + expr = VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* pre_let_node = std::get<3>(*itr); + PopBoundVar(pre_let_node->var); + Let post_let = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), + /*body=*/expr, /*span=*/std::get<2>(*itr)); + expr = PostVisitLet_(pre_let_node, post_let.get()); + } + return PostVisitLetBlock_(let_node, expr.as()); +} + +Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + Expr expr = VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + return OnDevice(expr, props.device_type, props.is_fixed); + } else { + return DeviceAwareVisitExpr_(call_node); + } +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return ExprMutator::VisitExpr_(function_node); +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const CallNode* call_node) { + return ExprMutator::VisitExpr_(call_node); +} + +void DeviceAwareExprMutator::PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ +} + +std::pair DeviceAwareExprMutator::PreVisitLetBinding_(const Var& var, + const Expr& value) { + return std::make_pair(Downcast(VisitExpr(var)), VisitExpr(value)); +} + +Expr DeviceAwareExprMutator::PostVisitLet_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +Expr DeviceAwareExprMutator::PostVisitLetBlock_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h new file mode 100644 index 000000000000..8611f87efa06 --- /dev/null +++ b/src/relay/transforms/device_aware_visitors.h @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/device_aware_visitors.h + * \brief Visitors which track the device for the current Relay expression and Relay Vars. + */ + +#ifndef TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ +#define TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ + +#include +#include +#include +#include + +#include +#include +#include + +#include "../op/annotation/annotation.h" + +namespace tvm { +namespace relay { +namespace transform { + +/*! + * \brief Helper class for expression transformers which need to keep track of the device + * holding the results of expressions and bound variables. This is recovered from the + * "on_device" function attributes and fixed "on_device" CallNodes added by the PlanDevices + * pass. + * + * \sa \p DeviceAwareExpr{Visitor,Mutator}. + */ +class LexicalOnDeviceMixin { + protected: + /*! + * \brief Returns the device type on which the result of \p expr should/will be stored, assuming + * Push/Pop DeviceType/BoundVar have been correctly called. Returns \p kInvalidDeviceType if + * stack is empty and no bound vars have device types. + */ + DLDeviceType GetInScopeDeviceType(const Expr& expr) const; + + /*! \brief Indicate a function body is being entered. */ + void EnterFunctionBody(); + + /*! \brief Indicate a function body has been processed. */ + void ExitFunctionBody(); + + /*! \brief Push a device type onto the lexical device stack. Ignore if \p kInvalidDeviceType. */ + void PushDeviceType(const DLDeviceType device_type); + + /*! \brief Pop a device type from the lexical device stack. Ignore if stack is empty. */ + void PopDeviceType(); + + /*! \brief Remember that \p var will be stored on \p device_type. Ignore if \p kInvalidDeviceType. + * + * CAUTION: Despite the name we don't support re-entering the same function body. + */ + void PushBoundVar(Var var, DLDeviceType device_type); + + /*! \brief Remove the binding for \p var to it's device type. Ignore if var is not bound. */ + void PopBoundVar(const Var& var); + + /*! + * \brief Returns the number of function definitions wrapping the currently visited expression. + */ + int function_nesting() const { return function_nesting_; } + + private: + /*! + * \brief The number of function bodies entered. Since many transforms need to distinguish global + * functions from local functions this supports the mixin's \p is_global() helper method. + */ + int function_nesting_ = 0; + + /*! + * \brief The stack of lexically enclosing "on_device" devices types, from outermost to innermost. + * When visiting an expression other than a variable we can assume the expression result is + * to be stored on device_type_.back(). + */ + std::vector expr_device_types_; + /*! + * \brief A map from in-scope variable to their device types. We may assume the variable is only + * ever bound to a value stored on this device at runtime. + */ + std::unordered_map + var_device_types_; +}; + +template +class DeviceAwareExprFunctor; + +/*! + * \brief ExprFunctor which tracks devices. We only support 'visitor' style implementation + * with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without + * any memoization. + */ +template <> +class DeviceAwareExprFunctor : public ExprFunctor, + public LexicalOnDeviceMixin { + private: + using TSuper = ExprFunctor; + + public: + void VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } + } + + void VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* visited_let_node = *itr; + PopBoundVar(visited_let_node->var); + PostVisitLet_(visited_let_node); + } + PostVisitLetBlock_(let_node); + } + + void VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } + } + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return TSuper::VisitExpr_(function_node); + } + + virtual void DeviceAwareVisitExpr_(const CallNode* call_node) { + return TSuper::VisitExpr_(call_node); + } + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ + } + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); + } + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node) { /* no-op */ + } + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node) {} +}; + +/*! \brief ExprVisitor which tracks devices. */ +class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { + public: + using ExprVisitor::VisitExpr_; + + void VisitExpr_(const FunctionNode* function_node) final; + void VisitExpr_(const LetNode* let_node) final; + void VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual void DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node); +}; + +/*! \brief ExprMutator which tracks devices. */ +class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { + public: + Expr VisitExpr_(const FunctionNode* function_node) final; + Expr VisitExpr_(const LetNode* let_node) final; + Expr VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual Expr DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual Expr DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual std::pair PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation just returns a reference to the post-visited node. + */ + virtual Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation returns reference to let node. + */ + virtual Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node); +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc new file mode 100644 index 000000000000..15784856edbf --- /dev/null +++ b/src/relay/transforms/device_domains.cc @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_domains.cc + * \brief Unification domain for the device planner. + */ + +#include "./device_domains.h" + +#include + +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" + +namespace tvm { +namespace relay { +namespace transform { + +namespace { + +// Ye olde boost hash mixer. +constexpr size_t mix(size_t h1, size_t h2) { + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); +} + +/*! + * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather + * than the original "device_copy" operator. + * + * See te_compiler.cc for where this rewriting occurs. + */ +DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { + auto tir_call_attrs = call_node->attrs.as(); + if (tir_call_attrs == nullptr) { + return {}; + } + if (tir_call_attrs->metadata.count("source_device") != 1 || + tir_call_attrs->metadata.count("dst_device") != 1) { + return {}; + } + ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1"; + return { + call_node->args[0], + static_cast( + Downcast(tir_call_attrs->metadata["source_device"])->value), + static_cast(Downcast(tir_call_attrs->metadata["dst_device"])->value)}; +} + +} // namespace + +// The following hash and equality helpers give each free first-order domain pointer its own +// distinct identity. + +size_t DeviceDomainHash::operator()(const DeviceDomainPtr& domain) const { + if (domain->is_free()) { + // Give each free first-order domain its own identity. + return static_cast(reinterpret_cast(domain.get())); + } else { + size_t h = domain->args_and_result_.size(); + h = mix(h, std::hash()(static_cast(domain->device_type_))); + for (const auto& sub_domain_ptr : domain->args_and_result_) { + h = mix(h, DeviceDomainHash()(sub_domain_ptr)); + } + return h; + } +} + +bool DeviceDomainEqual::operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { + if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) { + // Mismatched arities are never equal. + // (Though we'll never ask to do such a comparison explicitly, the hash map + // may do so implicitly due to hash collisions.) + return false; + } + if (lhs->is_free() && rhs->is_free()) { + // Compare first-order free domains by their address. + return lhs.get() == rhs.get(); + } + if (lhs->args_and_result_.empty()) { + // Compare first-order domains by their device type -- free vs bound will compare as false. + return lhs->device_type_ == rhs->device_type_; + } else { + // Compare higher-order domains pointwise. + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) { + return false; + } + } + return true; + } +} + +/* static */ +DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, DLDeviceType device_type) { + if (const auto* func_type_node = type.as()) { + std::vector args_and_result; + args_and_result.reserve(func_type_node->arg_types.size() + 1); + for (const auto& arg_type : func_type_node->arg_types) { + args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType)); + } + args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type)); + return std::make_shared(std::move(args_and_result)); + } else { + return std::make_shared(device_type); + } +} + +DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) { + DeviceDomainPtr root = domain; + while (true) { + auto itr = domain_to_equiv_.find(root); + if (itr == domain_to_equiv_.end()) { + break; + } + ICHECK_NE(itr->second, root); + root = itr->second; + ICHECK_NOTNULL(root); + } + // Path compression. + while (domain != root) { + auto itr = domain_to_equiv_.find(domain); + ICHECK(itr != domain_to_equiv_.end()); + domain = itr->second; + ICHECK_NOTNULL(domain); + itr->second = root; + } + return root; +} + +DeviceDomainPtr DeviceDomains::Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + // TODO(mbs): Proper diagnostics. + ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size()) + << "Device domains:" << std::endl + << ToString(lhs) << std::endl + << "and" << std::endl + << ToString(rhs) << std::endl + << "do not have the same kind and can't be unified."; + if (rhs->is_free()) { + return lhs; + } else if (lhs->is_free()) { + return rhs; + } else if (lhs->args_and_result_.empty()) { + // Must have consistent device types for first order domains. + if (lhs->device_type_ != rhs->device_type_) { + // TODO(mbs): Proper diagnostics. + std::ostringstream os; + os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_; + throw Error(os.str()); + } + return lhs; + } else { + // Recurse for higher-order. + std::vector args_and_result; + args_and_result.reserve(lhs->args_and_result_.size()); + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i])); + } + return MakeDomain(std::move(args_and_result)); + } +} + +DeviceDomainPtr DeviceDomains::Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { + lhs = Lookup(lhs); + rhs = Lookup(rhs); + auto joined_domain = Join(lhs, rhs); + if (!DeviceDomainEqual()(lhs, joined_domain)) { + domain_to_equiv_.emplace(lhs, joined_domain); + } + if (!DeviceDomainEqual()(rhs, joined_domain)) { + domain_to_equiv_.emplace(rhs, joined_domain); + } + return joined_domain; +} + +void DeviceDomains::UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + if (!lhs->is_higher_order() && rhs->is_higher_order()) { + Collapse(lhs, rhs); + } else { + Unify(lhs, rhs); + } +} + +DeviceDomainPtr DeviceDomains::DomainFor(const Expr& expr) { + ICHECK(expr.defined()); + auto itr = expr_to_domain_.find(expr.get()); + if (itr != expr_to_domain_.end()) { + return Lookup(itr->second); + } + auto domain = Free(expr->checked_type()); + expr_to_domain_.emplace(expr.get(), domain); + return domain; +} + +DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { + auto itr = call_to_callee_domain_.find(call.get()); + if (itr != call_to_callee_domain_.end()) { + return Lookup(itr->second); + } + std::vector args_and_result; + + auto on_device_props = GetOnDeviceProps(call.get()); + auto device_copy_props = GetDeviceCopyProps(call.get()); + if (!device_copy_props.body.defined()) { + device_copy_props = GetPrimitiveDeviceCopyProps(call.get()); + } + + if (on_device_props.body.defined()) { + // on_device(expr, device_type=, is_fixed=false) + // on_device : fn():?x? + // + // on_device(expr, device_type=, is_fixed=true) + // on_device: fn(): + args_and_result.emplace_back( + ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type)); + if (on_device_props.is_fixed) { + args_and_result.emplace_back(args_and_result.front()); + } else { + args_and_result.emplace_back(Free(on_device_props.body->checked_type())); + } + } else if (device_copy_props.body.defined()) { + // device_copy(expr, src_dev_type=, dst_dev_type=) + // device_copy: fn(): + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type)); + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type)); + } else if (call->op == alloc_storage_op) { + ICHECK_EQ(call->args.size(), 2U); + // alloc_storage(size, alignment, device_type=) + // alloc_storage: fn(, ): + const auto* attrs = call->attrs.as(); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back( + ForDeviceType(call->checked_type(), static_cast(attrs->device_type))); + } else if (call->op == alloc_tensor_op) { + ICHECK_EQ(call->args.size(), 3U); + // alloc_tensor(storage, offset, shape) + // alloc_tensor: fn(?x?, , ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op == shape_func_op) { + ICHECK_EQ(call->args.size(), 3U); + // shape_func(func, inputs, outputs, is_inputs=[...]) + // shape_func: fn(..., , ): + // where ... is a free domain appropriate for func's type + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + // TODO(mbs): I think this should be on the cpu only when is_input = [false], but + // what do we do when we have multiple arguments with different is_input values? + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == shape_of_op) { + ICHECK_EQ(call->args.size(), 1U); + // shape_of(tensor) + // shape_of: fn(?x?): + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == invoke_tvm_op) { + ICHECK_EQ(call->args.size(), 3U); + // invoke_tvm_op(op, inputs, outputs) + // invoke_tvm_op: fn(..., ?x?, ?x?):?x? + // where ... is a free domain appropriate for op's type + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + } else if (call->op == reshape_tensor_op) { + ICHECK_EQ(call->args.size(), 2U); + // reshape_tensor(data, shape) + // reshape_tensor: fn(?x?, ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op->IsInstance()) { + // (arg1, ..., argn) + // : fn(?x?, ..., ?x?):?x? + // (all args and result must be first-order). + auto free_domain = Free(arb_); + for (size_t i = 0; i < call->args.size(); ++i) { + args_and_result.emplace_back(free_domain); + } + args_and_result.emplace_back(free_domain); + } else if (call->op->IsInstance()) { + // (arg1, ..., argn) + // : fn(?x1?, ..., ?xn?):?xr? + // where we force all possibly higher-order ?xi? to be collapsed to the first-order ?xr?. + // TODO(mbs): This assumes we've eta-expanded constructors, thus all constructors appear + // in callee positions. + const auto* func_type_node = call->op->checked_type().as(); + ICHECK_NOTNULL(func_type_node); + ICHECK_EQ(func_type_node->arg_types.size(), call->args.size()); + auto result_domain = Free(func_type_node->ret_type); // first-order + for (const auto& arg_type : func_type_node->arg_types) { + auto param_domain = Free(arg_type); // possibly higher-order + UnifyCollapsed(result_domain, param_domain); // collapse if required + args_and_result.emplace_back(param_domain); + } + args_and_result.emplace_back(result_domain); + } else { + // Defer to normal case where op can be an arbitrary expression. + return DomainFor(call->op); + } + auto domain = MakeDomain(std::move(args_and_result)); + call_to_callee_domain_.emplace(call.get(), domain); + return domain; +} + +void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) { + auto lhs_domain = DomainFor(lhs); + auto rhs_domain = DomainFor(rhs); + try { + Unify(lhs_domain, rhs_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expressions:" << std::endl + << PrettyPrint(lhs) << std::endl + << "with device:" << std::endl + << ToString(lhs_domain) << "and:" << std::endl + << PrettyPrint(rhs) << std::endl + << "with device:" << std::endl + << ToString(rhs_domain) << std::endl + << e.what(); + } +} + +void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + Unify(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } +} + +void DeviceDomains::UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + UnifyCollapsed(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } +} + +bool DeviceDomains::AnyFree(DeviceDomainPtr domain) { + domain = Lookup(domain); + if (domain->is_free()) { + return true; + } + for (const auto& sub_domain : domain->args_and_result_) { + if (AnyFree(sub_domain)) { + return true; + } + } + return false; +} + +void DeviceDomains::Collapse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain) { + for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { + Unify(higher_order_domain->function_param(i), first_order_domain); + } + Unify(higher_order_domain->function_result(), first_order_domain); +} + +void DeviceDomains::SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) { + ICHECK_NE(default_device_type, kInvalidDeviceType); + domain = Lookup(domain); + if (domain->is_free()) { + // Will never throw since lhs is free. + Unify(domain, std::make_shared(default_device_type)); + } else if (!domain->args_and_result_.empty()) { + for (const auto& sub_domain : domain->args_and_result_) { + SetDefault(sub_domain, default_device_type); + } + } +} + +void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain, + DLDeviceType default_device_type) { + if (!domain->is_higher_order()) { + SetDefault(domain, default_device_type); + return; + } + DLDeviceType result_device_type = ResultDeviceType(domain); + if (result_device_type == kInvalidDeviceType) { + // If the function result device is still free use the given default. + result_device_type = default_device_type; + } + // Default any remaining free parameters to the function result device. + SetDefault(domain, result_device_type); +} + +std::string DeviceDomains::ToString(DeviceDomainPtr domain) { + domain = Lookup(domain); + std::ostringstream os; + if (domain->is_free()) { + // first-order free + os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; + } else if (domain->args_and_result_.empty()) { + // first-order bound + os << "<" << domain->device_type_ << ">"; + } else { + // higher-order + os << "fn("; + for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) { + if (i > 0) { + os << ","; + } + os << ToString(domain->args_and_result_[i]); + } + os << "):" << ToString(domain->args_and_result_.back()); + } + return os.str(); +} + +std::string DeviceDomains::ToString() { + std::ostringstream os; + for (const auto& pair : expr_to_domain_) { + os << "expression:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + for (const auto& pair : call_to_callee_domain_) { + os << "call:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "callee domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + return os.str(); +} + +DeviceDomainPtr DeviceDomains::ResultDomain(DeviceDomainPtr domain) { + domain = Lookup(domain); + while (!domain->args_and_result_.empty()) { + domain = Lookup(domain->args_and_result_.back()); + } + return domain; +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/device_domains.h b/src/relay/transforms/device_domains.h new file mode 100644 index 000000000000..a29370a0e807 --- /dev/null +++ b/src/relay/transforms/device_domains.h @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_domains.h + * \brief Unification domain for the device planner. + */ + +#ifndef TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ +#define TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { + +class DeviceDomain; +using DeviceDomainPtr = std::shared_ptr; + +/*! + * \brief Represents the domain over which we collect equality constraints. + * + * \code + * D ::= ?x? -- first order, free + * | -- first order, bound + * | fn(D1, ..., Dn):Dr -- higher order + * \endcode + * + * We require a function value to be on the same device as its result. To support that we need + * a notion of the 'result domain' of a domain: + * \code + * result_domain(?x?) = ?x? + * result_domain() = + * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) + * \endcode + */ +class DeviceDomain { + public: + /*! + * \brief Constructs a first-order domain of \p device_type, which may be + * \p kInvalidDeviceType to indicate the domain is free. + */ + explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {} + + /*! + * \brief Constructs a higher-order domain, where \p args_and_result contain the + * function argument and result domains in order. + */ + explicit DeviceDomain(std::vector args_and_result) + : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {} + + /*! \brief Returns true if domain is first-order and free. */ + bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); } + + /*! \brief Returns true if domain is higher-order. */ + bool is_higher_order() const { return !args_and_result_.empty(); } + + DLDeviceType first_order_device_type() const { + ICHECK(args_and_result_.empty()); + return device_type_; + } + + size_t function_arity() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.size() - 1UL; + } + + DeviceDomainPtr function_param(size_t i) const { + ICHECK(!args_and_result_.empty()); + ICHECK_LT(i + 1, args_and_result_.size()); + return args_and_result_[i]; + } + + DeviceDomainPtr function_result() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.back(); + } + + private: + /*! + * \brief If this is a function domain then always kInvalidDevice. Otherwise will be + * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is + * bound. + */ + const DLDeviceType device_type_; + + /*! + * \brief If this is a function domain then the sub-domains for each of the function's + * arguments, and the domain for its result. Otherwise empty. + */ + const std::vector args_and_result_; + + friend struct DeviceDomainHash; + friend struct DeviceDomainEqual; + friend class DeviceDomains; +}; + +// The following hash and equality helpers give each free first-order domain pointer its own +// distinct identity. +struct DeviceDomainHash { + size_t operator()(const DeviceDomainPtr& domain) const; +}; + +struct DeviceDomainEqual { + public: + bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const; +}; + +/*! + * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation + * built up by calls to \p Unify. + */ +class DeviceDomains { + public: + DeviceDomains() = default; + + /*! + * \brief Returns a domain appropriate for \p type who's result domain is bound + * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain + * will be free. + */ + static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type); + + /*! + * \brief Returns a higher-order domain with \p args_and_results. + */ + static DeviceDomainPtr MakeDomain(std::vector arg_and_results) { + return std::make_shared(std::move(arg_and_results)); + } + + /*! \brief Returns a domain with the given result device type appropriate \p device_type. */ + static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) { + ICHECK_NE(device_type, kInvalidDeviceType); + return MakeDomain(type, device_type); + } + + /*! \brief Returns a free domain appropriate for \p type. */ + static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); } + + /*! \brief Returns the domain representing the equivalence class containing \p domain. */ + DeviceDomainPtr Lookup(DeviceDomainPtr domain); + + /*! + * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs. + * + * Throws \p Error on failure. + */ + DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + + /*! + * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p + * rhs disagree on bound device type. + * + * Throws \p Error on failure. + */ + // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but + // given we have refs to functions I'm prepared to be surprised. + DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs); + + /*! + * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order, + * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as + * \p Unify. + * + * Throws \p Error on failure. + */ + void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + + /*! \brief Returns true if a domain is known for \p expr. */ + bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); } + + /*! \brief Returns the domain representing \p expr. */ + DeviceDomainPtr DomainFor(const Expr& expr); + + /*! + * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the + * callee is a primitive or special operation we handle it specially. Otherwise defers to \p + * DomainFor(call->op). + * + * This special handling is needed: + * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices. + * - To handle some special ops which constrain devices to the CPU. + * - To allow the same primitive to be called on different devices at different call sites. + * Since each call to the op can have a different domain we index the ops by the call expression + * rather than the op itself. + */ + DeviceDomainPtr DomainForCallee(const Call& call); + + /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */ + void UnifyExprExact(const Expr& lhs, const Expr& rhs); + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + */ + void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain); + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + * If \p expected_domain is higher-order but \p expr is first-order, require all arguments + * and the result of \p expected_domain to have the same domain as for \p expr. + */ + void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain); + + /*! \brief Returns true if \p domain contains any free sub-domains. */ + bool AnyFree(DeviceDomainPtr domain); + + /* + * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. + * This can be used to handle functions within tuples, references and ADTs since we don't + * attempt to track anything beyond 'the device' for expressions of those first-order types. + * + * Throws \p Error on failure. + */ + void Collapse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain); + + /*! \brief Force all free domains in \p domain to default to \p default_device_type. */ + void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type); + + /*! + * \brief If \p domain is higher-order and its result domain is free, force it to + * \p default_device_type. Then force any remaining free domains to the result domain + * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault. + */ + void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type); + + /*! \brief Returns one-line description of \p domain for debugging. */ + std::string ToString(DeviceDomainPtr domain); + + /*! \brief Returns description of entire system of constraints for debugging */ + std::string ToString(); + + /*! + * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment). + */ + DeviceDomainPtr ResultDomain(DeviceDomainPtr domain); + + /*! + * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain + * comment). + */ + DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) { + return ResultDomain(domain)->first_order_device_type(); + } + + private: + /*! \brief Intrinsics we need to handle specially. */ + const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); + const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); + const Op& shape_of_op = Op::Get("vm.shape_of"); + const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); + const Op& shape_func_op = Op::Get("vm.shape_func"); + const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); + /*! \brief The CPU device type for special operators such as dynamic shape functions. */ + const DLDeviceType cpu_device_type_ = kDLCPU; + /*! \brief Placeholder for any first-order type. */ + Type arb_ = TupleType(); + /*! \brief The domain for first-order expressions on the CPU. */ + DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_); + + /*! \brief Maps expressions to their domains as determined during analysis. */ + std::unordered_map expr_to_domain_; + + /*! + * \brief Maps call expressions to the domains for their callee where the callee is a primitive. + */ + std::unordered_map call_to_callee_domain_; + + /*! \brief Maps device domains to their equivalent domains as determined during unification. */ + std::unordered_map + domain_to_equiv_; +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc new file mode 100644 index 000000000000..35bf406263e4 --- /dev/null +++ b/src/relay/transforms/device_planner.cc @@ -0,0 +1,1123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_planner.cc + * \brief Determines a unique device to hold the result of every Relay sub-expression. + * + * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. + * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the + * specific target associated with D (this is recovered independently via a TargetMap), and we + * do not track the storage scope within D (this is yet to be implemented). + * + * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', + * see below. + * + * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes: + * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and + * 'dst_dev_type' device type, which constrain the argument and context of the call + * respectively. It is ok if source and destination devices are the same, such no-op copies + * will be removed after accounting for the device preference. + * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which + * constrains the argument of the call, but (usually, see below) leaves the context + * unconstrained. These are called 'annotations' in the rest of the code, have no operational + * significance by themselves, but may trigger the insertion of a new "device_copy". + * - In two situations the result of an "on_device" CallNode may also be constrained to the + * given device: + * - The "on_device" call occurs at the top-level of a function body, or occurs as an + * immediately let-bound expression. In this situation the extra degree of freedom in + * the function result and let-binding leads to surprising device copies, so we simply + * force the function result or let-bound variable to the given device. + * - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted + * it ourselves during an earlier invocation of this pass. This helps make this pass + * idempotent. + * + * We proceed in four phases: + * + * Phase 0 + * ------- + * We rewrite the programs to handle some special cases: + * - "on_device" calls at the top-level of function or immediately let-bound are rewritten + * to have \code is_fixed=true \endcode. + * - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written + * \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from + * the tuple rather than project from a copy of the tuple. We'll do this by rewriting. + * + * Phase 1 + * ------- + * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see + * below) to all other Relay sub-expressions. (For idempotence we also respect any existing + * "param_device_types" and "result_device_type" function attributes we introduce below.) + * + * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the + * same device. However each call site can use a different device. In other words primitives are + * 'device polymorphic' since we compile and execute them for each required device. + * + * For most Relay expressions the device for the overall expression is the same as the device + * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple + * itself, the condition and arms of an \p if must all be on the same device as the overall if, + * and so on. + * + * Some special ops (or 'dialects') are handled: + * - Relay supports computing the shape of tensors and operators at runtime using "shape_of", + * "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors + * they describe may reside on any device. + * - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again + * shapes reside on the CPU, but the allocated tensors may reside on any device. + * + * Two Relay expression have special handling: + * - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the + * overall let. However the result of \p e1 may be on a different device. + * - For a function \code fn(x, y) { body } \endcode the result of the function must be on the + * same device as \p body. However parameters \p x and \p may be on different devices, even + * different from each other. Every call to the function must use the same choice of parameter + * and result devices -- there is no 'device polymorphism' for Relay functions. + * + * Phase 2 + * ------- + * After flowing constraints we apply some defaulting heuristics (using a global default device) + * to fix the device for any as-yet unconstrained sub-expressions. + * - Unconstrained function result devices default to the global default device. + * - Unconstrained function parameters devices default to the device for the function result. + * - Unconstrained let-bound expression devices default to the device for the overall let. + * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to + * the global default device. Worth a design doc with motivating examples I think. + * + * Phase 3 + * ------- + * Finally, the result of this analysis is reified into the result as: + * - Additional "param_device_types" (an Array) and "result_device_type" (Integer) + * attributes for every function (both top-level and local). These describe the devices for + * the function's parameters and the result. + * - Additional "device_copy" CallNodes where a copy is required in order to respect the + * intent of the original "on_device" CallNodes. + * - Additional "on_device" CallNodes where the device type of an expression does not match + * that of the lexically enclosing "on_device" CallNode or function attribute. In practice + * this means "on_device" CallNodes may appear in two places: + * - On a let-bound expression if its device differs from the overall let expression. + * - On a call argument if its device differs from the call result. In particular, the + * argument to a "device_copy" call will always be wrapped in an "on_device". (That may + * seem pedantic but simplifies downstream handling.) + * However since we make it easy to track devices for variables we never wrap an "on_device" + * around a var or global var. These uses of "on_device" imply both the argument and result are + * on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true, + * which helps make this pass idempotent. + * + * Helper visitors (in device_aware_visitors.h) can be used by downstream transforms to recover + * the device for any expression for their own use, e.g. during memory planning. All downstream + * passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion + * to ANF must respect the lexical scoping convention: + * \code + * f(on_device(g(h(a, b), c), device_type=CPU)) + * ==> + * let %x0 = on_device(h(a, b), device_type=CPU) + * let %x1 = on_device(g(%x0), device-type=CPU) + * f(on_device(%x1, device_type=CPU)) + * \endcode + * + * This pass can be run before FuseOps it can use device-specific fusion rules. + * + * 'Stored on' vs 'Executes on' + * ---------------------------- + * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the + * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for + * primitives. + * + * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are + * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific + * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to + * know exactly which device (possibly one of a number of available 'CPU'-like devices) is + * responsible for execution. Currently that's handled independently by the \p AnnotateTargets + * pass, but we'd like to fold that into device planning here to ensure everything is consistent. + * + * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay + * expression (eg an if expression) on one device even though the tensor data resides on + * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on' + * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just + * compile the function body for the function's result device. + * + * This works after conversion to ANF provided the compilation for a let expression is prepared + * to make a cross-device call. However we leave it to a downstream transformation to heuristically + * minimize cross-device calls by moving device copies out of functions. E.g.: + * \code + * def @f() { // execute on CPU + * let x = on_device(...GPU computation..., device_type=GPU); + * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU) + * } + * def @main() { + * ... call @f() on CPU ... + * } + * \endcode + * could be rewritten to: + * \code + * def @f() { // execute on GPU + * let x = ...GPU computation...; + * ...GPU computation... + * } + * def @main() { + * let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU) + * ... use x on CPU ... + * } + * \endcode + * + * Higher-order shenanigans + * ------------------------ + * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions + * as arguments (even anonymous functions), return functions, evaluate conditional expressions + * over functions, and so on. We handle this during constraint solving using the domain: + * \code + * D ::= -- first-order + * | fn(D,...,D):D -- higher-order + * \endcode + * In this way we can determine the device for all function parameters and results. E.g. for + * \code + * let f = fn(x, y) { ... } + * let g = fn(f, z) { f(z, z) } + * g(f, on_device(..., device_type=CPU)) + * \endcode + * the parameters \p x and \p y will be on the CPU. + * + * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a + * function. Our analysis must guarantee that the function's parameters and result devices are + * consistent for \p e2, \p e3, and the context of the call. But: + * - Which device holds the closure result of evaluating \p e1 ? + * - If \p e2 is of function type, what does that mean when we say every function parameter + * is on a device? + * - If \p e1 returns a function, what does that mean when we say every function result is + * on a device? + * + * Since higher-order aspects are later compiled away (by 'defunctionalization' + * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular, + * we really don't want our domain \p D to allow for yet another device for the function closure. + * So we'll just force the 'device for a function' to be the same as the device for the function's + * result using the notion of the 'result domain' for a domain: + * \code + * result_domain() = + * result_domain(fn(D1,...,Dn):Dr) = result_domain(Dr) + * \endcode + * + * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the + * analysis encounters a function inside one of those it simply forces all argument and result + * devices for the function to match the device for the first-order expression. For example, + * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function + * parameters and result must similarly be on the GPU. + * + * ------- + * | AOR | This pass supports all of Relay. + * ------- + * ^ + * | + * `-- Mark's stamp of completeness :-) + * + * TODO(mbs): + * * Though on_device is the identity for all types we can't wrap it around functions/constructors + * taking type args (or at least not without changing type_infer.cc to see through them). + * This is not currently handled generally. + * * Proper diagnostics for unification failure using spans. + * * Make sure the pass is idempotent even after FuseOps etc. + * * Support application of constructors properly. Are they device polymorphic? + * * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'. + * * Support running the pass post FuseOps (so need to understand primitive functions, both + * outlines and lined) and post the VM transforms (probably need to support more intrinsic + * forms?). + * * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default + * device for primitives vs the default device for the rest of Relay. + * * We'll probably need some support for partial 'device polymorphism' for functions once we + * incorporate targets and memory scopes into the domain. For example it's ok for the function + * body to be executed on different device ids provided they have the same target and memory + * scope. + * * Might be simpler to just let every type have a device annotation rather than work in + * a separate domain? + * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. + * * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls + * in tuples at the top level of function bodies or main expression, irrespective of the + * "on_device" body. What's up with that? + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" +#include "./device_domains.h" + +namespace tvm { +namespace relay { +namespace transform { + +namespace { + +/****** +******* Phase 0 +*******/ + +/*! + * \brief Rewrites "on_device" calls to handle some special cases. + * + * \code + * let %x = on_device(e, device_type=d) + * ==> let %x = on_device(e, device_type=d, is_fixed=True) + * + * fn(%x) { on_device(e, device_type=d) } + * ==> fn(%x) { on_device(e, device_type=d, is_fixed=True) + * + * on_device(e).0 + * ==> on_device(e.0) + * \endcode + */ +class RewriteOnDevices : public ExprMutator { + public: + RewriteOnDevices() = default; + + private: + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + Expr tuple = VisitExpr(tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy. + Expr tuple_get_item = + TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + auto props = GetOnDeviceProps(tuple); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "wrapping tuple get item:" << std::endl + << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl + << "with \"on_device\" for device " << props.device_type; + return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false); + } else { + return tuple_get_item; + } + } + + Expr VisitExpr_(const LetNode* let_node) final { + auto expr = GetRef(let_node); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + Expr value = VisitExpr(inner_let_node->value); + auto props = GetOnDeviceProps(value); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising let-bound expression of let:" << std::endl + << PrettyPrint(expr) << std::endl + << "to be fixed to device " << props.device_type; + value = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + expr = VisitExpr(expr); + // TODO(mbs): Avoid copy. + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + expr = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), expr, + /*span=*/std::get<2>(*itr)); + } + return expr; + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + Expr body = VisitExpr(function_node->body); + auto props = GetOnDeviceProps(body); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising body of function:" << std::endl + << PrettyPrint(GetRef(function_node)) << std::endl + << "to be fixed to device " << props.device_type; + body = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + // TODO(mbs): Avoid copy + return Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + } +}; + +/****** +******* Phase 1 +*******/ + +/* + * \brief Collects the system of device constraints for all sub-expressions in a module. + * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter. + * + * Eg from \code add(%x, %y) \endcode we know \p %x and \p %y must be on the same device. Later, + * from \code on_device(%x, device_type=d) \endcode we know \p %x must be on device \p d, and thus + * so must \p %y. + * + * Constraints can flow in interesting ways. E.g. in: + * \code + * let %f = fn(%x, %y) { add(%x, on_device(%y, device_type=d)) } + * let %g = fn(%f, %x, %y) { %f(%x, %y) } + * %g(%f, %a, %b) + * \endcode + * we discover \p %b must be on device \p d. + */ +class DeviceAnalyzer : public ExprVisitor { + public: + explicit DeviceAnalyzer(IRModule mod) + : mod_(std::move(mod)), domains_(std::make_unique()) {} + + /*! + * \brief Returns the expression-to-device-domain map for all expressions in all the global + * function definitions in the module. Expressions may have free domains, these will be resolved + * by \p DeviceDefaulter below. + */ + std::unique_ptr Analyze() { + VLOG_CONTEXT << "DeviceAnalyzer"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; + domains_->UnifyExprExact(pair.first, pair.second); + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + + // Find the higher-order domain for the callee. See DomainForCallee for the special rules + // for primitives. + VisitExpr(call_node->op); + auto func_domain = domains_->DomainForCallee(call); // higher-order + + // Build the domain for the function implied by its arguments and call context. + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + std::vector args_and_result_domains; + args_and_result_domains.reserve(call_node->args.size() + 1); + for (const auto& arg : call_node->args) { + args_and_result_domains.emplace_back(domains_->DomainFor(arg)); + VisitExpr(arg); + } + args_and_result_domains.emplace_back(domains_->DomainFor(call)); + auto implied_domain = + DeviceDomains::MakeDomain(std::move(args_and_result_domains)); // higher-order + + VLOG(1) << "initial call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied domain:" << std::endl + << domains_->ToString(implied_domain) << std::endl + << "for call:" << std::endl + << PrettyPrint(call); + + // The above must match. + try { + domains_->Unify(func_domain, implied_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Function parameters and result devices do not match those of call. Call:" + << std::endl + << PrettyPrint(call) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied call devices:" << std::endl + << domains_->ToString(implied_domain) << std::endl + << e.what(); + } + + VLOG(1) << "final call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "for call:" << std::endl + << PrettyPrint(call); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // Let var must be same device as value it is bound to. + domains_->UnifyExprExact(let->var, let->value); // may be higher-order + // Let body must be same device as overall let. + domains_->UnifyExprExact(let, let->body); // may be higher-order + + VisitExpr(let->var); + VisitExpr(let->value); + + expr = let->body; + } + + // Visit the last body + VisitExpr(expr); + } + + void VisitExpr_(const FunctionNode* function_node) final { + // No need to step into fused primitive functions as they are lowered individually according + // to the devices of all their call sites. + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + + // The function body domain must match the function result domain. + domains_->UnifyExprExact(function_node->body, + func_domain->function_result()); // may be higher-order + + VLOG(1) << "initial function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + // The parameter domains must match the function argument domains. + domains_->UnifyExprExact(function_node->params[i], + func_domain->function_param(i)); // may be higher-order + VisitExpr(function_node->params[i]); + } + + // If the function already has device attributes then we can further constrain the + // function's domain to match them. + if (GetFunctionResultDeviceType(function_node) != kInvalidDeviceType) { + std::vector args_and_result; + for (size_t i = 0; i < function_node->params.size(); ++i) { + args_and_result.emplace_back( + domains_->ForDeviceType(function_node->params[i]->checked_type(), + GetFunctionParamDeviceType(function_node, i))); + } + args_and_result.emplace_back(domains_->ForDeviceType( + function_node->body->checked_type(), GetFunctionResultDeviceType(function_node))); + auto annotation_domain = domains_->MakeDomain(std::move(args_and_result)); + try { + domains_->Unify(func_domain, annotation_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) + << "Function devices are incompatible with its \"on_device\" annotation. Function:" + << std::endl + << PrettyPrint(function) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and annotation devices:" << std::endl + << domains_->ToString(annotation_domain) << std::endl + << e.what(); + } + } + + VisitExpr(function_node->body); + + VLOG(1) << "final function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + } + + void VisitExpr_(const TupleNode* tuple_node) final { + Tuple tuple = GetRef(tuple_node); + for (size_t i = 0; i < tuple->fields.size(); i++) { + auto domain = domains_->DomainFor(tuple->fields[i]); // may be higher-order + domains_->UnifyExprCollapsed(tuple, domain); // collapse to first-order if needed + VisitExpr(tuple->fields[i]); + } + } + + void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + TupleGetItem tuple_get_item = GetRef(tuple_get_item_node); + auto domain = domains_->DomainFor(tuple_get_item); // may be higher-order + domains_->UnifyExprCollapsed(tuple_get_item_node->tuple, + domain); // collapse to first-order if needed + VisitExpr(tuple_get_item_node->tuple); + } + + class DevicePatternAnalyzer : public PatternVisitor { + public: + DevicePatternAnalyzer(DeviceDomains* domains, const ExprNode* adt_node) + : domains_(domains), adt_node_(adt_node) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + auto var_domain = domains_->DomainFor(pattern_var_node->var); // may be higher order + domains_->UnifyExprCollapsed(GetRef(adt_node_), + var_domain); // collapse to first-order if needed + } + + /*! \brief (Mutable borrow of) the domains for all expressions processed so far. */ + DeviceDomains* domains_; + /*! \brief The expression for the ADT we are matching over. */ + const ExprNode* adt_node_; + }; + + void VisitPattern(const Pattern& pattern) final {} + + void VisitExpr_(const MatchNode* match_node) final { + // For match node, we unify the value and the rhs of each clause + Match match = GetRef(match_node); + auto match_domain = domains_->DomainFor(match); // may be higher-order + DevicePatternAnalyzer pattern_analyzer(domains_.get(), match->data.get()); + domains_->UnifyExprCollapsed(match->data, match_domain); // collapse to first-order if needed + for (const auto& clause : match->clauses) { + pattern_analyzer.VisitPattern(clause->lhs); + domains_->UnifyExprExact(clause->rhs, match_domain); + VisitExpr(clause->rhs); + } + VisitExpr(match_node->data); + } + + void VisitExpr_(const GlobalVarNode* global_var_node) final { + domains_->DomainFor(GetRef(global_var_node)); + } + + void VisitExpr_(const VarNode* var_node) final { domains_->DomainFor(GetRef(var_node)); } + + void VisitExpr_(const ConstantNode* constant_node) final { + domains_->DomainFor(GetRef(constant_node)); + } + + void VisitExpr_(const ConstructorNode* constructor_node) final { + // no-op, constructors are handled at their call-sites. + // TODO(mbs): Assumes eta-expansion + } + + void VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + auto domain = domains_->DomainFor(ife); // may be higher-order + domains_->UnifyExprCollapsed(if_node->cond, domain); // collapse to first-order if needed + domains_->UnifyExprExact(if_node->true_branch, domain); + domains_->UnifyExprExact(if_node->false_branch, domain); + VisitExpr(if_node->cond); + VisitExpr(if_node->true_branch); + VisitExpr(if_node->false_branch); + } + + void VisitExpr_(const OpNode* op) final { + // no-op, primitive operators are handled at their call-sites. + } + + void VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + auto domain = domains_->DomainFor(ref_create_node->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_create, domain); // collapse to first-order if needed + VisitExpr(ref_create_node->value); + } + + void VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + auto domain = domains_->DomainFor(ref_read); // may be higher-order + domains_->UnifyExprCollapsed(ref_read_node->ref, domain); // collapse to first-order if needed + VisitExpr(ref_read_node->ref); + } + + void VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + auto domain = domains_->DomainFor(ref_write->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_write->ref, domain); // collapse to first-order if needed + domains_->UnifyExprCollapsed(ref_write, domain); // collapse to first-order if needed + VisitExpr(ref_write_node->ref); + VisitExpr(ref_write_node->value); + } + + /*! \brief The module we are analyzing. */ + IRModule mod_; + /*! \brief The domains for all expressions processed so far. */ + std::unique_ptr domains_; +}; + +/****** +******* Phase 2 +*******/ + +/*! + * \brief Ensures every sub-expression in a module has a device type, using both the global + * default and some local heuristics to avoid unnecessary additional "device_copy" CallNodes. + * + * E.g. in: + * \code + * def @main(%x, %y, %z) { + * let %a = add(%x, %y); + * multiply(%a, on_device(%z, device_type=d)) + * \endcode + * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y, + * and the device for the function result, are still 'free'. The global 'default' device type + * is first used to 'fix' \p @main's result type, which in turn 'fixes' \p %x and \p %y, which + * in turn 'fixes' the device on which the \p add and \p multiply are executed. + * + * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap + * order. + */ +class DeviceDefaulter : public ExprVisitor { + public: + DeviceDefaulter(IRModule mod, std::unique_ptr domains, + DLDeviceType default_device_type) + : mod_(std::move(mod)), + domains_(std::move(domains)), + default_device_type_(default_device_type) {} + + std::unique_ptr Default() { + VLOG_CONTEXT << "DeviceDefaulter"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + if (domains_->AnyFree(func_domain)) { + VLOG(1) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); + } + VisitExpr(function_node->body); + } + + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + auto func_domain = domains_->DomainForCallee(call); // higher-order + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + if (domains_->AnyFree(func_domain)) { + // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) + // above. But for calls to primitives we may still need to force free domains to be + // defaulted. + VLOG(1) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); + } + return ExprVisitor::VisitExpr_(call_node); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // If the let-var device is still free force it to match the overall let. + auto let_domain = domains_->DomainFor(let); // may be higher-order + DLDeviceType let_device_type = domains_->ResultDeviceType(let_domain); + ICHECK_NE(let_device_type, kInvalidDeviceType); + auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order + if (domains_->AnyFree(let_var_domain)) { + VLOG(1) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + domains_->SetDefault(let_var_domain, let_device_type); + VLOG(1) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + } + VisitExpr(let->var); + VisitExpr(let->value); + expr = let->body; + } + VisitExpr(expr); + } + + /*! \brief The module we are processing. */ + IRModule mod_; + /*! \brief The domains for all expressions. */ + std::unique_ptr domains_; + /*! \brief The default device type. */ + DLDeviceType default_device_type_; +}; + +/****** +******* Phase 3 +*******/ + +/*! + * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every + * sub-expression in a module can be easily recovered by a later transformation using simple + * lexical scoping rules (e.g. for memory planning). + * + * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard + * any existing "device_copy" CallNodes which are no-ops. + * + * - Functions are given "param_device_types" and "result_device_type" attributes to capture + * the device type for its parameters and result. + * + * - Additional "device_copy" CallNodes are inserted wherever there's a transition between + * storage device types. Since the DeviceAnalyzer phase succeeded this can only happen + * where the original program explicitly allowed a transition using an "on_device" CallNode. + * That is, we do not not try to 'fix' a program with inconsistent devices. + * + * - Additional "on_device" CallNodes are inserted so that a later transform can discover + * the device for an arbitrary sub-expression by looking only for the lexically enclosing + * "on_device" CallNode or "on_device" function attribute. In particular, since function + * arguments and let-bound expressions can be on a device different from the function + * or let body itself we will insert "on_device" CallNodes to spell out any differences. This + * applies even to the argument to a "device_copy" CallNode, which may look pedantic but + * keeps downstream processing simple. The "on_device" calls should be removed before code gen, + * which is easily done on-the-fly. + * + * For example, we'll end up with programs that look like: + * \code + * def @main(%x, %y, param_device_types=[...], result_device_type=...) { + * let %a = on_device(..., device_type=..., is_fixed=True) + * @f(%a, device_copy(on_device(..., device_type=..., is_fixed=True), + * src_device_type=..., dst_device_type=...)) + * } + * \endcode + */ +class DeviceCapturer : public ExprMutator { + public: + DeviceCapturer(IRModule mod, std::unique_ptr domains) + : mod_(std::move(mod)), domains_(std::move(domains)) {} + + IRModule Capture() { + VLOG_CONTEXT << "CaptureDevices"; + IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map); + for (const auto& pair : mod_->functions) { + VLOG(1) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; + result->Add(pair.first, Downcast(Mutate(pair.second))); + } + return result; + } + + private: + // Nothing interesting for VarNode, ConstantNode, GlobalVarNode, OpNode and ConstructorNode + + Expr VisitExpr_(const TupleNode* tuple_node) final { + auto tuple = GetRef(tuple_node); + Array fields; + fields.reserve(tuple_node->fields.size()); + for (const auto& field : tuple_node->fields) { + fields.push_back(VisitChild(tuple, field)); + } + // TODO(mbs): Avoid copy + return Tuple(std::move(fields), tuple_node->span); + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return GetRef(function_node); + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + VLOG(1) << "capturing function:" << std::endl + << PrettyPrint(function) << std::endl + << "with domain:" << std::endl + << domains_->ToString(func_domain); + + // Gather the parameter and result device types for the function attributes. + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + Array param_device_types; + param_device_types.reserve(function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType); + param_device_types.push_back(param_device_type); + } + + // Rewrite the body. Note that the body may have begun with an "on_device" so + // be prepared to insert a "device_copy". + Expr body = VisitChild( + /*lexical_device_type=*/result_device_type, + /*expected_device_type=*/result_device_type, + /*child_device_type=*/GetDeviceType(function_node->body), function_node->body); + + // TODO(mbs): Avoid copy + Function func = Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + return FunctionOnDevice(func, param_device_types, result_device_type); + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + DLDeviceType call_device_type = GetDeviceType(call); + + auto on_device_props = GetOnDeviceProps(call_node); + if (on_device_props.body.defined()) { + // We're done with the original "on_device" calls and can pinch them out. + // Note that this step has already been simulated by GetDeviceType. + return VisitExpr(on_device_props.body); + } + + auto device_copy_props = GetDeviceCopyProps(call_node); + if (device_copy_props.body.defined()) { + DLDeviceType src_device_type = device_copy_props.src_dev_type; + ICHECK_EQ(call_device_type, device_copy_props.dst_dev_type); + if (call_device_type == src_device_type) { + // We can pinch out existing "device_copy" CallNodes if their source and destinations + // match. + return VisitExpr(device_copy_props.body); + } + // else: handle as for any other call. + } + + auto func_domain = domains_->DomainForCallee(call); // higher-order + VLOG(1) << "considering call:" << std::endl + << PrettyPrint(call) << std::endl + << "on device " << call_device_type << " with function domain:" << std::endl + << domains_->ToString(func_domain); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + + // The callee is on the current device. + Expr op = VisitChild( + /*lexical_device_type=*/call_device_type, + /*expected_device_type=*/call_device_type, + /*child_device_type=*/result_device_type, call_node->op); + + // Each argument can be on the device for the corresponding function parameter. However if + // any of those differ from the overall call device then wrap them in an "on_device" to + // help downstream transforms track devices lexically. + Array args; + args.reserve(call_node->args.size()); + ICHECK_EQ(func_domain->function_arity(), call->args.size()); + for (size_t i = 0; i < call_node->args.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType) + << "for parameter " << i << " for call:" << std::endl + << PrettyPrint(call); + args.push_back(VisitChild(/*lexical_device_type=*/call_device_type, + /*expected_device_type=*/param_device_type, + /*child_device_type=*/GetDeviceType(call_node->args[i]), + call_node->args[i])); + } + // TODO(mbs): Avoid copy + return Call(std::move(op), std::move(args), call_node->attrs, call_node->type_args, + call_node->span); + } + + Expr VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iterate through chained lets, provided they all agree on their device type. + DLDeviceType let_device_type = GetDeviceType(expr); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + if (GetDeviceType(inner_let) != let_device_type) { + // We have a device transition which needs to be handled. + break; + } + // The let-bound value can be on a different device than the overall let. However if those + // devices don't agree wrap the let-bound value in an "on_device" to help downstream + // transforms track devices lexically. + Expr value = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/GetDeviceType(inner_let_node->var), + /*child_device_type=*/GetDeviceType(inner_let_node->value), + inner_let_node->value); + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + Expr body = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/let_device_type, + /*child_device_type=*/GetDeviceType(expr), expr); + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body, + /*span=*/std::get<2>(*itr)); + } + return body; + } + + Expr VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + Expr cond = VisitChild(ife, if_node->cond); + Expr true_branch = VisitChild(ife, if_node->true_branch); + Expr false_branch = VisitChild(ife, if_node->false_branch); + // TODO(mbs): Avoid copy + return If(cond, true_branch, false_branch, if_node->span); + } + + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + auto tuple_get_item = GetRef(tuple_get_item_node); + Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy + return TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + } + + Expr VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + Expr value = VisitChild(ref_create, ref_create_node->value); + // TODO(mbs): Avoid copy + return RefCreate(value, ref_create_node->span); + } + + Expr VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + Expr ref = VisitChild(ref_read, ref_read_node->ref); + // TODO(mbs): Avoid copy + return RefRead(ref, ref_read_node->span); + } + + Expr VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + Expr ref = VisitChild(ref_write, ref_write_node->ref); + Expr value = VisitChild(ref_write, ref_write_node->value); + // TODO(mbs): Avoid copy + return RefWrite(ref, value, ref_write_node->span); + } + + Expr VisitExpr_(const MatchNode* match_node) final { + auto match = GetRef(match_node); + Expr data = VisitChild(match, match_node->data); + Array clauses; + clauses.reserve(match_node->clauses.size()); + for (const auto& clause : match_node->clauses) { + Pattern lhs = VisitPattern(clause->lhs); // actually a no-op, so we're not checking vars + Expr rhs = VisitChild(match, clause->rhs); + clauses.push_back(Clause(lhs, rhs)); + } + // TODO(mbs): Avoid copy + return Match(data, std::move(clauses), match_node->complete, match_node->span); + } + + DLDeviceType GetDeviceType(const Expr& expr) { + // Look through any "on_device" CallNodes, to mimic how we will be pinching them out. + auto props = GetOnDeviceProps(expr); + Expr true_expr = props.body.defined() ? props.body : expr; + ICHECK(domains_->contains(true_expr)); + // If expr is higher order we'll return only the result domain's device type. + DLDeviceType device_type = domains_->ResultDeviceType(domains_->DomainFor(true_expr)); + ICHECK_NE(device_type, kInvalidDeviceType) + << "no device type was determined for expression:" << std::endl + << PrettyPrint(true_expr); + return device_type; + } + + /*! + * \brief Reconcile the \p child_device_type for \p child with both the \p expected_device_type + * (as required by the expression context the \p child is in) and the \p lexical_device_type + * (as a downstream transform would infer based only on lexically enclosing "on_device" + * CallNodes and function attributes.) Generally \p lexical_device_type and \p + * expected_device_type are the same by definition, but may differ in arguments to functions + * and let-bound expressions. + * + * If \p child_device_type differs from \p expected_device_type, wrap it as: + * \code + * device_copy(on_device(child', device_type=child_device_type), + * src_dev_type=child_device_type, dst_dev_type=expected_device_type) + * \endcode + * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the + * child. + * + * If \p expected_device_type differs from \p lexical_device_type, then (also) wrap + * the expression as: + * \code + * on_device(..., device_type=expected_device_type) + * \endcode + * + * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped + * by a "device_copy", even though those copies will generally all be to the same destination + * device. + */ + Expr VisitChild(DLDeviceType lexical_device_type, DLDeviceType expected_device_type, + DLDeviceType child_device_type, const Expr& child) { + ICHECK_NE(lexical_device_type, kInvalidDeviceType); + ICHECK_NE(expected_device_type, kInvalidDeviceType); + if (child->IsInstance()) { + // Primitive operators don't need to be rewritten and can have a different domain for + // each call site. + return child; + } + Expr result = VisitExpr(child); + if (child_device_type != expected_device_type) { + VLOG(1) << "creating " << DeviceCopyOp()->name << " from device type " << child_device_type + << " to device type " << expected_device_type << " for:" << std::endl + << PrettyPrint(result); + // Also wrap the child in an "on_device" so downstream transforms can track devices + // lexically. + result = MaybeOnDevice(result, child_device_type, /*is_fixed=*/true); + result = DeviceCopy(result, child_device_type, expected_device_type); + } + if (expected_device_type != lexical_device_type) { + VLOG(1) << "creating " << OnDeviceOp()->name << " for device type " << expected_device_type + << " for:" << std::endl + << PrettyPrint(result); + result = MaybeOnDevice(result, expected_device_type, /*is_fixed=*/true); + } + return result; + } + + /*! + * Common case of visiting a direct \p child of \p parent where by default the \p child + * is expected to be on the same device as the \p parent. + */ + Expr VisitChild(const Expr& parent, const Expr& child) { + DLDeviceType expected_device_type = GetDeviceType(parent); + DLDeviceType child_device_type = GetDeviceType(child); + return VisitChild(expected_device_type, expected_device_type, child_device_type, child); + } + + /*! \brief Module we are rewriting, so we can lookup global variables. */ + IRModule mod_; + /*! \brief Device domain for every expression from DeviceAnalyzer. */ + std::unique_ptr domains_; +}; + +/*! \brief Rewrite the "on_device" calls (and implicitly re-type-check). */ +tvm::transform::Pass Rewrite() { + auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { + return Downcast(RewriteOnDevices().Mutate(f)); + }; + return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite", {}); +} + +/*! \brief Run the remaining phases. */ +tvm::transform::Pass PlanDevicesCore(DLDeviceType default_device_type) { + return tvm::transform::CreateModulePass( + [=](IRModule mod, tvm::transform::PassContext pass_cnxt) -> IRModule { + // Collect the system of constraints for every sub-expression using existing "on_device" + // and "device_copy" calls. + std::unique_ptr domains = DeviceAnalyzer(mod).Analyze(); + VLOG(1) << "Domains after analysis:" << std::endl << domains->ToString(); + + // Choose sensible default devices for every sub-expression if otherwise unconstrained + // by existing "on_device" or "device_copy" calls. + domains = DeviceDefaulter(mod, std::move(domains), default_device_type).Default(); + VLOG(1) << "Domains after defaulting: " << std::endl << domains->ToString(); + + // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture + // the above map, and attach additional "param_device_types" and "result_device_type" + // attributes to all function definitions. + return DeviceCapturer(mod, std::move(domains)).Capture(); + }, + /*opt_level=*/0, "PlanDevicesCore", {}); +} + +} // namespace + +/****** +******* Overall composite Pass +*******/ + +// This function is declared in the public . +TVM_DLL tvm::transform::Pass PlanDevices(DLDeviceType default_device_type) { + std::vector passes; + passes.emplace_back(Rewrite()); + passes.emplace_back(PlanDevicesCore(default_device_type)); + return tvm::transform::Sequential(std::move(passes), "PlanDevices"); +} + +TVM_REGISTER_GLOBAL("relay._transform.PlanDevices") + .set_body_typed([](const Device& default_device) { + return PlanDevices(default_device.device_type); + }); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/cpp/relay/relay/transforms/device_domains_test.cc b/tests/cpp/relay/relay/transforms/device_domains_test.cc new file mode 100644 index 000000000000..8f263c3b3273 --- /dev/null +++ b/tests/cpp/relay/relay/transforms/device_domains_test.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Just a smoke test for the device planner's unification domain, mostly to tease out how we'd + * like to organize our cpp unit tests for functionality that's not obviously a Pass or should + * be exposed via FFI. + */ + +// TODO(mbs): Revisit cpp unit test layout or setup include dir at root of src/ +#include "../../../src/relay/transforms/device_domains.h" + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { +namespace { + +IRModule TestModule() { + return InferType()(tvm::parser::ParseModule("test", R"( + #[version = "0.0.5"] + def @f(%x : Tensor[(3, 7), float32], %y : Tensor[(3, 7), float32]) { + add(%x, %y) + } + )")); +} + +TEST(DeviceDomains, SmokeTest) { + DeviceDomains domains; + IRModule mod = TestModule(); + Function f = Downcast(mod->Lookup("f")); + + DeviceDomainPtr actual_add_domain = domains.DomainForCallee(Downcast(f->body)); + DeviceDomainPtr x_domain = domains.DomainFor(f->params[0]); + DeviceDomainPtr y_domain = domains.DomainFor(f->params[1]); + DeviceDomainPtr result_domain = DeviceDomains::Free(f->ret_type); + std::vector arg_and_results; + arg_and_results.push_back(x_domain); + arg_and_results.push_back(y_domain); + arg_and_results.push_back(result_domain); + DeviceDomainPtr implied_add_domain = DeviceDomains::MakeDomain(std::move(arg_and_results)); + domains.Unify(actual_add_domain, implied_add_domain); + domains.Unify(x_domain, DeviceDomains::ForDeviceType(f->params[0]->checked_type(), kDLCUDA)); + + EXPECT_EQ(domains.ResultDeviceType(y_domain), kDLCUDA); + EXPECT_EQ(domains.ResultDeviceType(result_domain), kDLCUDA); +} + +} // namespace +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py new file mode 100644 index 000000000000..2252d8a235c9 --- /dev/null +++ b/tests/python/relay/test_pass_plan_devices.py @@ -0,0 +1,1320 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License + + +"""Unit tests for the PlanDevices pass. We check: + - The pass alone given the expected AST, though we need to manually run InferTypes. + - The pass is idempotent. + - Execution on the VM backend yields the correct result.""" + +import tvm +from tvm import relay +import tvm.testing +import numpy as np + +CPU = tvm.device("cpu") # device_type=1 +GPU = tvm.device("cuda") # device_type=2 +DEFAULT = GPU + +core = tvm.IRModule() +core.import_from_std("core.rly") + + +def rewrite_and_assert(in_mod, expected_mod): + """Manually run the pass and assert it's structurally equals to the expected.""" + actual_mod = relay.transform.InferType()(in_mod) + actual_mod = relay.transform.PlanDevices(DEFAULT)(actual_mod) + actual_mod = relay.transform.InferType()(actual_mod) + expected_mod = relay.transform.InferType()(expected_mod) + if not tvm.ir.structural_equal(actual_mod, expected_mod, True): + # Print everything in full so we can see what's going on when things fail. + print("Input module:") + print(in_mod) + print("Expected module:") + print(expected_mod) + print("Actual module:") + print(actual_mod) + # Assert again so as to see the actual disagreeing sub-expressions. + tvm.ir.assert_structural_equal(actual_mod, expected_mod, True) + + +def eval_and_assert(in_mod: tvm.IRModule, reference_func, args): + """Test the standard compilation flow gives us a function which agrees with the Numpy + reference implementation.""" + if not tvm.runtime.enabled("cuda"): + print("Not evaluating since GPU is not available") + return + with tvm.transform.PassContext(opt_level=3): + compiled = relay.create_executor("vm", mod=in_mod, device=GPU, target="cuda").evaluate() + actual = compiled(*args).numpy() + expected = reference_func(*args) + tvm.testing.assert_allclose(actual, expected) + + +def rand(shape): + return np.random.rand(*shape).astype("float32") + + +def rands(shape, n): + return [rand(shape) for i in range(n)] + + +def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, args): + """Test in_mod against expected_mod and reference_func using args.""" + # Correctness + rewrite_and_assert(in_mod, expected_mod) + # Idempotence + rewrite_and_assert(expected_mod, expected_mod) + # The VM can compile and possibly even run the module + # TODO(mbs): Disabled until VM supports new device planning. + # if not (reference_func is None) and not (args is None): + # eval_and_assert(in_mod, reference_func, args) + + +def test_plain(): + # Everything defaults to GPU + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[2, 2, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_left_add_on_cpu(): + # Force some args to be on CPU, rest default to GPU. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = add(%c, %d); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_left_add_on_cpu_via_copy(): + # As for test_left_add_on_cpu, but with an explicit device_copy. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = device_copy(%0, src_dev_type=1, dst_dev_type=2); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = add(%c, %d); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_both_adds_on_cpu(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = add(%c, %d); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=1); + subtract(%2, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1, 1], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = add(%c, %d); + %3 = on_device(%2, device_type=1, is_fixed=True); + %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + subtract(%4, %5) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_sharing(): + # The same add sub-expression is annotated twice. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = on_device(%0, device_type=1); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = on_device(%0, device_type=1, is_fixed=True); + %3 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %4 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + subtract(%3, %4) + } + """ + ) + + def ref(a, b): + x = np.add(a, b) + return np.subtract(x, x) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_let_on_cpu(): + # The device for a let-bound expression can flow from uses of the let-bound var. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %l = add(%a, %b); + let %r = add(%c, %d); + %0 = on_device(%l, device_type=1); + subtract(%0, %r) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + let %l = on_device(%0, device_type=1, is_fixed=True); + let %r = add(%c, %d); + %1 = device_copy(%l, src_dev_type=1, dst_dev_type=2); + subtract(%1, %r) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_func_param_on_cpu(): + # Devices for function parameters flow to call sites. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %f = fn (%x, %y) { + %0 = add(%x, %y); + on_device(%0, device_type=1) + }; + %1 = %f(%a, %b); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1, 1], result_device_type=1) { + let %f = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + add(%x, %y) + }; + %0 = %f(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_func_result_on_cpu(): + # Devices for call sites flow to function results. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %f = fn (%x, %y) { + add(%x, %y) + }; + %0 = %f(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + add(%x, %y) + }; + let %f = on_device(%0, device_type=1, is_fixed=True); + %1 = %f(%a, %b); + %2 = on_device(%1, device_type=1, is_fixed=True); + %3 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %4 = add(%c, %d); + subtract(%3, %4) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_higher_order(): + # The constraint on %a flows back to %y via %f and %h + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %f = fn (%g) { + fn (%a) { + %0 = on_device(%a, device_type=1); + %1 = %g(%0); + add(%1, %x) + } + }; + let %h = fn (%b) { + negative(%b) + }; + %2 = %f(%h); + %3 = %2(%y); + subtract(%x, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) { + let %f = fn (%g, param_device_types=[2], result_device_type=2) { + fn (%a, param_device_types=[1], result_device_type=2) { + %0 = device_copy(%a, src_dev_type=1, dst_dev_type=2); + %1 = %g(%0); + add(%1, %x) + } + }; + let %h = fn (%b, param_device_types=[2], result_device_type=2) { + negative(%b) + }; + %2 = %f(%h); + %3 = %2(%y); + subtract(%x, %3) + } + """ + ) + + def ref(x, y): + def f(g): + return lambda a: np.add(g(a), x) + + def h(b): + return np.negative(b) + + return np.subtract(x, f(h)(y)) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_function_in_tuple(): + # Since %f ends up in a tuple its argument and result is forced to be on the CPU + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + %0 = on_device(%b, device_type=1); + add(%a, %0) + }; + let %t = (%f, %x); + %1 = %t.1; + %2 = %t.0; + %2(%1, %y) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + add(%a, %b) + }; + let %t = (%f, %x); + %0 = %t.1; + %1 = %t.0; + %1(%0, %y) + } + """ + ) + + def ref(x, y): + return np.add(x, y) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_device_copy(): + const = rand((5, 7)) + metatable = {"relay.Constant": [relay.const(const)]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32]) { + %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + add(%0, meta[relay.Constant][0]) + } + """, + "from_string", + None, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=2) { + %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + add(%0, meta[relay.Constant][0]) + } + """, + "from_string", + None, + metatable, + ) + + def ref(x): + return np.add(x, const) + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_shape_func(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64]) { + %0 = fn (%y: Tensor[(?), float32]) { + nn.relu(%y) + }; + let %p = on_device(%0, device_type=2, is_fixed=True); + %1 = on_device(%x, device_type=2, is_fixed=True); + %2 = vm.shape_of(%1, dtype="int64"); + %3 = (%2,); + %4 = (%s,); + vm.shape_func(%p, %3, %4, is_input=[False]) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], + param_device_types=[2, 1], result_device_type=1) { + %0 = fn (%y: Tensor[(?), float32], param_device_types=[2], result_device_type=2) { + nn.relu(%y) + }; + let %p = on_device(%0, device_type=2, is_fixed=True); + %1 = vm.shape_of(%x, dtype="int64"); + %2 = (%1,); + %3 = (%s,); + vm.shape_func(%p, %2, %3, is_input=[False]) + } + """ + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_shape_of(): + # We need to use is_fixed=True in the on_device call so that the tensor will be on the GPU. Otherwise the + # result defaults to the result device for @main which is the CPU, thus forcing a copy. + # TODO(mbs): Perhaps the defaulting heuristics are being too clever? + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?, ?), float32]) { + %0 = on_device(%x, device_type=2, is_fixed=True); + vm.shape_of(%0, dtype="int64") + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?, ?), float32], param_device_types=[2], result_device_type=1) { + vm.shape_of(%x, dtype="int64") + } + """ + ) + + def ref(x): + return x.shape + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_alloc_storage(): + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%size: int64, %alignment: int64) { + memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + } + """, + "from_string", + core, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%size: int64, %alignment: int64, param_device_types=[1, 1], result_device_type=2) { + memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + } + """, + "from_string", + core, + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_alloc_tensor(): + shape = np.array([3, 2]) + metatable = {"relay.Constant": [relay.const(shape, dtype="int64")]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%sto: Storage[]) { + memory.alloc_tensor(%sto, 0, meta[relay.Constant][0], + const_shape=meta[relay.Constant][0], assert_shape=[]) + } + """, + "from_string", + core, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%sto: Storage[], param_device_types=[2], result_device_type=2) { + %0 = on_device(0, device_type=1, is_fixed=True); + %1 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[]) + } + """, + "from_string", + core, + metatable, + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_reshape_tensor(): + newshape = [2, 4, 2] + metatable = {"relay.Constant": [relay.const(newshape, dtype="int64")]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(2, 8), float32]) { + vm.reshape_tensor(%x, meta[relay.Constant][0], newshape=[2, 4, 2]) + } + """, + "from_string", + None, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(2, 8), float32], param_device_types=[2], result_device_type=2) { + %0 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + vm.reshape_tensor(%x, %0, newshape=[2, 4, 2]) + } + """, + "from_string", + None, + metatable, + ) + + def ref(x): + return np.reshape(x, newshape) + + exercise(input(), expected(), ref, rands((2, 8), 1)) + + +def test_dynamic_input(): + # There's nothing special about inferring devices for partially unknown types. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32]) { + add(%x0, %x1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32], + param_device_types=[2, 2], result_device_type=2) { + add(%x0, %x1) + } + """ + ) + + def ref(x0, x1): + return np.add(x0, x1) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_redundant_annotation(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1); + %2 = subtract(%1, %z); + %3 = on_device(%0, device_type=1); + add(%2, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2], result_device_type=2) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%0, device_type=1, is_fixed=True); + %4 = subtract(%2, %z); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + add(%4, %5) + } + """ + ) + + def ref(x, y, z): + a = np.add(x, y) + return np.add(np.subtract(a, z), a) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_annotate_expr(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2); + %2 = subtract(%1, %z); + on_device(%2, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[2, 2, 1], result_device_type=1) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + subtract(%2, %z) + } + """ + ) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_annotate_all(): + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1); + %2 = subtract(%1, %z); + on_device(%2, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1], result_device_type=1) { + %0 = add(%x, %y); + subtract(%0, %z) + } + """ + ) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_conv_network(): + r"""The network and devices are as follows: + data1 data2 <--- CPU + | | + conv2d conv2d <--- CPU + \ / + \ / + add <--- GPU + | + conv2d <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], + %weight: Tensor[(64, 64, 3, 3), float32]) { + %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %1 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=1); + %4 = add(%2, %3); + %5 = on_device(%4, device_type=2); + %6 = nn.conv2d(%5, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + on_device(%6, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], + %weight: Tensor[(64, 64, 3, 3), float32], param_device_types=[1, 1, 1], result_device_type=1) { + %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %3 = on_device(%2, device_type=1, is_fixed=True); + %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %6 = add(%4, %5); + %7 = on_device(%6, device_type=2, is_fixed=True); + %8 = device_copy(%7, src_dev_type=2, dst_dev_type=1); + nn.conv2d(%8, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) + } + """ + ) + + # Don't try to execute, we don't have a reference conv2d + exercise(input(), expected(), None, None) + + +def test_tuple_get_item(): + # Note that the device copy should be placed after projection rather than before. This is handled by + # a heuristic in the pass. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(3, 3, 4), float32]) { + let %t = split(%x, indices_or_sections=3); + %0 = on_device(%t, device_type=1); + %1 = on_device(%t, device_type=1); + %2 = %0.0; + %3 = %1.1; + %4 = subtract(%2, %3); + on_device(%4, device_type=2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(3, 3, 4), float32], param_device_types=[1], result_device_type=2) { + %0 = split(%x, indices_or_sections=3); + let %t = on_device(%0, device_type=1, is_fixed=True); + %1 = %t.0; + %2 = on_device(%1, device_type=1, is_fixed=True); + %3 = %t.1; + %4 = on_device(%3, device_type=1, is_fixed=True); + %5 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %6 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + subtract(%5, %6) + } + """ + ) + + def ref(x): + t = np.split(x, 3) + return np.subtract(t[0], t[1]) + + exercise(input(), expected(), ref, rands((3, 3, 4), 1)) + + +def test_propogation(): + r""" The network and devices are as follows: + x <--- CPU + | + log <--- CPU + / \ + log2 log10 <--- GPU + \ / + add <--- GPU + | + tan <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32]) { + %0 = log(%x); + %1 = on_device(%0, device_type=1); + %2 = log2(%1); + %3 = on_device(%0, device_type=1); + %4 = log10(%3); + %5 = on_device(%2, device_type=2); + %6 = on_device(%4, device_type=2); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2); + %9 = tan(%8); + on_device(%9, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=1) { + %0 = log(%x); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%0, device_type=1, is_fixed=True); + %4 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %5 = log2(%2); + %6 = log10(%4); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2, is_fixed=True); + %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + tan(%9) + } + """ + ) + + def ref(x): + y = np.log(x) + return np.tan(np.add(np.log2(y), np.log10(y))) + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_fusible_network(): + r""" The network is as follows: + x y <--- GPU + \ / + add <--- GPU + / \ + negative \ <--- CPU + \ \ + \ negative <--- GPU + \ / + add <--- GPU + | + negative <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2); + %2 = negative(%1); + %3 = on_device(%2, device_type=1); + %4 = negative(%0); + %5 = add(%3, %4); + %6 = on_device(%5, device_type=2); + %7 = negative(%6); + on_device(%7, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_device_types=[2, 2], result_device_type=1) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + %3 = negative(%2); + %4 = on_device(%3, device_type=1, is_fixed=True); + %5 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + %6 = negative(%0); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2, is_fixed=True); + %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + negative(%9) + } + """ + ) + + def ref(x, y): + z = np.add(x, y) + return np.negative(np.add(np.negative(z), np.negative(z))) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_unpropagatable_graph(): + r"""The network is as follows: + a b <--- CPU + \ / + \ / c d <--- GPU + \ / \ / + add \ / <--- CPU + \ \ / + \ multiply <--- GPU + \ / + subtract <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = multiply(%c, %d); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=2); + %4 = subtract(%2, %3); + on_device(%4, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=1) { + %0 = multiply(%c, %d); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = add(%a, %b); + %3 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.multiply(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_conditional(): + # The conditional is over a function type, thus exercising the first-order/higher-order domain handling. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + let %f = fn (%a) { + %0 = on_device(%y, device_type=1, is_fixed=True); + add(%a, %0) + }; + let %g = fn (%a1) { + subtract(%a1, %y) + }; + let %h = if (%x) { + %f + } else { + %g + }; + %h(%z) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1], result_device_type=1) { + let %f = fn (%a, param_device_types=[1], result_device_type=1) { + add(%a, %y) + }; + let %g = fn (%a1, param_device_types=[1], result_device_type=1) { + subtract(%a1, %y) + }; + let %h = if (%x) { + %f + } else { + %g + }; + %h(%z) + } + """ + ) + + def ref(x, y, z): + def f(a): + return np.add(a, y) + + def g(a): + return np.subtract(a, y) + + h = f if x else g + return h(z) + + exercise(input(), expected(), ref, [True, rand((5, 7)), rand((5, 7))]) + + +def test_global(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = on_device(%b, device_type=1); + add(%a, %0) + } + + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + @f(%y, %x) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) -> Tensor[(5, 7), float32] { + %0 = device_copy(%b, src_dev_type=1, dst_dev_type=2); + add(%a, %0) + } + + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[1, 2], result_device_type=2) -> Tensor[(5, 7), float32] { + @f(%y, %x) + } + """ + ) + + def ref(x, y): + def f(a, b): + return np.add(a, b) + + return f(x, y) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_ref(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %r = ref(%x); + %0 = on_device(%y, device_type=1); + ref_write(%r, %0); + %1 = ref_read(%r); + add(%x, %1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) { + let %r = ref(%x); + %0 = device_copy(%y, src_dev_type=1, dst_dev_type=2); + ref_write(%r, %0); + %1 = ref_read(%r); + add(%x, %1) + } + """ + ) + + def ref(x, y): + r = {"value": x} + r["value"] = y + return np.add(x, r["value"]) + + # Don't try to execute, no backend currently supports both hetrogeneous devices and references. + exercise(input(), expected(), None, None) + + +def test_adt(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { + %0 = on_device(%y, device_type=1, is_fixed=True); + %1 = Nil; + %2 = Cons(%0, %1); + let %l = Cons(%x, %2); + match? (%l) { + Cons(%z, _) => %z + } + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + %0 = Nil; + %1 = Cons(%y, %0); + let %l = Cons(%x, %1); + match? (%l) { + Cons(%z, _) => %z + } + } + """ + ) + + def ref(x, y): + l = [x, y] + return l[0] + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))