diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index cdd4c9c1dbd2..e740776d6d4f 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -444,6 +444,17 @@ TVM_DLL Pass RelayToTIRTargetHook(); */ TVM_DLL Pass ManifestAlloc(Target target_host, Map targets); +/*! + * \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + * every Relay sub-expression should run (and the result stored). Captures the result of that + * analysis using new "on_device" and "device_copy" CallNodes. See + * tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator} + * for help recovering the device for an arbitrary sub-expression in downstream transformations. + * + * \param default_device_type DLDeviceType for default device. + */ +TVM_DLL Pass PlanDevices(DLDeviceType default_device_type); + } // namespace transform /*! diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 7c79464bdd30..bb91afc06195 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1167,6 +1167,16 @@ def SimplifyExpr(): return _ffi_api.SimplifyExpr() +def PlanDevices(default_device): + """ + Uses existing "on_device" and "device_copy" CallNodes to infer the device on which + every Relay sub-expression should run (and the result stored). Captures the result of that + analysis using new "on_device" and "device_copy" CallNodes. Note that the device_id of + the default_device is ignored. + """ + return _ffi_api.PlanDevices(default_device) + + def FoldExplicitPadding(): """ FoldExplicitPadding finds explict padding before an op that can support diff --git a/src/relay/op/annotation/annotation.h b/src/relay/op/annotation/annotation.h index 643a82116b5b..35f8b6bf50b6 100644 --- a/src/relay/op/annotation/annotation.h +++ b/src/relay/op/annotation/annotation.h @@ -81,9 +81,6 @@ OnDeviceProps GetOnDeviceProps(const CallNode* call_node); */ OnDeviceProps GetOnDeviceProps(const Expr& expr); -/*! \brief Returns true if \p expr is an on_device CallNode. */ -inline bool IsOnDeviceCall(const Expr& expr) { return GetOnDeviceProps(expr).body.defined(); } - /*! * \brief Returns \p function annotated with "param_device_types" and "result_device_type" * attributes capturing parameter and result devices types respectively. diff --git a/src/relay/transforms/device_aware_visitors.cc b/src/relay/transforms/device_aware_visitors.cc new file mode 100644 index 000000000000..204bce53207b --- /dev/null +++ b/src/relay/transforms/device_aware_visitors.cc @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/device_aware_visitors.cc + * \brief Visitors which track the device for the current Relay expression and Relay Vars. + */ + +#include "./device_aware_visitors.h" + +namespace tvm { +namespace relay { +namespace transform { + +// TODO(mbs): We'd probably have less tendious code duplication if we redefined the memoizing +// mutator on top of the generic Functor. + +DLDeviceType LexicalOnDeviceMixin::GetInScopeDeviceType(const Expr& expr) const { + auto props = GetOnDeviceProps(expr); + if (props.body.defined() && props.is_fixed) { + // Look through any fixed "on_device" annotations. + return props.device_type; + } + if (expr->IsInstance()) { + // Lookup variable binding. + auto itr = var_device_types_.find(Downcast(expr)); + if (itr == var_device_types_.end()) { + return kInvalidDeviceType; + } else { + return itr->second; + } + } + // Otherwise use the currently in-scope device type. + if (expr_device_types_.empty()) { + return kInvalidDeviceType; + } else { + return expr_device_types_.back(); + } +} + +void LexicalOnDeviceMixin::EnterFunctionBody() { ++function_nesting_; } + +void LexicalOnDeviceMixin::ExitFunctionBody() { + ICHECK_GT(function_nesting_, 0); + --function_nesting_; +} + +void LexicalOnDeviceMixin::PushDeviceType(DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + expr_device_types_.emplace_back(device_type); +} + +void LexicalOnDeviceMixin::PopDeviceType() { + if (expr_device_types_.empty()) { + return; + } + expr_device_types_.pop_back(); +} + +void LexicalOnDeviceMixin::PushBoundVar(Var var, DLDeviceType device_type) { + if (device_type == kInvalidDeviceType) { + return; + } + ICHECK(var_device_types_.find(var) == var_device_types_.end()); + var_device_types_.emplace(std::move(var), device_type); +} + +void LexicalOnDeviceMixin::PopBoundVar(const Var& var) { + auto itr = var_device_types_.find(var); + if (itr == var_device_types_.end()) { + return; + } + var_device_types_.erase(itr); +} + +void DeviceAwareExprVisitor::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } +} + +void DeviceAwareExprVisitor::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec). + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + PopBoundVar((*itr)->var); + PostVisitLet_(*itr); + } + PostVisitLetBlock_(let_node); +} + +void DeviceAwareExprVisitor::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + ExprVisitor::VisitExpr_(function_node); +} + +void DeviceAwareExprVisitor::DeviceAwareVisitExpr_(const CallNode* call_node) { + ExprVisitor::VisitExpr_(call_node); +} + +void DeviceAwareExprVisitor::PreVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); +} + +void DeviceAwareExprVisitor::PostVisitLet_(const LetNode* let_node) { + // no-op +} + +void DeviceAwareExprVisitor::PostVisitLetBlock_(const LetNode* let_node) { + // no-op +} + +Expr DeviceAwareExprMutator::VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + Expr result = DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + + return result; + } +} + +Expr DeviceAwareExprMutator::VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector> bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + std::pair pair = PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(pair.first, pair.second, inner_let_node->span, inner_let_node); + expr = inner_let_node->body; + } + + expr = VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* pre_let_node = std::get<3>(*itr); + PopBoundVar(pre_let_node->var); + Let post_let = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), + /*body=*/expr, /*span=*/std::get<2>(*itr)); + expr = PostVisitLet_(pre_let_node, post_let.get()); + } + return PostVisitLetBlock_(let_node, expr.as()); +} + +Expr DeviceAwareExprMutator::VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + Expr expr = VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + return OnDevice(expr, props.device_type, props.is_fixed); + } else { + return DeviceAwareVisitExpr_(call_node); + } +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return ExprMutator::VisitExpr_(function_node); +} + +Expr DeviceAwareExprMutator::DeviceAwareVisitExpr_(const CallNode* call_node) { + return ExprMutator::VisitExpr_(call_node); +} + +void DeviceAwareExprMutator::PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ +} + +std::pair DeviceAwareExprMutator::PreVisitLetBinding_(const Var& var, + const Expr& value) { + return std::make_pair(Downcast(VisitExpr(var)), VisitExpr(value)); +} + +Expr DeviceAwareExprMutator::PostVisitLet_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +Expr DeviceAwareExprMutator::PostVisitLetBlock_(const LetNode* pre_let_node, + const LetNode* post_let_node) { + if (pre_let_node->var == post_let_node->var && pre_let_node->value == post_let_node->value && + pre_let_node->body == post_let_node->body) { + return GetRef(pre_let_node); + } else { + return GetRef(post_let_node); + } +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/device_aware_visitors.h b/src/relay/transforms/device_aware_visitors.h new file mode 100644 index 000000000000..8611f87efa06 --- /dev/null +++ b/src/relay/transforms/device_aware_visitors.h @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/device_aware_visitors.h + * \brief Visitors which track the device for the current Relay expression and Relay Vars. + */ + +#ifndef TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ +#define TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ + +#include +#include +#include +#include + +#include +#include +#include + +#include "../op/annotation/annotation.h" + +namespace tvm { +namespace relay { +namespace transform { + +/*! + * \brief Helper class for expression transformers which need to keep track of the device + * holding the results of expressions and bound variables. This is recovered from the + * "on_device" function attributes and fixed "on_device" CallNodes added by the PlanDevices + * pass. + * + * \sa \p DeviceAwareExpr{Visitor,Mutator}. + */ +class LexicalOnDeviceMixin { + protected: + /*! + * \brief Returns the device type on which the result of \p expr should/will be stored, assuming + * Push/Pop DeviceType/BoundVar have been correctly called. Returns \p kInvalidDeviceType if + * stack is empty and no bound vars have device types. + */ + DLDeviceType GetInScopeDeviceType(const Expr& expr) const; + + /*! \brief Indicate a function body is being entered. */ + void EnterFunctionBody(); + + /*! \brief Indicate a function body has been processed. */ + void ExitFunctionBody(); + + /*! \brief Push a device type onto the lexical device stack. Ignore if \p kInvalidDeviceType. */ + void PushDeviceType(const DLDeviceType device_type); + + /*! \brief Pop a device type from the lexical device stack. Ignore if stack is empty. */ + void PopDeviceType(); + + /*! \brief Remember that \p var will be stored on \p device_type. Ignore if \p kInvalidDeviceType. + * + * CAUTION: Despite the name we don't support re-entering the same function body. + */ + void PushBoundVar(Var var, DLDeviceType device_type); + + /*! \brief Remove the binding for \p var to it's device type. Ignore if var is not bound. */ + void PopBoundVar(const Var& var); + + /*! + * \brief Returns the number of function definitions wrapping the currently visited expression. + */ + int function_nesting() const { return function_nesting_; } + + private: + /*! + * \brief The number of function bodies entered. Since many transforms need to distinguish global + * functions from local functions this supports the mixin's \p is_global() helper method. + */ + int function_nesting_ = 0; + + /*! + * \brief The stack of lexically enclosing "on_device" devices types, from outermost to innermost. + * When visiting an expression other than a variable we can assume the expression result is + * to be stored on device_type_.back(). + */ + std::vector expr_device_types_; + /*! + * \brief A map from in-scope variable to their device types. We may assume the variable is only + * ever bound to a value stored on this device at runtime. + */ + std::unordered_map + var_device_types_; +}; + +template +class DeviceAwareExprFunctor; + +/*! + * \brief ExprFunctor which tracks devices. We only support 'visitor' style implementation + * with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without + * any memoization. + */ +template <> +class DeviceAwareExprFunctor : public ExprFunctor, + public LexicalOnDeviceMixin { + private: + using TSuper = ExprFunctor; + + public: + void VisitExpr_(const FunctionNode* function_node) { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + // No tracking inside primitive functions. + return DeviceAwareVisitExpr_(function_node); + } else { + // Function parameters come into scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i)); + } + // Entering scope of function body. + PushDeviceType(GetFunctionResultDeviceType(function_node)); + EnterFunctionBody(); + + DeviceAwareVisitExpr_(function_node); + + // Leaving scope of function body. + ExitFunctionBody(); + PopDeviceType(); + // Function parameters go out of scope. + for (size_t i = 0; i < function_node->params.size(); ++i) { + PopBoundVar(function_node->params[i]); + } + } + } + + void VisitExpr_(const LetNode* let_node) { + PreVisitLetBlock_(let_node); + std::vector bindings; + Expr expr = GetRef(let_node); + while (const auto* inner_let_node = expr.as()) { + // Let-bound var (in pre visited version) goes into scope. + // (We'll just assume this is a letrec.) + PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value)); + PreVisitLetBinding_(inner_let_node->var, inner_let_node->value); + bindings.emplace_back(inner_let_node); + expr = inner_let_node->body; + } + + VisitExpr(expr); + + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + // Let-bound var goes out of scope. + const LetNode* visited_let_node = *itr; + PopBoundVar(visited_let_node->var); + PostVisitLet_(visited_let_node); + } + PostVisitLetBlock_(let_node); + } + + void VisitExpr_(const CallNode* call_node) { + auto props = GetOnDeviceProps(call_node); + if (props.body.defined() && props.is_fixed) { + // Entering lexical scope of fixed "on_device" call. + PushDeviceType(props.device_type); + VisitExpr(props.body); + // Leaving lexical scope of "on_device" call. + PopDeviceType(); + } else { + DeviceAwareVisitExpr_(call_node); + } + } + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node) { + return TSuper::VisitExpr_(function_node); + } + + virtual void DeviceAwareVisitExpr_(const CallNode* call_node) { + return TSuper::VisitExpr_(call_node); + } + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node) { /* no-op */ + } + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value) { + VisitExpr(var); + VisitExpr(value); + } + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node) { /* no-op */ + } + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node) {} +}; + +/*! \brief ExprVisitor which tracks devices. */ +class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin { + public: + using ExprVisitor::VisitExpr_; + + void VisitExpr_(const FunctionNode* function_node) final; + void VisitExpr_(const LetNode* let_node) final; + void VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual void DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual void PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLet_(const LetNode* let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation is a no-op. + */ + virtual void PostVisitLetBlock_(const LetNode* let_node); +}; + +/*! \brief ExprMutator which tracks devices. */ +class DeviceAwareExprMutator : public ExprMutator, public LexicalOnDeviceMixin { + public: + Expr VisitExpr_(const FunctionNode* function_node) final; + Expr VisitExpr_(const LetNode* let_node) final; + Expr VisitExpr_(const CallNode* call_node) final; + + /*! + * \brief These are as for VisitExpr_. Devices for expressions and function parameters will be + * tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For + * functions the function_nesting count will already include that of \p function_node. + */ + virtual Expr DeviceAwareVisitExpr_(const FunctionNode* function_node); + virtual Expr DeviceAwareVisitExpr_(const CallNode* call_node); + + /*! + * \brief Visit the first let in a chain of let expressions before any let bindings or final + * body has been visited. Default implementation is a no-op. + */ + virtual void PreVisitLetBlock_(const LetNode* let_node); + + /*! + * \brief Visit a let-bound expression before the let body has been visited. Devices for the + * let-bound variable will be tracked automatically. Default implementation just visits var and + * value. + */ + virtual std::pair PreVisitLetBinding_(const Var& var, const Expr& value); + + /*! + * \brief Visit a let expression after the let-bound value and body have been visited. + * Default implementation just returns a reference to the post-visited node. + */ + virtual Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node); + + /*! + * \brief Visit the first let in a chain of let expressions after it has been visited. + * Default implementation returns reference to let node. + */ + virtual Expr PostVisitLetBlock_(const LetNode* pre_let_node, const LetNode* post_let_node); +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_DEVICE_AWARE_VISITORS_H_ diff --git a/src/relay/transforms/device_domains.cc b/src/relay/transforms/device_domains.cc new file mode 100644 index 000000000000..15784856edbf --- /dev/null +++ b/src/relay/transforms/device_domains.cc @@ -0,0 +1,482 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_domains.cc + * \brief Unification domain for the device planner. + */ + +#include "./device_domains.h" + +#include + +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" + +namespace tvm { +namespace relay { +namespace transform { + +namespace { + +// Ye olde boost hash mixer. +constexpr size_t mix(size_t h1, size_t h2) { + return h1 ^ (h1 + 0x9e3779b9 + (h2 << 6) + (h2 >> 2)); +} + +/*! + * \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather + * than the original "device_copy" operator. + * + * See te_compiler.cc for where this rewriting occurs. + */ +DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) { + auto tir_call_attrs = call_node->attrs.as(); + if (tir_call_attrs == nullptr) { + return {}; + } + if (tir_call_attrs->metadata.count("source_device") != 1 || + tir_call_attrs->metadata.count("dst_device") != 1) { + return {}; + } + ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1"; + return { + call_node->args[0], + static_cast( + Downcast(tir_call_attrs->metadata["source_device"])->value), + static_cast(Downcast(tir_call_attrs->metadata["dst_device"])->value)}; +} + +} // namespace + +// The following hash and equality helpers give each free first-order domain pointer its own +// distinct identity. + +size_t DeviceDomainHash::operator()(const DeviceDomainPtr& domain) const { + if (domain->is_free()) { + // Give each free first-order domain its own identity. + return static_cast(reinterpret_cast(domain.get())); + } else { + size_t h = domain->args_and_result_.size(); + h = mix(h, std::hash()(static_cast(domain->device_type_))); + for (const auto& sub_domain_ptr : domain->args_and_result_) { + h = mix(h, DeviceDomainHash()(sub_domain_ptr)); + } + return h; + } +} + +bool DeviceDomainEqual::operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { + if (lhs->args_and_result_.size() != rhs->args_and_result_.size()) { + // Mismatched arities are never equal. + // (Though we'll never ask to do such a comparison explicitly, the hash map + // may do so implicitly due to hash collisions.) + return false; + } + if (lhs->is_free() && rhs->is_free()) { + // Compare first-order free domains by their address. + return lhs.get() == rhs.get(); + } + if (lhs->args_and_result_.empty()) { + // Compare first-order domains by their device type -- free vs bound will compare as false. + return lhs->device_type_ == rhs->device_type_; + } else { + // Compare higher-order domains pointwise. + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + if (!(*this)(lhs->args_and_result_[i], rhs->args_and_result_[i])) { + return false; + } + } + return true; + } +} + +/* static */ +DeviceDomainPtr DeviceDomains::MakeDomain(const Type& type, DLDeviceType device_type) { + if (const auto* func_type_node = type.as()) { + std::vector args_and_result; + args_and_result.reserve(func_type_node->arg_types.size() + 1); + for (const auto& arg_type : func_type_node->arg_types) { + args_and_result.emplace_back(MakeDomain(arg_type, kInvalidDeviceType)); + } + args_and_result.emplace_back(MakeDomain(func_type_node->ret_type, device_type)); + return std::make_shared(std::move(args_and_result)); + } else { + return std::make_shared(device_type); + } +} + +DeviceDomainPtr DeviceDomains::Lookup(DeviceDomainPtr domain) { + DeviceDomainPtr root = domain; + while (true) { + auto itr = domain_to_equiv_.find(root); + if (itr == domain_to_equiv_.end()) { + break; + } + ICHECK_NE(itr->second, root); + root = itr->second; + ICHECK_NOTNULL(root); + } + // Path compression. + while (domain != root) { + auto itr = domain_to_equiv_.find(domain); + ICHECK(itr != domain_to_equiv_.end()); + domain = itr->second; + ICHECK_NOTNULL(domain); + itr->second = root; + } + return root; +} + +DeviceDomainPtr DeviceDomains::Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + // TODO(mbs): Proper diagnostics. + ICHECK_EQ(lhs->args_and_result_.size(), rhs->args_and_result_.size()) + << "Device domains:" << std::endl + << ToString(lhs) << std::endl + << "and" << std::endl + << ToString(rhs) << std::endl + << "do not have the same kind and can't be unified."; + if (rhs->is_free()) { + return lhs; + } else if (lhs->is_free()) { + return rhs; + } else if (lhs->args_and_result_.empty()) { + // Must have consistent device types for first order domains. + if (lhs->device_type_ != rhs->device_type_) { + // TODO(mbs): Proper diagnostics. + std::ostringstream os; + os << "Inconsistent device types " << lhs->device_type_ << " and " << rhs->device_type_; + throw Error(os.str()); + } + return lhs; + } else { + // Recurse for higher-order. + std::vector args_and_result; + args_and_result.reserve(lhs->args_and_result_.size()); + for (size_t i = 0; i < lhs->args_and_result_.size(); ++i) { + args_and_result.emplace_back(Unify(lhs->args_and_result_[i], rhs->args_and_result_[i])); + } + return MakeDomain(std::move(args_and_result)); + } +} + +DeviceDomainPtr DeviceDomains::Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { + lhs = Lookup(lhs); + rhs = Lookup(rhs); + auto joined_domain = Join(lhs, rhs); + if (!DeviceDomainEqual()(lhs, joined_domain)) { + domain_to_equiv_.emplace(lhs, joined_domain); + } + if (!DeviceDomainEqual()(rhs, joined_domain)) { + domain_to_equiv_.emplace(rhs, joined_domain); + } + return joined_domain; +} + +void DeviceDomains::UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + if (!lhs->is_higher_order() && rhs->is_higher_order()) { + Collapse(lhs, rhs); + } else { + Unify(lhs, rhs); + } +} + +DeviceDomainPtr DeviceDomains::DomainFor(const Expr& expr) { + ICHECK(expr.defined()); + auto itr = expr_to_domain_.find(expr.get()); + if (itr != expr_to_domain_.end()) { + return Lookup(itr->second); + } + auto domain = Free(expr->checked_type()); + expr_to_domain_.emplace(expr.get(), domain); + return domain; +} + +DeviceDomainPtr DeviceDomains::DomainForCallee(const Call& call) { + auto itr = call_to_callee_domain_.find(call.get()); + if (itr != call_to_callee_domain_.end()) { + return Lookup(itr->second); + } + std::vector args_and_result; + + auto on_device_props = GetOnDeviceProps(call.get()); + auto device_copy_props = GetDeviceCopyProps(call.get()); + if (!device_copy_props.body.defined()) { + device_copy_props = GetPrimitiveDeviceCopyProps(call.get()); + } + + if (on_device_props.body.defined()) { + // on_device(expr, device_type=, is_fixed=false) + // on_device : fn():?x? + // + // on_device(expr, device_type=, is_fixed=true) + // on_device: fn(): + args_and_result.emplace_back( + ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type)); + if (on_device_props.is_fixed) { + args_and_result.emplace_back(args_and_result.front()); + } else { + args_and_result.emplace_back(Free(on_device_props.body->checked_type())); + } + } else if (device_copy_props.body.defined()) { + // device_copy(expr, src_dev_type=, dst_dev_type=) + // device_copy: fn(): + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type)); + args_and_result.emplace_back( + ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type)); + } else if (call->op == alloc_storage_op) { + ICHECK_EQ(call->args.size(), 2U); + // alloc_storage(size, alignment, device_type=) + // alloc_storage: fn(, ): + const auto* attrs = call->attrs.as(); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back( + ForDeviceType(call->checked_type(), static_cast(attrs->device_type))); + } else if (call->op == alloc_tensor_op) { + ICHECK_EQ(call->args.size(), 3U); + // alloc_tensor(storage, offset, shape) + // alloc_tensor: fn(?x?, , ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op == shape_func_op) { + ICHECK_EQ(call->args.size(), 3U); + // shape_func(func, inputs, outputs, is_inputs=[...]) + // shape_func: fn(..., , ): + // where ... is a free domain appropriate for func's type + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + // TODO(mbs): I think this should be on the cpu only when is_input = [false], but + // what do we do when we have multiple arguments with different is_input values? + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == shape_of_op) { + ICHECK_EQ(call->args.size(), 1U); + // shape_of(tensor) + // shape_of: fn(?x?): + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(cpu_domain_); + } else if (call->op == invoke_tvm_op) { + ICHECK_EQ(call->args.size(), 3U); + // invoke_tvm_op(op, inputs, outputs) + // invoke_tvm_op: fn(..., ?x?, ?x?):?x? + // where ... is a free domain appropriate for op's type + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(Free(call->args[0]->checked_type())); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(free_domain); + } else if (call->op == reshape_tensor_op) { + ICHECK_EQ(call->args.size(), 2U); + // reshape_tensor(data, shape) + // reshape_tensor: fn(?x?, ):?x? + auto free_domain = Free(call->checked_type()); + args_and_result.emplace_back(free_domain); + args_and_result.emplace_back(cpu_domain_); + args_and_result.emplace_back(free_domain); + } else if (call->op->IsInstance()) { + // (arg1, ..., argn) + // : fn(?x?, ..., ?x?):?x? + // (all args and result must be first-order). + auto free_domain = Free(arb_); + for (size_t i = 0; i < call->args.size(); ++i) { + args_and_result.emplace_back(free_domain); + } + args_and_result.emplace_back(free_domain); + } else if (call->op->IsInstance()) { + // (arg1, ..., argn) + // : fn(?x1?, ..., ?xn?):?xr? + // where we force all possibly higher-order ?xi? to be collapsed to the first-order ?xr?. + // TODO(mbs): This assumes we've eta-expanded constructors, thus all constructors appear + // in callee positions. + const auto* func_type_node = call->op->checked_type().as(); + ICHECK_NOTNULL(func_type_node); + ICHECK_EQ(func_type_node->arg_types.size(), call->args.size()); + auto result_domain = Free(func_type_node->ret_type); // first-order + for (const auto& arg_type : func_type_node->arg_types) { + auto param_domain = Free(arg_type); // possibly higher-order + UnifyCollapsed(result_domain, param_domain); // collapse if required + args_and_result.emplace_back(param_domain); + } + args_and_result.emplace_back(result_domain); + } else { + // Defer to normal case where op can be an arbitrary expression. + return DomainFor(call->op); + } + auto domain = MakeDomain(std::move(args_and_result)); + call_to_callee_domain_.emplace(call.get(), domain); + return domain; +} + +void DeviceDomains::UnifyExprExact(const Expr& lhs, const Expr& rhs) { + auto lhs_domain = DomainFor(lhs); + auto rhs_domain = DomainFor(rhs); + try { + Unify(lhs_domain, rhs_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expressions:" << std::endl + << PrettyPrint(lhs) << std::endl + << "with device:" << std::endl + << ToString(lhs_domain) << "and:" << std::endl + << PrettyPrint(rhs) << std::endl + << "with device:" << std::endl + << ToString(rhs_domain) << std::endl + << e.what(); + } +} + +void DeviceDomains::UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + Unify(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } +} + +void DeviceDomains::UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain) { + auto actual_domain = DomainFor(expr); + try { + UnifyCollapsed(actual_domain, expected_domain); + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Incompatible devices for expression:" << std::endl + << PrettyPrint(expr) << std::endl + << "with actual device:" << std::endl + << ToString(actual_domain) << std::endl + << "and expected device:" << std::endl + << ToString(expected_domain) << std::endl + << e.what(); + } +} + +bool DeviceDomains::AnyFree(DeviceDomainPtr domain) { + domain = Lookup(domain); + if (domain->is_free()) { + return true; + } + for (const auto& sub_domain : domain->args_and_result_) { + if (AnyFree(sub_domain)) { + return true; + } + } + return false; +} + +void DeviceDomains::Collapse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain) { + for (size_t i = 0; i < higher_order_domain->function_arity(); ++i) { + Unify(higher_order_domain->function_param(i), first_order_domain); + } + Unify(higher_order_domain->function_result(), first_order_domain); +} + +void DeviceDomains::SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type) { + ICHECK_NE(default_device_type, kInvalidDeviceType); + domain = Lookup(domain); + if (domain->is_free()) { + // Will never throw since lhs is free. + Unify(domain, std::make_shared(default_device_type)); + } else if (!domain->args_and_result_.empty()) { + for (const auto& sub_domain : domain->args_and_result_) { + SetDefault(sub_domain, default_device_type); + } + } +} + +void DeviceDomains::SetResultDefaultThenParams(const DeviceDomainPtr& domain, + DLDeviceType default_device_type) { + if (!domain->is_higher_order()) { + SetDefault(domain, default_device_type); + return; + } + DLDeviceType result_device_type = ResultDeviceType(domain); + if (result_device_type == kInvalidDeviceType) { + // If the function result device is still free use the given default. + result_device_type = default_device_type; + } + // Default any remaining free parameters to the function result device. + SetDefault(domain, result_device_type); +} + +std::string DeviceDomains::ToString(DeviceDomainPtr domain) { + domain = Lookup(domain); + std::ostringstream os; + if (domain->is_free()) { + // first-order free + os << "?" << static_cast(reinterpret_cast(domain.get())) << "?"; + } else if (domain->args_and_result_.empty()) { + // first-order bound + os << "<" << domain->device_type_ << ">"; + } else { + // higher-order + os << "fn("; + for (size_t i = 0; i + 1 < domain->args_and_result_.size(); ++i) { + if (i > 0) { + os << ","; + } + os << ToString(domain->args_and_result_[i]); + } + os << "):" << ToString(domain->args_and_result_.back()); + } + return os.str(); +} + +std::string DeviceDomains::ToString() { + std::ostringstream os; + for (const auto& pair : expr_to_domain_) { + os << "expression:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + for (const auto& pair : call_to_callee_domain_) { + os << "call:" << std::endl + << PrettyPrint(GetRef(pair.first)) << std::endl + << "callee domain:" << std::endl + << ToString(pair.second) << std::endl + << std::endl; + } + return os.str(); +} + +DeviceDomainPtr DeviceDomains::ResultDomain(DeviceDomainPtr domain) { + domain = Lookup(domain); + while (!domain->args_and_result_.empty()) { + domain = Lookup(domain->args_and_result_.back()); + } + return domain; +} + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/transforms/device_domains.h b/src/relay/transforms/device_domains.h new file mode 100644 index 000000000000..a29370a0e807 --- /dev/null +++ b/src/relay/transforms/device_domains.h @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_domains.h + * \brief Unification domain for the device planner. + */ + +#ifndef TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ +#define TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { + +class DeviceDomain; +using DeviceDomainPtr = std::shared_ptr; + +/*! + * \brief Represents the domain over which we collect equality constraints. + * + * \code + * D ::= ?x? -- first order, free + * | -- first order, bound + * | fn(D1, ..., Dn):Dr -- higher order + * \endcode + * + * We require a function value to be on the same device as its result. To support that we need + * a notion of the 'result domain' of a domain: + * \code + * result_domain(?x?) = ?x? + * result_domain() = + * result_domain(fn(D1, ..., Dn):Dr) = result_domain(Dr) + * \endcode + */ +class DeviceDomain { + public: + /*! + * \brief Constructs a first-order domain of \p device_type, which may be + * \p kInvalidDeviceType to indicate the domain is free. + */ + explicit DeviceDomain(DLDeviceType device_type) : device_type_(device_type) {} + + /*! + * \brief Constructs a higher-order domain, where \p args_and_result contain the + * function argument and result domains in order. + */ + explicit DeviceDomain(std::vector args_and_result) + : device_type_(kInvalidDeviceType), args_and_result_(std::move(args_and_result)) {} + + /*! \brief Returns true if domain is first-order and free. */ + bool is_free() const { return device_type_ == kInvalidDeviceType && args_and_result_.empty(); } + + /*! \brief Returns true if domain is higher-order. */ + bool is_higher_order() const { return !args_and_result_.empty(); } + + DLDeviceType first_order_device_type() const { + ICHECK(args_and_result_.empty()); + return device_type_; + } + + size_t function_arity() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.size() - 1UL; + } + + DeviceDomainPtr function_param(size_t i) const { + ICHECK(!args_and_result_.empty()); + ICHECK_LT(i + 1, args_and_result_.size()); + return args_and_result_[i]; + } + + DeviceDomainPtr function_result() const { + ICHECK(!args_and_result_.empty()); + return args_and_result_.back(); + } + + private: + /*! + * \brief If this is a function domain then always kInvalidDevice. Otherwise will be + * kInvalidDevice if the domain is still free, or the specific concrete device if the domain is + * bound. + */ + const DLDeviceType device_type_; + + /*! + * \brief If this is a function domain then the sub-domains for each of the function's + * arguments, and the domain for its result. Otherwise empty. + */ + const std::vector args_and_result_; + + friend struct DeviceDomainHash; + friend struct DeviceDomainEqual; + friend class DeviceDomains; +}; + +// The following hash and equality helpers give each free first-order domain pointer its own +// distinct identity. +struct DeviceDomainHash { + size_t operator()(const DeviceDomainPtr& domain) const; +}; + +struct DeviceDomainEqual { + public: + bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const; +}; + +/*! + * \brief Tracks the device domains for a set of expressions w.r.t. an equivalence relation + * built up by calls to \p Unify. + */ +class DeviceDomains { + public: + DeviceDomains() = default; + + /*! + * \brief Returns a domain appropriate for \p type who's result domain is bound + * to \p device_type. If \p device_type is \p kInvalidDeviceType then the entire domain + * will be free. + */ + static DeviceDomainPtr MakeDomain(const Type& type, DLDeviceType device_type); + + /*! + * \brief Returns a higher-order domain with \p args_and_results. + */ + static DeviceDomainPtr MakeDomain(std::vector arg_and_results) { + return std::make_shared(std::move(arg_and_results)); + } + + /*! \brief Returns a domain with the given result device type appropriate \p device_type. */ + static DeviceDomainPtr ForDeviceType(const Type& type, DLDeviceType device_type) { + ICHECK_NE(device_type, kInvalidDeviceType); + return MakeDomain(type, device_type); + } + + /*! \brief Returns a free domain appropriate for \p type. */ + static DeviceDomainPtr Free(const Type& type) { return MakeDomain(type, kInvalidDeviceType); } + + /*! \brief Returns the domain representing the equivalence class containing \p domain. */ + DeviceDomainPtr Lookup(DeviceDomainPtr domain); + + /*! + * \brief Returns the domain accounting for all bound devices in \p lhs and \p rhs. + * + * Throws \p Error on failure. + */ + DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + + /*! + * \brief Unifies \p lhs and \p rhs, returning the most-bound of the two. Fails if \p lhs and \p + * rhs disagree on bound device type. + * + * Throws \p Error on failure. + */ + // TODO(mbs): I don't think we need an occurs check since the program is well-typed, but + // given we have refs to functions I'm prepared to be surprised. + DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs); + + /*! + * \brief Unifies \p lhs and \p rhs. If \p lhs is first-order and \p rhs is higher-order, + * require all arguments and result of \p rhs to unify with \p lhs. Otherwise same as + * \p Unify. + * + * Throws \p Error on failure. + */ + void UnifyCollapsed(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + + /*! \brief Returns true if a domain is known for \p expr. */ + bool contains(const Expr& expr) const { return expr_to_domain_.count(expr.get()); } + + /*! \brief Returns the domain representing \p expr. */ + DeviceDomainPtr DomainFor(const Expr& expr); + + /*! + * \brief Returns the domain representing the callee (ie 'op') in \p call expression. If the + * callee is a primitive or special operation we handle it specially. Otherwise defers to \p + * DomainFor(call->op). + * + * This special handling is needed: + * - To handle the "on_device" and "device_copy" ops which constrain devices to the given devices. + * - To handle some special ops which constrain devices to the CPU. + * - To allow the same primitive to be called on different devices at different call sites. + * Since each call to the op can have a different domain we index the ops by the call expression + * rather than the op itself. + */ + DeviceDomainPtr DomainForCallee(const Call& call); + + /*! \brief Unifies the domains for expressions \p lhs and \p rhs. */ + void UnifyExprExact(const Expr& lhs, const Expr& rhs); + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + */ + void UnifyExprExact(const Expr& expr, const DeviceDomainPtr& expected_domain); + + /*! + * \brief Unifies the domain for \p expr with \p expected_domain. + * If \p expected_domain is higher-order but \p expr is first-order, require all arguments + * and the result of \p expected_domain to have the same domain as for \p expr. + */ + void UnifyExprCollapsed(const Expr& expr, const DeviceDomainPtr& expected_domain); + + /*! \brief Returns true if \p domain contains any free sub-domains. */ + bool AnyFree(DeviceDomainPtr domain); + + /* + * \brief Force all domains in \p higher_order_domain to unify with \p first_order_domain. + * This can be used to handle functions within tuples, references and ADTs since we don't + * attempt to track anything beyond 'the device' for expressions of those first-order types. + * + * Throws \p Error on failure. + */ + void Collapse(const DeviceDomainPtr& first_order_domain, + const DeviceDomainPtr& higher_order_domain); + + /*! \brief Force all free domains in \p domain to default to \p default_device_type. */ + void SetDefault(DeviceDomainPtr domain, DLDeviceType default_device_type); + + /*! + * \brief If \p domain is higher-order and its result domain is free, force it to + * \p default_device_type. Then force any remaining free domains to the result domain + * (freshly defaulted or original). If \p domain is first-order same as \p SetDefault. + */ + void SetResultDefaultThenParams(const DeviceDomainPtr& domain, DLDeviceType default_device_type); + + /*! \brief Returns one-line description of \p domain for debugging. */ + std::string ToString(DeviceDomainPtr domain); + + /*! \brief Returns description of entire system of constraints for debugging */ + std::string ToString(); + + /*! + * \brief Returns the result domain for \p domain (see defn in DeviceDomain comment). + */ + DeviceDomainPtr ResultDomain(DeviceDomainPtr domain); + + /*! + * \brief Returns the result (possibly free) device type for \p domain (see defn in DeviceDomain + * comment). + */ + DLDeviceType ResultDeviceType(const DeviceDomainPtr& domain) { + return ResultDomain(domain)->first_order_device_type(); + } + + private: + /*! \brief Intrinsics we need to handle specially. */ + const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); + const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); + const Op& shape_of_op = Op::Get("vm.shape_of"); + const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); + const Op& shape_func_op = Op::Get("vm.shape_func"); + const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); + /*! \brief The CPU device type for special operators such as dynamic shape functions. */ + const DLDeviceType cpu_device_type_ = kDLCPU; + /*! \brief Placeholder for any first-order type. */ + Type arb_ = TupleType(); + /*! \brief The domain for first-order expressions on the CPU. */ + DeviceDomainPtr cpu_domain_ = ForDeviceType(arb_, cpu_device_type_); + + /*! \brief Maps expressions to their domains as determined during analysis. */ + std::unordered_map expr_to_domain_; + + /*! + * \brief Maps call expressions to the domains for their callee where the callee is a primitive. + */ + std::unordered_map call_to_callee_domain_; + + /*! \brief Maps device domains to their equivalent domains as determined during unification. */ + std::unordered_map + domain_to_equiv_; +}; + +} // namespace transform +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_TRANSFORMS_DEVICE_DOMAINS_H_ diff --git a/src/relay/transforms/device_planner.cc b/src/relay/transforms/device_planner.cc new file mode 100644 index 000000000000..35bf406263e4 --- /dev/null +++ b/src/relay/transforms/device_planner.cc @@ -0,0 +1,1123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/device_planner.cc + * \brief Determines a unique device to hold the result of every Relay sub-expression. + * + * We say a Relay expression E is 'on device D' if the result of executing E is stored on D. + * Currently we only track the 'device_type' of D and not its 'device id'. We do not track the + * specific target associated with D (this is recovered independently via a TargetMap), and we + * do not track the storage scope within D (this is yet to be implemented). + * + * Note that 'stored on device D' is almost but not quite the same as 'executes on device D', + * see below. + * + * This pass assumes the module already contains some "on_device" and/or "device_copy" CallNodes: + * - "device_copy" CallNodes (with a \p DeviceCopyAttrs attribute) specify a 'src_dev_type' and + * 'dst_dev_type' device type, which constrain the argument and context of the call + * respectively. It is ok if source and destination devices are the same, such no-op copies + * will be removed after accounting for the device preference. + * - "on_device" CallNodes (with a \p OnDeviceAttrs attribute) specify a 'device_type', which + * constrains the argument of the call, but (usually, see below) leaves the context + * unconstrained. These are called 'annotations' in the rest of the code, have no operational + * significance by themselves, but may trigger the insertion of a new "device_copy". + * - In two situations the result of an "on_device" CallNode may also be constrained to the + * given device: + * - The "on_device" call occurs at the top-level of a function body, or occurs as an + * immediately let-bound expression. In this situation the extra degree of freedom in + * the function result and let-binding leads to surprising device copies, so we simply + * force the function result or let-bound variable to the given device. + * - The \p OnDeviceAttrs has an \p is_fixed field of \p true, which indicates we inserted + * it ourselves during an earlier invocation of this pass. This helps make this pass + * idempotent. + * + * We proceed in four phases: + * + * Phase 0 + * ------- + * We rewrite the programs to handle some special cases: + * - "on_device" calls at the top-level of function or immediately let-bound are rewritten + * to have \code is_fixed=true \endcode. + * - We wish to treat \code on_device(expr, device_type=d).0 \endcode as if it were written + * \code on_device(expr.0, device_type_d) \endcode. I.e. we prefer to copy the projection from + * the tuple rather than project from a copy of the tuple. We'll do this by rewriting. + * + * Phase 1 + * ------- + * We flow constraints from the "on_device" and "device_copy" calls (and some special ops, see + * below) to all other Relay sub-expressions. (For idempotence we also respect any existing + * "param_device_types" and "result_device_type" function attributes we introduce below.) + * + * For a primitive such as \code add(e1, e2) \endcode all arguments and results must be on the + * same device. However each call site can use a different device. In other words primitives are + * 'device polymorphic' since we compile and execute them for each required device. + * + * For most Relay expressions the device for the overall expression is the same as the device + * for it's sub-expressions. E.g. each field of a tuple must be on the same device as the tuple + * itself, the condition and arms of an \p if must all be on the same device as the overall if, + * and so on. + * + * Some special ops (or 'dialects') are handled: + * - Relay supports computing the shape of tensors and operators at runtime using "shape_of", + * "shape_func", and "reshape_tensor". Shapes must only be held on the CPU, but the tensors + * they describe may reside on any device. + * - Explicit memory allocation is done using the "alloc_storage" and "alloc_tensor". Again + * shapes reside on the CPU, but the allocated tensors may reside on any device. + * + * Two Relay expression have special handling: + * - For \code let x = e1; e2 \endcode the result of \p e2 must be on the same device as the + * overall let. However the result of \p e1 may be on a different device. + * - For a function \code fn(x, y) { body } \endcode the result of the function must be on the + * same device as \p body. However parameters \p x and \p may be on different devices, even + * different from each other. Every call to the function must use the same choice of parameter + * and result devices -- there is no 'device polymorphism' for Relay functions. + * + * Phase 2 + * ------- + * After flowing constraints we apply some defaulting heuristics (using a global default device) + * to fix the device for any as-yet unconstrained sub-expressions. + * - Unconstrained function result devices default to the global default device. + * - Unconstrained function parameters devices default to the device for the function result. + * - Unconstrained let-bound expression devices default to the device for the overall let. + * TODO(mbs): I may have over-innovated here and we simply want to bind all free domaints to + * the global default device. Worth a design doc with motivating examples I think. + * + * Phase 3 + * ------- + * Finally, the result of this analysis is reified into the result as: + * - Additional "param_device_types" (an Array) and "result_device_type" (Integer) + * attributes for every function (both top-level and local). These describe the devices for + * the function's parameters and the result. + * - Additional "device_copy" CallNodes where a copy is required in order to respect the + * intent of the original "on_device" CallNodes. + * - Additional "on_device" CallNodes where the device type of an expression does not match + * that of the lexically enclosing "on_device" CallNode or function attribute. In practice + * this means "on_device" CallNodes may appear in two places: + * - On a let-bound expression if its device differs from the overall let expression. + * - On a call argument if its device differs from the call result. In particular, the + * argument to a "device_copy" call will always be wrapped in an "on_device". (That may + * seem pedantic but simplifies downstream handling.) + * However since we make it easy to track devices for variables we never wrap an "on_device" + * around a var or global var. These uses of "on_device" imply both the argument and result are + * on the same device. We signal this by setting the 'is_fixed' OnDeviceAttrs field to true, + * which helps make this pass idempotent. + * + * Helper visitors (in device_aware_visitors.h) can be used by downstream transforms to recover + * the device for any expression for their own use, e.g. during memory planning. All downstream + * passes must preserve the lexical scoping of the "on_device" CallNodes. E.g. conversion + * to ANF must respect the lexical scoping convention: + * \code + * f(on_device(g(h(a, b), c), device_type=CPU)) + * ==> + * let %x0 = on_device(h(a, b), device_type=CPU) + * let %x1 = on_device(g(%x0), device-type=CPU) + * f(on_device(%x1, device_type=CPU)) + * \endcode + * + * This pass can be run before FuseOps it can use device-specific fusion rules. + * + * 'Stored on' vs 'Executes on' + * ---------------------------- + * Obviously for a primitive call \code add(x, y) \endcode we can execute the primitive on the + * same device as will hold its result. Thus 'executes on' is the same as 'stored on' for + * primitives. + * + * But what about for arbitrary Relay expressions? Most backends (interpreter, graph, VM) are + * implicitly executed on the 'host' CPU, with only primitive evaluation handed off to specific + * devices, thus the notion of 'executes on' is mute. AOT backends on the other hand need to + * know exactly which device (possibly one of a number of available 'CPU'-like devices) is + * responsible for execution. Currently that's handled independently by the \p AnnotateTargets + * pass, but we'd like to fold that into device planning here to ensure everything is consistent. + * + * Obviously since tensors are passed-by-pointer it's quite possible to execute a Relay + * expression (eg an if expression) on one device even though the tensor data resides on + * another. But for AOT that flexibility seems excessive. So we'd like to just take 'executes on' + * to be 'stored on' exactly. In particular, for a Relay function, we'd like to be able to just + * compile the function body for the function's result device. + * + * This works after conversion to ANF provided the compilation for a let expression is prepared + * to make a cross-device call. However we leave it to a downstream transformation to heuristically + * minimize cross-device calls by moving device copies out of functions. E.g.: + * \code + * def @f() { // execute on CPU + * let x = on_device(...GPU computation..., device_type=GPU); + * device_copy(...GPU computation..., src_dev_type=GPU, dst_dev_type=CPU) + * } + * def @main() { + * ... call @f() on CPU ... + * } + * \endcode + * could be rewritten to: + * \code + * def @f() { // execute on GPU + * let x = ...GPU computation...; + * ...GPU computation... + * } + * def @main() { + * let x = device_copy(@f(), src_dev_type=GPU, dst_dev_type=CPU) + * ... use x on CPU ... + * } + * \endcode + * + * Higher-order shenanigans + * ------------------------ + * Relay is a 'mostly' higher-order language -- we can let-bind functions, pass functions + * as arguments (even anonymous functions), return functions, evaluate conditional expressions + * over functions, and so on. We handle this during constraint solving using the domain: + * \code + * D ::= -- first-order + * | fn(D,...,D):D -- higher-order + * \endcode + * In this way we can determine the device for all function parameters and results. E.g. for + * \code + * let f = fn(x, y) { ... } + * let g = fn(f, z) { f(z, z) } + * g(f, on_device(..., device_type=CPU)) + * \endcode + * the parameters \p x and \p y will be on the CPU. + * + * But now look closely at the call \code e1(e2, e3) \endcode. We know \p e1 must evaluate to a + * function. Our analysis must guarantee that the function's parameters and result devices are + * consistent for \p e2, \p e3, and the context of the call. But: + * - Which device holds the closure result of evaluating \p e1 ? + * - If \p e2 is of function type, what does that mean when we say every function parameter + * is on a device? + * - If \p e1 returns a function, what does that mean when we say every function result is + * on a device? + * + * Since higher-order aspects are later compiled away (by 'defunctionalization' + * aka 'firstification') we'd prefer not to have to answer any of those questions. In particular, + * we really don't want our domain \p D to allow for yet another device for the function closure. + * So we'll just force the 'device for a function' to be the same as the device for the function's + * result using the notion of the 'result domain' for a domain: + * \code + * result_domain() = + * result_domain(fn(D1,...,Dn):Dr) = result_domain(Dr) + * \endcode + * + * Similarly the domain does not have entries for tuples, references, or ADTs. Whenever the + * analysis encounters a function inside one of those it simply forces all argument and result + * devices for the function to match the device for the first-order expression. For example, + * if the tuple \code (fn(x, y) { ... }, 3) \endcode is on the GPU then the inner function + * parameters and result must similarly be on the GPU. + * + * ------- + * | AOR | This pass supports all of Relay. + * ------- + * ^ + * | + * `-- Mark's stamp of completeness :-) + * + * TODO(mbs): + * * Though on_device is the identity for all types we can't wrap it around functions/constructors + * taking type args (or at least not without changing type_infer.cc to see through them). + * This is not currently handled generally. + * * Proper diagnostics for unification failure using spans. + * * Make sure the pass is idempotent even after FuseOps etc. + * * Support application of constructors properly. Are they device polymorphic? + * * Replace DLDeviceType with TargetDevice, and unify 'target annotation' with 'device planning'. + * * Support running the pass post FuseOps (so need to understand primitive functions, both + * outlines and lined) and post the VM transforms (probably need to support more intrinsic + * forms?). + * * Don't hardcode the 'CPU' device for shape funcs etc, and distinguish between the default + * device for primitives vs the default device for the rest of Relay. + * * We'll probably need some support for partial 'device polymorphism' for functions once we + * incorporate targets and memory scopes into the domain. For example it's ok for the function + * body to be executed on different device ids provided they have the same target and memory + * scope. + * * Might be simpler to just let every type have a device annotation rather than work in + * a separate domain? + * * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies. + * * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls + * in tuples at the top level of function bodies or main expression, irrespective of the + * "on_device" body. What's up with that? + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "../op/annotation/annotation.h" +#include "../op/memory/device_copy.h" +#include "./device_domains.h" + +namespace tvm { +namespace relay { +namespace transform { + +namespace { + +/****** +******* Phase 0 +*******/ + +/*! + * \brief Rewrites "on_device" calls to handle some special cases. + * + * \code + * let %x = on_device(e, device_type=d) + * ==> let %x = on_device(e, device_type=d, is_fixed=True) + * + * fn(%x) { on_device(e, device_type=d) } + * ==> fn(%x) { on_device(e, device_type=d, is_fixed=True) + * + * on_device(e).0 + * ==> on_device(e.0) + * \endcode + */ +class RewriteOnDevices : public ExprMutator { + public: + RewriteOnDevices() = default; + + private: + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + Expr tuple = VisitExpr(tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy. + Expr tuple_get_item = + TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + auto props = GetOnDeviceProps(tuple); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "wrapping tuple get item:" << std::endl + << PrettyPrint(GetRef(tuple_get_item_node)) << std::endl + << "with \"on_device\" for device " << props.device_type; + return OnDevice(tuple_get_item, props.device_type, /*is_fixed=*/false); + } else { + return tuple_get_item; + } + } + + Expr VisitExpr_(const LetNode* let_node) final { + auto expr = GetRef(let_node); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + Expr value = VisitExpr(inner_let_node->value); + auto props = GetOnDeviceProps(value); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising let-bound expression of let:" << std::endl + << PrettyPrint(expr) << std::endl + << "to be fixed to device " << props.device_type; + value = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + expr = VisitExpr(expr); + // TODO(mbs): Avoid copy. + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + expr = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), expr, + /*span=*/std::get<2>(*itr)); + } + return expr; + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + Expr body = VisitExpr(function_node->body); + auto props = GetOnDeviceProps(body); + if (props.body.defined() && !props.is_fixed) { + VLOG(1) << "revising body of function:" << std::endl + << PrettyPrint(GetRef(function_node)) << std::endl + << "to be fixed to device " << props.device_type; + body = OnDevice(props.body, props.device_type, /*is_fixed=*/true); + } + // TODO(mbs): Avoid copy + return Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + } +}; + +/****** +******* Phase 1 +*******/ + +/* + * \brief Collects the system of device constraints for all sub-expressions in a module. + * It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter. + * + * Eg from \code add(%x, %y) \endcode we know \p %x and \p %y must be on the same device. Later, + * from \code on_device(%x, device_type=d) \endcode we know \p %x must be on device \p d, and thus + * so must \p %y. + * + * Constraints can flow in interesting ways. E.g. in: + * \code + * let %f = fn(%x, %y) { add(%x, on_device(%y, device_type=d)) } + * let %g = fn(%f, %x, %y) { %f(%x, %y) } + * %g(%f, %a, %b) + * \endcode + * we discover \p %b must be on device \p d. + */ +class DeviceAnalyzer : public ExprVisitor { + public: + explicit DeviceAnalyzer(IRModule mod) + : mod_(std::move(mod)), domains_(std::make_unique()) {} + + /*! + * \brief Returns the expression-to-device-domain map for all expressions in all the global + * function definitions in the module. Expressions may have free domains, these will be resolved + * by \p DeviceDefaulter below. + */ + std::unique_ptr Analyze() { + VLOG_CONTEXT << "DeviceAnalyzer"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "collecting constraints for '" << PrettyPrint(pair.first) << "'"; + domains_->UnifyExprExact(pair.first, pair.second); + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + + // Find the higher-order domain for the callee. See DomainForCallee for the special rules + // for primitives. + VisitExpr(call_node->op); + auto func_domain = domains_->DomainForCallee(call); // higher-order + + // Build the domain for the function implied by its arguments and call context. + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + std::vector args_and_result_domains; + args_and_result_domains.reserve(call_node->args.size() + 1); + for (const auto& arg : call_node->args) { + args_and_result_domains.emplace_back(domains_->DomainFor(arg)); + VisitExpr(arg); + } + args_and_result_domains.emplace_back(domains_->DomainFor(call)); + auto implied_domain = + DeviceDomains::MakeDomain(std::move(args_and_result_domains)); // higher-order + + VLOG(1) << "initial call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied domain:" << std::endl + << domains_->ToString(implied_domain) << std::endl + << "for call:" << std::endl + << PrettyPrint(call); + + // The above must match. + try { + domains_->Unify(func_domain, implied_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) << "Function parameters and result devices do not match those of call. Call:" + << std::endl + << PrettyPrint(call) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and implied call devices:" << std::endl + << domains_->ToString(implied_domain) << std::endl + << e.what(); + } + + VLOG(1) << "final call function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "for call:" << std::endl + << PrettyPrint(call); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // Let var must be same device as value it is bound to. + domains_->UnifyExprExact(let->var, let->value); // may be higher-order + // Let body must be same device as overall let. + domains_->UnifyExprExact(let, let->body); // may be higher-order + + VisitExpr(let->var); + VisitExpr(let->value); + + expr = let->body; + } + + // Visit the last body + VisitExpr(expr); + } + + void VisitExpr_(const FunctionNode* function_node) final { + // No need to step into fused primitive functions as they are lowered individually according + // to the devices of all their call sites. + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + + // The function body domain must match the function result domain. + domains_->UnifyExprExact(function_node->body, + func_domain->function_result()); // may be higher-order + + VLOG(1) << "initial function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + // The parameter domains must match the function argument domains. + domains_->UnifyExprExact(function_node->params[i], + func_domain->function_param(i)); // may be higher-order + VisitExpr(function_node->params[i]); + } + + // If the function already has device attributes then we can further constrain the + // function's domain to match them. + if (GetFunctionResultDeviceType(function_node) != kInvalidDeviceType) { + std::vector args_and_result; + for (size_t i = 0; i < function_node->params.size(); ++i) { + args_and_result.emplace_back( + domains_->ForDeviceType(function_node->params[i]->checked_type(), + GetFunctionParamDeviceType(function_node, i))); + } + args_and_result.emplace_back(domains_->ForDeviceType( + function_node->body->checked_type(), GetFunctionResultDeviceType(function_node))); + auto annotation_domain = domains_->MakeDomain(std::move(args_and_result)); + try { + domains_->Unify(func_domain, annotation_domain); // higher-order + } catch (const Error& e) { + // TODO(mbs): Proper diagnostics. + LOG(FATAL) + << "Function devices are incompatible with its \"on_device\" annotation. Function:" + << std::endl + << PrettyPrint(function) << std::endl + << "with function devices:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and annotation devices:" << std::endl + << domains_->ToString(annotation_domain) << std::endl + << e.what(); + } + } + + VisitExpr(function_node->body); + + VLOG(1) << "final function domain:" << std::endl + << domains_->ToString(func_domain) << std::endl + << "and function body domain:" << std::endl + << domains_->ToString(domains_->DomainFor(function_node->body)) << std::endl + << "for function:" << std::endl + << PrettyPrint(function); + } + + void VisitExpr_(const TupleNode* tuple_node) final { + Tuple tuple = GetRef(tuple_node); + for (size_t i = 0; i < tuple->fields.size(); i++) { + auto domain = domains_->DomainFor(tuple->fields[i]); // may be higher-order + domains_->UnifyExprCollapsed(tuple, domain); // collapse to first-order if needed + VisitExpr(tuple->fields[i]); + } + } + + void VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + TupleGetItem tuple_get_item = GetRef(tuple_get_item_node); + auto domain = domains_->DomainFor(tuple_get_item); // may be higher-order + domains_->UnifyExprCollapsed(tuple_get_item_node->tuple, + domain); // collapse to first-order if needed + VisitExpr(tuple_get_item_node->tuple); + } + + class DevicePatternAnalyzer : public PatternVisitor { + public: + DevicePatternAnalyzer(DeviceDomains* domains, const ExprNode* adt_node) + : domains_(domains), adt_node_(adt_node) {} + + private: + void VisitPattern_(const PatternVarNode* pattern_var_node) final { + auto var_domain = domains_->DomainFor(pattern_var_node->var); // may be higher order + domains_->UnifyExprCollapsed(GetRef(adt_node_), + var_domain); // collapse to first-order if needed + } + + /*! \brief (Mutable borrow of) the domains for all expressions processed so far. */ + DeviceDomains* domains_; + /*! \brief The expression for the ADT we are matching over. */ + const ExprNode* adt_node_; + }; + + void VisitPattern(const Pattern& pattern) final {} + + void VisitExpr_(const MatchNode* match_node) final { + // For match node, we unify the value and the rhs of each clause + Match match = GetRef(match_node); + auto match_domain = domains_->DomainFor(match); // may be higher-order + DevicePatternAnalyzer pattern_analyzer(domains_.get(), match->data.get()); + domains_->UnifyExprCollapsed(match->data, match_domain); // collapse to first-order if needed + for (const auto& clause : match->clauses) { + pattern_analyzer.VisitPattern(clause->lhs); + domains_->UnifyExprExact(clause->rhs, match_domain); + VisitExpr(clause->rhs); + } + VisitExpr(match_node->data); + } + + void VisitExpr_(const GlobalVarNode* global_var_node) final { + domains_->DomainFor(GetRef(global_var_node)); + } + + void VisitExpr_(const VarNode* var_node) final { domains_->DomainFor(GetRef(var_node)); } + + void VisitExpr_(const ConstantNode* constant_node) final { + domains_->DomainFor(GetRef(constant_node)); + } + + void VisitExpr_(const ConstructorNode* constructor_node) final { + // no-op, constructors are handled at their call-sites. + // TODO(mbs): Assumes eta-expansion + } + + void VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + auto domain = domains_->DomainFor(ife); // may be higher-order + domains_->UnifyExprCollapsed(if_node->cond, domain); // collapse to first-order if needed + domains_->UnifyExprExact(if_node->true_branch, domain); + domains_->UnifyExprExact(if_node->false_branch, domain); + VisitExpr(if_node->cond); + VisitExpr(if_node->true_branch); + VisitExpr(if_node->false_branch); + } + + void VisitExpr_(const OpNode* op) final { + // no-op, primitive operators are handled at their call-sites. + } + + void VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + auto domain = domains_->DomainFor(ref_create_node->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_create, domain); // collapse to first-order if needed + VisitExpr(ref_create_node->value); + } + + void VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + auto domain = domains_->DomainFor(ref_read); // may be higher-order + domains_->UnifyExprCollapsed(ref_read_node->ref, domain); // collapse to first-order if needed + VisitExpr(ref_read_node->ref); + } + + void VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + auto domain = domains_->DomainFor(ref_write->value); // may be higher-order + domains_->UnifyExprCollapsed(ref_write->ref, domain); // collapse to first-order if needed + domains_->UnifyExprCollapsed(ref_write, domain); // collapse to first-order if needed + VisitExpr(ref_write_node->ref); + VisitExpr(ref_write_node->value); + } + + /*! \brief The module we are analyzing. */ + IRModule mod_; + /*! \brief The domains for all expressions processed so far. */ + std::unique_ptr domains_; +}; + +/****** +******* Phase 2 +*******/ + +/*! + * \brief Ensures every sub-expression in a module has a device type, using both the global + * default and some local heuristics to avoid unnecessary additional "device_copy" CallNodes. + * + * E.g. in: + * \code + * def @main(%x, %y, %z) { + * let %a = add(%x, %y); + * multiply(%a, on_device(%z, device_type=d)) + * \endcode + * we know the parameter \p %z must be on device \p d, but the devices for \p %x and \p %y, + * and the device for the function result, are still 'free'. The global 'default' device type + * is first used to 'fix' \p @main's result type, which in turn 'fixes' \p %x and \p %y, which + * in turn 'fixes' the device on which the \p add and \p multiply are executed. + * + * TODO(mbs): I think this is deterministic? We do however visit the top-level defs in hashmap + * order. + */ +class DeviceDefaulter : public ExprVisitor { + public: + DeviceDefaulter(IRModule mod, std::unique_ptr domains, + DLDeviceType default_device_type) + : mod_(std::move(mod)), + domains_(std::move(domains)), + default_device_type_(default_device_type) {} + + std::unique_ptr Default() { + VLOG_CONTEXT << "DeviceDefaulter"; + for (const auto& pair : mod_->functions) { + VLOG(1) << "defaulting devices for '" << PrettyPrint(pair.first) << "'"; + VisitExpr(pair.second); + } + return std::move(domains_); + } + + private: + void VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return; + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + if (domains_->AnyFree(func_domain)) { + VLOG(1) << "before defaulting function:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting function:" << std::endl << domains_->ToString(func_domain); + } + VisitExpr(function_node->body); + } + + void VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + auto func_domain = domains_->DomainForCallee(call); // higher-order + ICHECK_EQ(func_domain->function_arity(), call_node->args.size()); + if (domains_->AnyFree(func_domain)) { + // For calls to Relay functions this step is identical to that for VisitExpr_(FunctionNode*) + // above. But for calls to primitives we may still need to force free domains to be + // defaulted. + VLOG(1) << "before defaulting callee:" << std::endl << domains_->ToString(func_domain); + domains_->SetResultDefaultThenParams(func_domain, default_device_type_); + VLOG(1) << "after defaulting callee:" << std::endl << domains_->ToString(func_domain); + } + return ExprVisitor::VisitExpr_(call_node); + } + + void VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // If the let-var device is still free force it to match the overall let. + auto let_domain = domains_->DomainFor(let); // may be higher-order + DLDeviceType let_device_type = domains_->ResultDeviceType(let_domain); + ICHECK_NE(let_device_type, kInvalidDeviceType); + auto let_var_domain = domains_->DomainFor(let->var); // may be higher-order + if (domains_->AnyFree(let_var_domain)) { + VLOG(1) << "before defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + domains_->SetDefault(let_var_domain, let_device_type); + VLOG(1) << "after defaulting let-var:" << std::endl << domains_->ToString(let_var_domain); + } + VisitExpr(let->var); + VisitExpr(let->value); + expr = let->body; + } + VisitExpr(expr); + } + + /*! \brief The module we are processing. */ + IRModule mod_; + /*! \brief The domains for all expressions. */ + std::unique_ptr domains_; + /*! \brief The default device type. */ + DLDeviceType default_device_type_; +}; + +/****** +******* Phase 3 +*******/ + +/*! + * \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every + * sub-expression in a module can be easily recovered by a later transformation using simple + * lexical scoping rules (e.g. for memory planning). + * + * - Discard any existing "on_device" CallNodes since their job is done. Similarly, discard + * any existing "device_copy" CallNodes which are no-ops. + * + * - Functions are given "param_device_types" and "result_device_type" attributes to capture + * the device type for its parameters and result. + * + * - Additional "device_copy" CallNodes are inserted wherever there's a transition between + * storage device types. Since the DeviceAnalyzer phase succeeded this can only happen + * where the original program explicitly allowed a transition using an "on_device" CallNode. + * That is, we do not not try to 'fix' a program with inconsistent devices. + * + * - Additional "on_device" CallNodes are inserted so that a later transform can discover + * the device for an arbitrary sub-expression by looking only for the lexically enclosing + * "on_device" CallNode or "on_device" function attribute. In particular, since function + * arguments and let-bound expressions can be on a device different from the function + * or let body itself we will insert "on_device" CallNodes to spell out any differences. This + * applies even to the argument to a "device_copy" CallNode, which may look pedantic but + * keeps downstream processing simple. The "on_device" calls should be removed before code gen, + * which is easily done on-the-fly. + * + * For example, we'll end up with programs that look like: + * \code + * def @main(%x, %y, param_device_types=[...], result_device_type=...) { + * let %a = on_device(..., device_type=..., is_fixed=True) + * @f(%a, device_copy(on_device(..., device_type=..., is_fixed=True), + * src_device_type=..., dst_device_type=...)) + * } + * \endcode + */ +class DeviceCapturer : public ExprMutator { + public: + DeviceCapturer(IRModule mod, std::unique_ptr domains) + : mod_(std::move(mod)), domains_(std::move(domains)) {} + + IRModule Capture() { + VLOG_CONTEXT << "CaptureDevices"; + IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map); + for (const auto& pair : mod_->functions) { + VLOG(1) << "capturing devices for '" << PrettyPrint(pair.first) << "'"; + result->Add(pair.first, Downcast(Mutate(pair.second))); + } + return result; + } + + private: + // Nothing interesting for VarNode, ConstantNode, GlobalVarNode, OpNode and ConstructorNode + + Expr VisitExpr_(const TupleNode* tuple_node) final { + auto tuple = GetRef(tuple_node); + Array fields; + fields.reserve(tuple_node->fields.size()); + for (const auto& field : tuple_node->fields) { + fields.push_back(VisitChild(tuple, field)); + } + // TODO(mbs): Avoid copy + return Tuple(std::move(fields), tuple_node->span); + } + + Expr VisitExpr_(const FunctionNode* function_node) final { + if (function_node->HasNonzeroAttr(attr::kPrimitive)) { + return GetRef(function_node); + } + + auto function = GetRef(function_node); + auto func_domain = domains_->DomainFor(function); // higher-order + VLOG(1) << "capturing function:" << std::endl + << PrettyPrint(function) << std::endl + << "with domain:" << std::endl + << domains_->ToString(func_domain); + + // Gather the parameter and result device types for the function attributes. + ICHECK_EQ(func_domain->function_arity(), function_node->params.size()); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + Array param_device_types; + param_device_types.reserve(function_node->params.size()); + for (size_t i = 0; i < function_node->params.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType); + param_device_types.push_back(param_device_type); + } + + // Rewrite the body. Note that the body may have begun with an "on_device" so + // be prepared to insert a "device_copy". + Expr body = VisitChild( + /*lexical_device_type=*/result_device_type, + /*expected_device_type=*/result_device_type, + /*child_device_type=*/GetDeviceType(function_node->body), function_node->body); + + // TODO(mbs): Avoid copy + Function func = Function(function_node->params, body, function_node->ret_type, + function_node->type_params, function_node->attrs, function_node->span); + return FunctionOnDevice(func, param_device_types, result_device_type); + } + + Expr VisitExpr_(const CallNode* call_node) final { + auto call = GetRef(call_node); + DLDeviceType call_device_type = GetDeviceType(call); + + auto on_device_props = GetOnDeviceProps(call_node); + if (on_device_props.body.defined()) { + // We're done with the original "on_device" calls and can pinch them out. + // Note that this step has already been simulated by GetDeviceType. + return VisitExpr(on_device_props.body); + } + + auto device_copy_props = GetDeviceCopyProps(call_node); + if (device_copy_props.body.defined()) { + DLDeviceType src_device_type = device_copy_props.src_dev_type; + ICHECK_EQ(call_device_type, device_copy_props.dst_dev_type); + if (call_device_type == src_device_type) { + // We can pinch out existing "device_copy" CallNodes if their source and destinations + // match. + return VisitExpr(device_copy_props.body); + } + // else: handle as for any other call. + } + + auto func_domain = domains_->DomainForCallee(call); // higher-order + VLOG(1) << "considering call:" << std::endl + << PrettyPrint(call) << std::endl + << "on device " << call_device_type << " with function domain:" << std::endl + << domains_->ToString(func_domain); + DLDeviceType result_device_type = domains_->ResultDeviceType(func_domain); + ICHECK_NE(result_device_type, kInvalidDeviceType); + + // The callee is on the current device. + Expr op = VisitChild( + /*lexical_device_type=*/call_device_type, + /*expected_device_type=*/call_device_type, + /*child_device_type=*/result_device_type, call_node->op); + + // Each argument can be on the device for the corresponding function parameter. However if + // any of those differ from the overall call device then wrap them in an "on_device" to + // help downstream transforms track devices lexically. + Array args; + args.reserve(call_node->args.size()); + ICHECK_EQ(func_domain->function_arity(), call->args.size()); + for (size_t i = 0; i < call_node->args.size(); ++i) { + DLDeviceType param_device_type = domains_->ResultDeviceType(func_domain->function_param(i)); + ICHECK_NE(param_device_type, kInvalidDeviceType) + << "for parameter " << i << " for call:" << std::endl + << PrettyPrint(call); + args.push_back(VisitChild(/*lexical_device_type=*/call_device_type, + /*expected_device_type=*/param_device_type, + /*child_device_type=*/GetDeviceType(call_node->args[i]), + call_node->args[i])); + } + // TODO(mbs): Avoid copy + return Call(std::move(op), std::move(args), call_node->attrs, call_node->type_args, + call_node->span); + } + + Expr VisitExpr_(const LetNode* let_node) final { + Expr expr = GetRef(let_node); + // Iterate through chained lets, provided they all agree on their device type. + DLDeviceType let_device_type = GetDeviceType(expr); + std::vector> bindings; + while (const auto* inner_let_node = expr.as()) { + Expr inner_let = GetRef(inner_let_node); + if (GetDeviceType(inner_let) != let_device_type) { + // We have a device transition which needs to be handled. + break; + } + // The let-bound value can be on a different device than the overall let. However if those + // devices don't agree wrap the let-bound value in an "on_device" to help downstream + // transforms track devices lexically. + Expr value = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/GetDeviceType(inner_let_node->var), + /*child_device_type=*/GetDeviceType(inner_let_node->value), + inner_let_node->value); + bindings.emplace_back(inner_let_node->var, value, inner_let_node->span); + expr = inner_let_node->body; + } + Expr body = VisitChild(/*lexical_device_type=*/let_device_type, + /*expected_device_type=*/let_device_type, + /*child_device_type=*/GetDeviceType(expr), expr); + for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) { + body = Let(/*var=*/std::get<0>(*itr), /*value=*/std::get<1>(*itr), body, + /*span=*/std::get<2>(*itr)); + } + return body; + } + + Expr VisitExpr_(const IfNode* if_node) final { + auto ife = GetRef(if_node); + Expr cond = VisitChild(ife, if_node->cond); + Expr true_branch = VisitChild(ife, if_node->true_branch); + Expr false_branch = VisitChild(ife, if_node->false_branch); + // TODO(mbs): Avoid copy + return If(cond, true_branch, false_branch, if_node->span); + } + + Expr VisitExpr_(const TupleGetItemNode* tuple_get_item_node) final { + auto tuple_get_item = GetRef(tuple_get_item_node); + Expr tuple = VisitChild(tuple_get_item, tuple_get_item_node->tuple); + // TODO(mbs): Avoid copy + return TupleGetItem(tuple, tuple_get_item_node->index, tuple_get_item_node->span); + } + + Expr VisitExpr_(const RefCreateNode* ref_create_node) final { + auto ref_create = GetRef(ref_create_node); + Expr value = VisitChild(ref_create, ref_create_node->value); + // TODO(mbs): Avoid copy + return RefCreate(value, ref_create_node->span); + } + + Expr VisitExpr_(const RefReadNode* ref_read_node) final { + auto ref_read = GetRef(ref_read_node); + Expr ref = VisitChild(ref_read, ref_read_node->ref); + // TODO(mbs): Avoid copy + return RefRead(ref, ref_read_node->span); + } + + Expr VisitExpr_(const RefWriteNode* ref_write_node) final { + auto ref_write = GetRef(ref_write_node); + Expr ref = VisitChild(ref_write, ref_write_node->ref); + Expr value = VisitChild(ref_write, ref_write_node->value); + // TODO(mbs): Avoid copy + return RefWrite(ref, value, ref_write_node->span); + } + + Expr VisitExpr_(const MatchNode* match_node) final { + auto match = GetRef(match_node); + Expr data = VisitChild(match, match_node->data); + Array clauses; + clauses.reserve(match_node->clauses.size()); + for (const auto& clause : match_node->clauses) { + Pattern lhs = VisitPattern(clause->lhs); // actually a no-op, so we're not checking vars + Expr rhs = VisitChild(match, clause->rhs); + clauses.push_back(Clause(lhs, rhs)); + } + // TODO(mbs): Avoid copy + return Match(data, std::move(clauses), match_node->complete, match_node->span); + } + + DLDeviceType GetDeviceType(const Expr& expr) { + // Look through any "on_device" CallNodes, to mimic how we will be pinching them out. + auto props = GetOnDeviceProps(expr); + Expr true_expr = props.body.defined() ? props.body : expr; + ICHECK(domains_->contains(true_expr)); + // If expr is higher order we'll return only the result domain's device type. + DLDeviceType device_type = domains_->ResultDeviceType(domains_->DomainFor(true_expr)); + ICHECK_NE(device_type, kInvalidDeviceType) + << "no device type was determined for expression:" << std::endl + << PrettyPrint(true_expr); + return device_type; + } + + /*! + * \brief Reconcile the \p child_device_type for \p child with both the \p expected_device_type + * (as required by the expression context the \p child is in) and the \p lexical_device_type + * (as a downstream transform would infer based only on lexically enclosing "on_device" + * CallNodes and function attributes.) Generally \p lexical_device_type and \p + * expected_device_type are the same by definition, but may differ in arguments to functions + * and let-bound expressions. + * + * If \p child_device_type differs from \p expected_device_type, wrap it as: + * \code + * device_copy(on_device(child', device_type=child_device_type), + * src_dev_type=child_device_type, dst_dev_type=expected_device_type) + * \endcode + * (where child is rewritten to child'). Note the pedantic spelling out of "on_device" on the + * child. + * + * If \p expected_device_type differs from \p lexical_device_type, then (also) wrap + * the expression as: + * \code + * on_device(..., device_type=expected_device_type) + * \endcode + * + * TODO(mbs): There's no attempt at sharing here. If usage of child's node could be wrapped + * by a "device_copy", even though those copies will generally all be to the same destination + * device. + */ + Expr VisitChild(DLDeviceType lexical_device_type, DLDeviceType expected_device_type, + DLDeviceType child_device_type, const Expr& child) { + ICHECK_NE(lexical_device_type, kInvalidDeviceType); + ICHECK_NE(expected_device_type, kInvalidDeviceType); + if (child->IsInstance()) { + // Primitive operators don't need to be rewritten and can have a different domain for + // each call site. + return child; + } + Expr result = VisitExpr(child); + if (child_device_type != expected_device_type) { + VLOG(1) << "creating " << DeviceCopyOp()->name << " from device type " << child_device_type + << " to device type " << expected_device_type << " for:" << std::endl + << PrettyPrint(result); + // Also wrap the child in an "on_device" so downstream transforms can track devices + // lexically. + result = MaybeOnDevice(result, child_device_type, /*is_fixed=*/true); + result = DeviceCopy(result, child_device_type, expected_device_type); + } + if (expected_device_type != lexical_device_type) { + VLOG(1) << "creating " << OnDeviceOp()->name << " for device type " << expected_device_type + << " for:" << std::endl + << PrettyPrint(result); + result = MaybeOnDevice(result, expected_device_type, /*is_fixed=*/true); + } + return result; + } + + /*! + * Common case of visiting a direct \p child of \p parent where by default the \p child + * is expected to be on the same device as the \p parent. + */ + Expr VisitChild(const Expr& parent, const Expr& child) { + DLDeviceType expected_device_type = GetDeviceType(parent); + DLDeviceType child_device_type = GetDeviceType(child); + return VisitChild(expected_device_type, expected_device_type, child_device_type, child); + } + + /*! \brief Module we are rewriting, so we can lookup global variables. */ + IRModule mod_; + /*! \brief Device domain for every expression from DeviceAnalyzer. */ + std::unique_ptr domains_; +}; + +/*! \brief Rewrite the "on_device" calls (and implicitly re-type-check). */ +tvm::transform::Pass Rewrite() { + auto pass_func = [](Function f, IRModule m, transform::PassContext ctxt) { + return Downcast(RewriteOnDevices().Mutate(f)); + }; + return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite", {}); +} + +/*! \brief Run the remaining phases. */ +tvm::transform::Pass PlanDevicesCore(DLDeviceType default_device_type) { + return tvm::transform::CreateModulePass( + [=](IRModule mod, tvm::transform::PassContext pass_cnxt) -> IRModule { + // Collect the system of constraints for every sub-expression using existing "on_device" + // and "device_copy" calls. + std::unique_ptr domains = DeviceAnalyzer(mod).Analyze(); + VLOG(1) << "Domains after analysis:" << std::endl << domains->ToString(); + + // Choose sensible default devices for every sub-expression if otherwise unconstrained + // by existing "on_device" or "device_copy" calls. + domains = DeviceDefaulter(mod, std::move(domains), default_device_type).Default(); + VLOG(1) << "Domains after defaulting: " << std::endl << domains->ToString(); + + // Insert "device_copy" and "on_device" CallNodes where needed to unambiguously capture + // the above map, and attach additional "param_device_types" and "result_device_type" + // attributes to all function definitions. + return DeviceCapturer(mod, std::move(domains)).Capture(); + }, + /*opt_level=*/0, "PlanDevicesCore", {}); +} + +} // namespace + +/****** +******* Overall composite Pass +*******/ + +// This function is declared in the public . +TVM_DLL tvm::transform::Pass PlanDevices(DLDeviceType default_device_type) { + std::vector passes; + passes.emplace_back(Rewrite()); + passes.emplace_back(PlanDevicesCore(default_device_type)); + return tvm::transform::Sequential(std::move(passes), "PlanDevices"); +} + +TVM_REGISTER_GLOBAL("relay._transform.PlanDevices") + .set_body_typed([](const Device& default_device) { + return PlanDevices(default_device.device_type); + }); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/cpp/relay/relay/transforms/device_domains_test.cc b/tests/cpp/relay/relay/transforms/device_domains_test.cc new file mode 100644 index 000000000000..8f263c3b3273 --- /dev/null +++ b/tests/cpp/relay/relay/transforms/device_domains_test.cc @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/* + * Just a smoke test for the device planner's unification domain, mostly to tease out how we'd + * like to organize our cpp unit tests for functionality that's not obviously a Pass or should + * be exposed via FFI. + */ + +// TODO(mbs): Revisit cpp unit test layout or setup include dir at root of src/ +#include "../../../src/relay/transforms/device_domains.h" + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { +namespace { + +IRModule TestModule() { + return InferType()(tvm::parser::ParseModule("test", R"( + #[version = "0.0.5"] + def @f(%x : Tensor[(3, 7), float32], %y : Tensor[(3, 7), float32]) { + add(%x, %y) + } + )")); +} + +TEST(DeviceDomains, SmokeTest) { + DeviceDomains domains; + IRModule mod = TestModule(); + Function f = Downcast(mod->Lookup("f")); + + DeviceDomainPtr actual_add_domain = domains.DomainForCallee(Downcast(f->body)); + DeviceDomainPtr x_domain = domains.DomainFor(f->params[0]); + DeviceDomainPtr y_domain = domains.DomainFor(f->params[1]); + DeviceDomainPtr result_domain = DeviceDomains::Free(f->ret_type); + std::vector arg_and_results; + arg_and_results.push_back(x_domain); + arg_and_results.push_back(y_domain); + arg_and_results.push_back(result_domain); + DeviceDomainPtr implied_add_domain = DeviceDomains::MakeDomain(std::move(arg_and_results)); + domains.Unify(actual_add_domain, implied_add_domain); + domains.Unify(x_domain, DeviceDomains::ForDeviceType(f->params[0]->checked_type(), kDLCUDA)); + + EXPECT_EQ(domains.ResultDeviceType(y_domain), kDLCUDA); + EXPECT_EQ(domains.ResultDeviceType(result_domain), kDLCUDA); +} + +} // namespace +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_plan_devices.py b/tests/python/relay/test_pass_plan_devices.py new file mode 100644 index 000000000000..2252d8a235c9 --- /dev/null +++ b/tests/python/relay/test_pass_plan_devices.py @@ -0,0 +1,1320 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License + + +"""Unit tests for the PlanDevices pass. We check: + - The pass alone given the expected AST, though we need to manually run InferTypes. + - The pass is idempotent. + - Execution on the VM backend yields the correct result.""" + +import tvm +from tvm import relay +import tvm.testing +import numpy as np + +CPU = tvm.device("cpu") # device_type=1 +GPU = tvm.device("cuda") # device_type=2 +DEFAULT = GPU + +core = tvm.IRModule() +core.import_from_std("core.rly") + + +def rewrite_and_assert(in_mod, expected_mod): + """Manually run the pass and assert it's structurally equals to the expected.""" + actual_mod = relay.transform.InferType()(in_mod) + actual_mod = relay.transform.PlanDevices(DEFAULT)(actual_mod) + actual_mod = relay.transform.InferType()(actual_mod) + expected_mod = relay.transform.InferType()(expected_mod) + if not tvm.ir.structural_equal(actual_mod, expected_mod, True): + # Print everything in full so we can see what's going on when things fail. + print("Input module:") + print(in_mod) + print("Expected module:") + print(expected_mod) + print("Actual module:") + print(actual_mod) + # Assert again so as to see the actual disagreeing sub-expressions. + tvm.ir.assert_structural_equal(actual_mod, expected_mod, True) + + +def eval_and_assert(in_mod: tvm.IRModule, reference_func, args): + """Test the standard compilation flow gives us a function which agrees with the Numpy + reference implementation.""" + if not tvm.runtime.enabled("cuda"): + print("Not evaluating since GPU is not available") + return + with tvm.transform.PassContext(opt_level=3): + compiled = relay.create_executor("vm", mod=in_mod, device=GPU, target="cuda").evaluate() + actual = compiled(*args).numpy() + expected = reference_func(*args) + tvm.testing.assert_allclose(actual, expected) + + +def rand(shape): + return np.random.rand(*shape).astype("float32") + + +def rands(shape, n): + return [rand(shape) for i in range(n)] + + +def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, args): + """Test in_mod against expected_mod and reference_func using args.""" + # Correctness + rewrite_and_assert(in_mod, expected_mod) + # Idempotence + rewrite_and_assert(expected_mod, expected_mod) + # The VM can compile and possibly even run the module + # TODO(mbs): Disabled until VM supports new device planning. + # if not (reference_func is None) and not (args is None): + # eval_and_assert(in_mod, reference_func, args) + + +def test_plain(): + # Everything defaults to GPU + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[2, 2, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_left_add_on_cpu(): + # Force some args to be on CPU, rest default to GPU. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = add(%c, %d); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_left_add_on_cpu_via_copy(): + # As for test_left_add_on_cpu, but with an explicit device_copy. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = device_copy(%0, src_dev_type=1, dst_dev_type=2); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = add(%c, %d); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_both_adds_on_cpu(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = add(%c, %d); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=1); + subtract(%2, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1, 1], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = add(%c, %d); + %3 = on_device(%2, device_type=1, is_fixed=True); + %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + subtract(%4, %5) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_sharing(): + # The same add sub-expression is annotated twice. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = on_device(%0, device_type=1); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=2) { + %0 = add(%a, %b); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = on_device(%0, device_type=1, is_fixed=True); + %3 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %4 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + subtract(%3, %4) + } + """ + ) + + def ref(a, b): + x = np.add(a, b) + return np.subtract(x, x) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_let_on_cpu(): + # The device for a let-bound expression can flow from uses of the let-bound var. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %l = add(%a, %b); + let %r = add(%c, %d); + %0 = on_device(%l, device_type=1); + subtract(%0, %r) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = add(%a, %b); + let %l = on_device(%0, device_type=1, is_fixed=True); + let %r = add(%c, %d); + %1 = device_copy(%l, src_dev_type=1, dst_dev_type=2); + subtract(%1, %r) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_func_param_on_cpu(): + # Devices for function parameters flow to call sites. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %f = fn (%x, %y) { + %0 = add(%x, %y); + on_device(%0, device_type=1) + }; + %1 = %f(%a, %b); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1, 1], result_device_type=1) { + let %f = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + add(%x, %y) + }; + %0 = %f(%a, %b); + %1 = add(%c, %d); + subtract(%0, %1) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_func_result_on_cpu(): + # Devices for call sites flow to function results. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + let %f = fn (%x, %y) { + add(%x, %y) + }; + %0 = %f(%a, %b); + %1 = on_device(%0, device_type=1); + %2 = add(%c, %d); + subtract(%1, %2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=2) { + %0 = fn (%x, %y, param_device_types=[1, 1], result_device_type=1) { + add(%x, %y) + }; + let %f = on_device(%0, device_type=1, is_fixed=True); + %1 = %f(%a, %b); + %2 = on_device(%1, device_type=1, is_fixed=True); + %3 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %4 = add(%c, %d); + subtract(%3, %4) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.add(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_higher_order(): + # The constraint on %a flows back to %y via %f and %h + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %f = fn (%g) { + fn (%a) { + %0 = on_device(%a, device_type=1); + %1 = %g(%0); + add(%1, %x) + } + }; + let %h = fn (%b) { + negative(%b) + }; + %2 = %f(%h); + %3 = %2(%y); + subtract(%x, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) { + let %f = fn (%g, param_device_types=[2], result_device_type=2) { + fn (%a, param_device_types=[1], result_device_type=2) { + %0 = device_copy(%a, src_dev_type=1, dst_dev_type=2); + %1 = %g(%0); + add(%1, %x) + } + }; + let %h = fn (%b, param_device_types=[2], result_device_type=2) { + negative(%b) + }; + %2 = %f(%h); + %3 = %2(%y); + subtract(%x, %3) + } + """ + ) + + def ref(x, y): + def f(g): + return lambda a: np.add(g(a), x) + + def h(b): + return np.negative(b) + + return np.subtract(x, f(h)(y)) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_function_in_tuple(): + # Since %f ends up in a tuple its argument and result is forced to be on the CPU + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) { + %0 = on_device(%b, device_type=1); + add(%a, %0) + }; + let %t = (%f, %x); + %1 = %t.1; + %2 = %t.0; + %2(%1, %y) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + let %f = fn (%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + add(%a, %b) + }; + let %t = (%f, %x); + %0 = %t.1; + %1 = %t.0; + %1(%0, %y) + } + """ + ) + + def ref(x, y): + return np.add(x, y) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_device_copy(): + const = rand((5, 7)) + metatable = {"relay.Constant": [relay.const(const)]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32]) { + %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + add(%0, meta[relay.Constant][0]) + } + """, + "from_string", + None, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=2) { + %0 = device_copy(%x, src_dev_type=1, dst_dev_type=2); + add(%0, meta[relay.Constant][0]) + } + """, + "from_string", + None, + metatable, + ) + + def ref(x): + return np.add(x, const) + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_shape_func(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64]) { + %0 = fn (%y: Tensor[(?), float32]) { + nn.relu(%y) + }; + let %p = on_device(%0, device_type=2, is_fixed=True); + %1 = on_device(%x, device_type=2, is_fixed=True); + %2 = vm.shape_of(%1, dtype="int64"); + %3 = (%2,); + %4 = (%s,); + vm.shape_func(%p, %3, %4, is_input=[False]) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?), float32], %s: Tensor[(1), int64], + param_device_types=[2, 1], result_device_type=1) { + %0 = fn (%y: Tensor[(?), float32], param_device_types=[2], result_device_type=2) { + nn.relu(%y) + }; + let %p = on_device(%0, device_type=2, is_fixed=True); + %1 = vm.shape_of(%x, dtype="int64"); + %2 = (%1,); + %3 = (%s,); + vm.shape_func(%p, %2, %3, is_input=[False]) + } + """ + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_shape_of(): + # We need to use is_fixed=True in the on_device call so that the tensor will be on the GPU. Otherwise the + # result defaults to the result device for @main which is the CPU, thus forcing a copy. + # TODO(mbs): Perhaps the defaulting heuristics are being too clever? + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?, ?), float32]) { + %0 = on_device(%x, device_type=2, is_fixed=True); + vm.shape_of(%0, dtype="int64") + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(?, ?), float32], param_device_types=[2], result_device_type=1) { + vm.shape_of(%x, dtype="int64") + } + """ + ) + + def ref(x): + return x.shape + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_alloc_storage(): + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%size: int64, %alignment: int64) { + memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + } + """, + "from_string", + core, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%size: int64, %alignment: int64, param_device_types=[1, 1], result_device_type=2) { + memory.alloc_storage(%size, %alignment, device_id=0, device_type=2) + } + """, + "from_string", + core, + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_alloc_tensor(): + shape = np.array([3, 2]) + metatable = {"relay.Constant": [relay.const(shape, dtype="int64")]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%sto: Storage[]) { + memory.alloc_tensor(%sto, 0, meta[relay.Constant][0], + const_shape=meta[relay.Constant][0], assert_shape=[]) + } + """, + "from_string", + core, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%sto: Storage[], param_device_types=[2], result_device_type=2) { + %0 = on_device(0, device_type=1, is_fixed=True); + %1 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + memory.alloc_tensor(%sto, %0, %1, const_shape=meta[relay.Constant][0], assert_shape=[]) + } + """, + "from_string", + core, + metatable, + ) + + # Don't try to execute, too fiddly to setup. + exercise(input(), expected(), None, None) + + +def test_reshape_tensor(): + newshape = [2, 4, 2] + metatable = {"relay.Constant": [relay.const(newshape, dtype="int64")]} + + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(2, 8), float32]) { + vm.reshape_tensor(%x, meta[relay.Constant][0], newshape=[2, 4, 2]) + } + """, + "from_string", + None, + metatable, + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(2, 8), float32], param_device_types=[2], result_device_type=2) { + %0 = on_device(meta[relay.Constant][0], device_type=1, is_fixed=True); + vm.reshape_tensor(%x, %0, newshape=[2, 4, 2]) + } + """, + "from_string", + None, + metatable, + ) + + def ref(x): + return np.reshape(x, newshape) + + exercise(input(), expected(), ref, rands((2, 8), 1)) + + +def test_dynamic_input(): + # There's nothing special about inferring devices for partially unknown types. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32]) { + add(%x0, %x1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x0: Tensor[(?, ?), float32], %x1: Tensor[(?, ?), float32], + param_device_types=[2, 2], result_device_type=2) { + add(%x0, %x1) + } + """ + ) + + def ref(x0, x1): + return np.add(x0, x1) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_redundant_annotation(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1); + %2 = subtract(%1, %z); + %3 = on_device(%0, device_type=1); + add(%2, %3) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2], result_device_type=2) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%0, device_type=1, is_fixed=True); + %4 = subtract(%2, %z); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + add(%4, %5) + } + """ + ) + + def ref(x, y, z): + a = np.add(x, y) + return np.add(np.subtract(a, z), a) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_annotate_expr(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2); + %2 = subtract(%1, %z); + on_device(%2, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[2, 2, 1], result_device_type=1) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + subtract(%2, %z) + } + """ + ) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_annotate_all(): + def input(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=1); + %2 = subtract(%1, %z); + on_device(%2, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1], result_device_type=1) { + %0 = add(%x, %y); + subtract(%0, %z) + } + """ + ) + + def ref(x, y, z): + return np.subtract(np.add(x, y), z) + + exercise(input(), expected(), ref, rands((5, 7), 3)) + + +def test_conv_network(): + r"""The network and devices are as follows: + data1 data2 <--- CPU + | | + conv2d conv2d <--- CPU + \ / + \ / + add <--- GPU + | + conv2d <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], + %weight: Tensor[(64, 64, 3, 3), float32]) { + %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %1 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=1); + %4 = add(%2, %3); + %5 = on_device(%4, device_type=2); + %6 = nn.conv2d(%5, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + on_device(%6, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%data1: Tensor[(1, 64, 56, 56), float32], %data2: Tensor[(1, 64, 56, 56), float32], + %weight: Tensor[(64, 64, 3, 3), float32], param_device_types=[1, 1, 1], result_device_type=1) { + %0 = nn.conv2d(%data1, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = nn.conv2d(%data2, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]); + %3 = on_device(%2, device_type=1, is_fixed=True); + %4 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %5 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %6 = add(%4, %5); + %7 = on_device(%6, device_type=2, is_fixed=True); + %8 = device_copy(%7, src_dev_type=2, dst_dev_type=1); + nn.conv2d(%8, %weight, padding=[1, 1, 1, 1], channels=64, kernel_size=[3, 3]) + } + """ + ) + + # Don't try to execute, we don't have a reference conv2d + exercise(input(), expected(), None, None) + + +def test_tuple_get_item(): + # Note that the device copy should be placed after projection rather than before. This is handled by + # a heuristic in the pass. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(3, 3, 4), float32]) { + let %t = split(%x, indices_or_sections=3); + %0 = on_device(%t, device_type=1); + %1 = on_device(%t, device_type=1); + %2 = %0.0; + %3 = %1.1; + %4 = subtract(%2, %3); + on_device(%4, device_type=2) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(3, 3, 4), float32], param_device_types=[1], result_device_type=2) { + %0 = split(%x, indices_or_sections=3); + let %t = on_device(%0, device_type=1, is_fixed=True); + %1 = %t.0; + %2 = on_device(%1, device_type=1, is_fixed=True); + %3 = %t.1; + %4 = on_device(%3, device_type=1, is_fixed=True); + %5 = device_copy(%2, src_dev_type=1, dst_dev_type=2); + %6 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + subtract(%5, %6) + } + """ + ) + + def ref(x): + t = np.split(x, 3) + return np.subtract(t[0], t[1]) + + exercise(input(), expected(), ref, rands((3, 3, 4), 1)) + + +def test_propogation(): + r""" The network and devices are as follows: + x <--- CPU + | + log <--- CPU + / \ + log2 log10 <--- GPU + \ / + add <--- GPU + | + tan <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32]) { + %0 = log(%x); + %1 = on_device(%0, device_type=1); + %2 = log2(%1); + %3 = on_device(%0, device_type=1); + %4 = log10(%3); + %5 = on_device(%2, device_type=2); + %6 = on_device(%4, device_type=2); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2); + %9 = tan(%8); + on_device(%9, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], param_device_types=[1], result_device_type=1) { + %0 = log(%x); + %1 = on_device(%0, device_type=1, is_fixed=True); + %2 = device_copy(%1, src_dev_type=1, dst_dev_type=2); + %3 = on_device(%0, device_type=1, is_fixed=True); + %4 = device_copy(%3, src_dev_type=1, dst_dev_type=2); + %5 = log2(%2); + %6 = log10(%4); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2, is_fixed=True); + %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + tan(%9) + } + """ + ) + + def ref(x): + y = np.log(x) + return np.tan(np.add(np.log2(y), np.log10(y))) + + exercise(input(), expected(), ref, rands((5, 7), 1)) + + +def test_fusible_network(): + r""" The network is as follows: + x y <--- GPU + \ / + add <--- GPU + / \ + negative \ <--- CPU + \ \ + \ negative <--- GPU + \ / + add <--- GPU + | + negative <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2); + %2 = negative(%1); + %3 = on_device(%2, device_type=1); + %4 = negative(%0); + %5 = add(%3, %4); + %6 = on_device(%5, device_type=2); + %7 = negative(%6); + on_device(%7, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], param_device_types=[2, 2], result_device_type=1) { + %0 = add(%x, %y); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + %3 = negative(%2); + %4 = on_device(%3, device_type=1, is_fixed=True); + %5 = device_copy(%4, src_dev_type=1, dst_dev_type=2); + %6 = negative(%0); + %7 = add(%5, %6); + %8 = on_device(%7, device_type=2, is_fixed=True); + %9 = device_copy(%8, src_dev_type=2, dst_dev_type=1); + negative(%9) + } + """ + ) + + def ref(x, y): + z = np.add(x, y) + return np.negative(np.add(np.negative(z), np.negative(z))) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_unpropagatable_graph(): + r"""The network is as follows: + a b <--- CPU + \ / + \ / c d <--- GPU + \ / \ / + add \ / <--- CPU + \ \ / + \ multiply <--- GPU + \ / + subtract <--- CPU + | + <--- CPU + """ + + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32]) { + %0 = add(%a, %b); + %1 = multiply(%c, %d); + %2 = on_device(%0, device_type=1); + %3 = on_device(%1, device_type=2); + %4 = subtract(%2, %3); + on_device(%4, device_type=1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + %c: Tensor[(5, 7), float32], %d: Tensor[(5, 7), float32], + param_device_types=[1, 1, 2, 2], result_device_type=1) { + %0 = multiply(%c, %d); + %1 = on_device(%0, device_type=2, is_fixed=True); + %2 = add(%a, %b); + %3 = device_copy(%1, src_dev_type=2, dst_dev_type=1); + subtract(%2, %3) + } + """ + ) + + def ref(a, b, c, d): + return np.subtract(np.add(a, b), np.multiply(c, d)) + + exercise(input(), expected(), ref, rands((5, 7), 4)) + + +def test_conditional(): + # The conditional is over a function type, thus exercising the first-order/higher-order domain handling. + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32]) { + let %f = fn (%a) { + %0 = on_device(%y, device_type=1, is_fixed=True); + add(%a, %0) + }; + let %g = fn (%a1) { + subtract(%a1, %y) + }; + let %h = if (%x) { + %f + } else { + %g + }; + %h(%z) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: bool, %y: Tensor[(5, 7), float32], %z: Tensor[(5, 7), float32], + param_device_types=[1, 1, 1], result_device_type=1) { + let %f = fn (%a, param_device_types=[1], result_device_type=1) { + add(%a, %y) + }; + let %g = fn (%a1, param_device_types=[1], result_device_type=1) { + subtract(%a1, %y) + }; + let %h = if (%x) { + %f + } else { + %g + }; + %h(%z) + } + """ + ) + + def ref(x, y, z): + def f(a): + return np.add(a, y) + + def g(a): + return np.subtract(a, y) + + h = f if x else g + return h(z) + + exercise(input(), expected(), ref, [True, rand((5, 7)), rand((5, 7))]) + + +def test_global(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + %0 = on_device(%b, device_type=1); + add(%a, %0) + } + + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) -> Tensor[(5, 7), float32] { + @f(%y, %x) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @f(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) -> Tensor[(5, 7), float32] { + %0 = device_copy(%b, src_dev_type=1, dst_dev_type=2); + add(%a, %0) + } + + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[1, 2], result_device_type=2) -> Tensor[(5, 7), float32] { + @f(%y, %x) + } + """ + ) + + def ref(x, y): + def f(a, b): + return np.add(a, b) + + return f(x, y) + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +def test_ref(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32]) { + let %r = ref(%x); + %0 = on_device(%y, device_type=1); + ref_write(%r, %0); + %1 = ref_read(%r); + add(%x, %1) + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%x: Tensor[(5, 7), float32], %y: Tensor[(5, 7), float32], + param_device_types=[2, 1], result_device_type=2) { + let %r = ref(%x); + %0 = device_copy(%y, src_dev_type=1, dst_dev_type=2); + ref_write(%r, %0); + %1 = ref_read(%r); + add(%x, %1) + } + """ + ) + + def ref(x, y): + r = {"value": x} + r["value"] = y + return np.add(x, r["value"]) + + # Don't try to execute, no backend currently supports both hetrogeneous devices and references. + exercise(input(), expected(), None, None) + + +def test_adt(): + def input(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32]) { + %0 = on_device(%y, device_type=1, is_fixed=True); + %1 = Nil; + %2 = Cons(%0, %1); + let %l = Cons(%x, %2); + match? (%l) { + Cons(%z, _) => %z + } + } + """ + ) + + def expected(): + return tvm.parser.fromtext( + """ + #[version = "0.0.5"] + type List[A] { + Cons(A, List[A]), + Nil, + } + def @main(%x : Tensor[(5, 7), float32], %y : Tensor[(5, 7), float32], + param_device_types=[1, 1], result_device_type=1) { + %0 = Nil; + %1 = Cons(%y, %0); + let %l = Cons(%x, %1); + match? (%l) { + Cons(%z, _) => %z + } + } + """ + ) + + def ref(x, y): + l = [x, y] + return l[0] + + exercise(input(), expected(), ref, rands((5, 7), 2)) + + +if __name__ == "__main__": + import sys + import pytest + + sys.exit(pytest.main([__file__] + sys.argv[1:]))