Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RELAY] [VIRTUALDEVICE] Change syntax for device planning and store parameter virtual devices in virtual_device_ field #10352

Merged
merged 32 commits into from
Feb 25, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
dc6053b
parent 33082e0032fb57b0516ad7e3eabd11fe0203437e
electriclilies Jan 25, 2022
69acc6d
Change plan devices tests to use the new syntax for function parameters
electriclilies Feb 16, 2022
a92ab15
Fix free var problem
electriclilies Feb 18, 2022
30fff16
Fix attribute parsing if there is virtual device; most device plannin…
electriclilies Feb 18, 2022
167e4fd
fixed lambda lifting
electriclilies Feb 19, 2022
e3a2ec2
Debugging high order functions -- right now FunctionOnDevice and Bind…
electriclilies Feb 19, 2022
a95ae9d
tests pass wootgit status
electriclilies Feb 19, 2022
e62bce8
Remove FunctionOnDevice from device planner
electriclilies Feb 22, 2022
904b93a
Don't use MaybeFunctionOnDevice in VM compiler
electriclilies Feb 22, 2022
5224b64
Remove MaybeFunctionOnDevice from lambda lifter
electriclilies Feb 22, 2022
fd97879
Delete FunctionOnDevice and MaybeFunctionOnDevice!
electriclilies Feb 22, 2022
3236a2d
Reomve GetFunctionResultVirtualDevice
electriclilies Feb 22, 2022
e269437
Remove GetFunctionParamVirtualDevice
electriclilies Feb 22, 2022
e642a5a
lint
electriclilies Feb 22, 2022
6c6e5e5
lint
electriclilies Feb 22, 2022
a344e15
Python formatting
electriclilies Feb 22, 2022
231d040
Remove FunctionOnDevice python test
electriclilies Feb 23, 2022
0f5e6d2
Fix bug in binds & debug output
electriclilies Feb 23, 2022
f71601d
Fix text printer
electriclilies Feb 23, 2022
e8e5a09
lint
electriclilies Feb 23, 2022
ad4c97a
Remove function on device from fold constant tests
electriclilies Feb 23, 2022
78d063e
Mark nits
electriclilies Feb 24, 2022
1f10769
Revert behavior of bind
electriclilies Feb 24, 2022
b20dce7
clean up debug
electriclilies Feb 24, 2022
8bacaba
Make ExprBinder public interface and use instead of Bind
electriclilies Feb 24, 2022
41870f0
Fix lambda lift
electriclilies Feb 24, 2022
3d346d9
This is broken but not sure how to fix
electriclilies Feb 24, 2022
c4939e6
passes all device planning tests yay!
electriclilies Feb 24, 2022
001a063
Add substitution helper and use in device planner
electriclilies Feb 24, 2022
2075dad
Remove unnecessary check
electriclilies Feb 24, 2022
9dda0c1
Respond to comments
electriclilies Feb 25, 2022
5bf526a
Update comment
electriclilies Feb 25, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,6 @@ constexpr const char* kTarget = "target";
*/
constexpr const char* kGlobalSymbol = "global_symbol";

/*!
* \brief The \p VirtualDevice which will hold each of the functions parameters.
*
* Only supported on Relay \p Functions. Generally added by the \p PlanDevices pass, but
* may be included as an annotation on user programs.
*
* Type: Array<VirtualDevice>
*/
constexpr const char* kParamVirtualDevice = "param_virtual_devices";

} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
18 changes: 18 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,10 @@ TVM_DLL Pass PlanDevices(CompilationConfig config);
/*!
* \brief Bind the free variables to a Relay expression. This is a helper
* function usually called by other pass functions to help optimizations.
* If any free variables are introduced into a function, those are added
* to the functoin parameters.
* Additionally this may change the order of parameters if you map a variable
* to a variable.
*
* \param expr The input expression.
* \param binds The variable to expression map that will be used to help the
Expand All @@ -508,6 +512,20 @@ TVM_DLL Pass PlanDevices(CompilationConfig config);
*/
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

/*!
* \brief Bind the free variables to a Relay expression. This is a helper
* function usually called by other pass functions to help optimizations.
* Differs from Bind in that it does not implicitly add any new free variables
* to function parameters.
*
* \param expr The input expression.
* \param binds The variable to expression map that will be used to help the
* binding.
*
* \return The updated expression.
*/
TVM_DLL Expr ExprBind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

/*!
* \brief Apply rewrite rules to rewrite the expr in post DFS order. This
* function is used as a helper function to rewrtie an expression in a pass.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/target/virtual_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ class VirtualDeviceCache {
*
* Type: VirtualDevice
*/
constexpr const char* kVirtualDevice = "result_virtual_device";
constexpr const char* kVirtualDevice = "virtual_device";

} // namespace tvm

Expand Down
34 changes: 30 additions & 4 deletions src/parser/parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -450,9 +450,13 @@ class Parser {
*
* "x" -> Var("x"), these are needed to map from the raw string names
* to unique variable nodes.
* If a virtual device is specified, sets the virtual device of the variable.
*/
Var BindVar(const std::string& name, const relay::Type& type_annotation) {
Var BindVar(const std::string& name, const relay::Type& type_annotation,
Optional<VirtualDevice> virtual_device = Optional<VirtualDevice>()) {
auto var = Var(name, type_annotation);
var->virtual_device_ = virtual_device.value_or(VirtualDevice::FullyUnconstrained());
VLOG(1) << "Binding var named " << name << " to variable node " << PrettyPrint(var);
this->expr_scopes.Add(name, var);
return var;
}
Expand Down Expand Up @@ -1107,11 +1111,26 @@ class Parser {
[&]() {
auto token = Match(TokenType::kLocal);
auto string = token.ToString();

// The fake attributes where the virtual device is specified.
VirtualDevice virtual_device;
if (WhenMatch(TokenType::kLCurly)) {
Map<String, ObjectRef> fake_attrs = ParseAttrs();
VLOG(9) << "Fake attributes for function parameter: " << fake_attrs;
Match(TokenType::kRCurly);
if (fake_attrs.size() == 1 && fake_attrs.count(kVirtualDevice)) {
ICHECK(fake_attrs[kVirtualDevice].as<VirtualDeviceNode>())
<< "Expected the " << kVirtualDevice
<< " to have type VirtualDeviceNode, but got " << virtual_device->GetTypeKey();
virtual_device = Downcast<VirtualDevice>(fake_attrs[kVirtualDevice]);
}
}

Type type;
if (WhenMatch(TokenType::kColon)) {
type = ParseType();
}
return BindVar(string, type);
return BindVar(string, type, virtual_device);
},
[&] {
auto is_ident = Lookahead(1)->token_type == TokenType::kIdentifier;
Expand Down Expand Up @@ -1144,8 +1163,15 @@ class Parser {
ICHECK(vid.as<VirtualDeviceNode>())
<< "Expected the " << kVirtualDevice << " to have type VirtualDeviceNode, but got "
<< vid->GetTypeKey();
raw_attrs.erase(kVirtualDevice);
Function func = relay::Function(params, body, ret_type, generics, DictAttrs(raw_attrs));

DictAttrs attrs;
// Don't fill the raw_attrs in if there's nothing other than kVirtualDevice in the
// attributes
if (raw_attrs.size() > 1) {
raw_attrs.erase(kVirtualDevice);
attrs = DictAttrs(raw_attrs);
}
Function func = relay::Function(params, body, ret_type, generics, attrs);
func->virtual_device_ = vid;
return func;
} else {
Expand Down
7 changes: 5 additions & 2 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,13 @@ Doc RelayTextPrinter::AllocVar(const Var& var) {
}
Doc val = GetUniqueName("%" + name);
memo_[var] = val;
if (!var->virtual_device()->IsFullyUnconstrained()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any issue with this being used for both param- and let-bound vars even though we don't parse annots for let-bound vars?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess for now, let-bound variables don't have their virtual devices set so it theoretically won't be triggered.. I haven't seen any issues in CI related to this but I could split the function into two if you'd like
Eventually we will annotate let-bound variables and at that point we will have to parse the fake attrs for let bound variables

val << " {" << kVirtualDevice << "=" << PrintAttributeValue(var->virtual_device()) << "}";
}
if (var->type_annotation.defined()) {
val << ": " << Print(var->type_annotation);
}

val << PrintOptionalInfo(var);
return val;
}
Expand Down Expand Up @@ -445,7 +449,7 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
for (const Doc& d : PrintDictAttrs(fn->attrs)) {
params.push_back(d);
}
if (fn->virtual_device() != VirtualDevice::FullyUnconstrained()) {
if (!fn->virtual_device()->IsFullyUnconstrained()) {
Doc vid_doc;
vid_doc << kVirtualDevice << "=" << PrintAttributeValue(fn->virtual_device());
params.push_back(vid_doc);
Expand All @@ -454,7 +458,6 @@ Doc RelayTextPrinter::PrintFunc(const Doc& prefix, const relay::Function& fn) {
if (fn->ret_type.defined()) {
doc << "-> " << Print(fn->ret_type) << " ";
}

doc << PrintBody(fn->body);
return doc;
}
Expand Down
18 changes: 6 additions & 12 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,21 +252,16 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
// Do that flattening on-the-fly here.
Function inner_func = Downcast<Function>(func->body);
std::vector<Var> params;
std::vector<VirtualDevice> param_virtual_devices;
params.reserve(func->params.size() + inner_func->params.size());
param_virtual_devices.reserve(func->params.size() + inner_func->params.size());
param_device_indexes.reserve(func->params.size() + inner_func->params.size());
for (size_t i = 0; i < func->params.size(); ++i) {
params.emplace_back(func->params[i]);
VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(func.get(), i);
param_virtual_devices.push_back(param_virtual_device);
param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));
param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we get some payoff at last!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah!!

}
for (size_t i = 0; i < inner_func->params.size(); ++i) {
params.emplace_back(inner_func->params[i]);
VirtualDevice param_virtual_device = GetFunctionParamVirtualDevice(inner_func.get(), i);
param_virtual_devices.push_back(param_virtual_device);
param_device_indexes.push_back(GetDeviceIndex(param_virtual_device));

param_device_indexes.push_back(GetDeviceIndex(inner_func->params[i]->virtual_device()));
}
std::vector<TypeVar> type_params;
type_params.reserve(func->type_params.size() + inner_func->type_params.size());
Expand All @@ -278,13 +273,12 @@ class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
}
Function flattened_func = Function(params, inner_func->body, inner_func->ret_type,
type_params, func->attrs, func->span);
VisitExpr(MaybeFunctionOnDevice(flattened_func, param_virtual_devices,
GetFunctionResultVirtualDevice(inner_func.get())));
flattened_func->virtual_device_ = inner_func->virtual_device();
VisitExpr(flattened_func);
} else {
param_device_indexes.reserve(func->params.size());
for (size_t i = 0; i < func->params.size(); ++i) {
param_device_indexes.push_back(
GetDeviceIndex(GetFunctionParamVirtualDevice(func.get(), i)));
param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));
}
VisitExpr(func);
}
Expand Down
9 changes: 5 additions & 4 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,22 +111,21 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
auto free_type_vars = FreeTypeVars(func, module_);

Array<Var> captured_vars;
std::vector<VirtualDevice> captured_var_virtual_devices;
bool recursive = false;
for (const auto& var : free_vars) {
if (!letrec_.empty() && var == letrec_.back()) {
recursive = true;
continue;
}
captured_vars.push_back(var);
captured_var_virtual_devices.push_back(GetVirtualDevice(var));
}

// Freshen all the captured vars.
Array<Var> typed_captured_vars;
Map<Var, Expr> rebinding_map;
for (auto free_var : captured_vars) {
auto var = Var(free_var->name_hint(), free_var->checked_type());
var->virtual_device_ = GetVirtualDevice(free_var);
typed_captured_vars.push_back(var);
rebinding_map.Set(free_var, var);
}
Expand Down Expand Up @@ -173,6 +172,8 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
if (captured_vars.empty() && free_type_vars.empty()) {
lifted_func = Function(body->params, body->body, body->ret_type, body->type_params,
body->attrs, body->span);
// We also need to copy the virtual device
lifted_func->virtual_device_ = body->virtual_device();
} else {
// When a closure is locally bound in a program, we have its full type information
// avalible to us.
Expand All @@ -187,14 +188,14 @@ class LambdaLifter : public transform::DeviceAwareExprMutator {
// construct the "closure" function with fully annotated arguments, no longer relying
// on type inference.
size_t before_arity = body->params.size();
VLOG(9) << "Binding " << rebinding_map << " into\n" << PrettyPrint(body->body);
auto rebound_body = WithFields(func, func->params, Bind(body->body, rebinding_map));
size_t after_arity = rebound_body->params.size();
CHECK_EQ(before_arity, after_arity);
lifted_func =
Function(typed_captured_vars, rebound_body, /*ret_type=*/func->func_type_annotation(),
free_type_vars, /*attrs=*/{}, func->span);
lifted_func =
MaybeFunctionOnDevice(lifted_func, captured_var_virtual_devices, result_virtual_device);
lifted_func->virtual_device_ = result_virtual_device;
lifted_func = MarkClosure(lifted_func);
}

Expand Down
42 changes: 23 additions & 19 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,45 +472,45 @@ class ExprBinder : public MixedModeMutator, PatternMutator {
const tvm::Map<Var, Expr>& args_map_;
};

// This function should be called SubstAndBind, since it assumes any variables introduced
// in the substitution right hand side should be implicitly bound in the function.
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).VisitExpr(func->body);
Array<Var> new_params;
std::vector<VirtualDevice> new_param_virtual_devices;
for (size_t i = 0; i < func->params.size(); ++i) {
if (!args_map.count(func->params[i])) {
new_params.push_back(func->params[i]);
new_param_virtual_devices.push_back(GetFunctionParamVirtualDevice(func, i));
}
}
if (new_body.same_as(func->body) && new_params.size() == func->params.size()) {
return expr;
}

auto ret =
Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
ret =
MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func));
ret->virtual_device_ = func->virtual_device();

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> set;
for (const auto& v : FreeVars(expr)) {
set.insert(v);
}
for (const auto& v : FreeVars(ret)) {
if (set.count(v) == 0) {
new_params.push_back(v);
if (!GetFunctionResultVirtualDevice(func)->IsFullyUnconstrained()) {
// TODO(mbs): The function has been annotated with a device, which means we are supposed
// to be preserving device annotations on every transformation. However there's no
// such context for the free vars in args_map.
LOG(WARNING) << "introduced free var '" << PrettyPrint(v)
<< "' into function body but no device is known for it";
if (set.count(v) == 0) {
new_params.push_back(v);
if (!v->virtual_device()->IsFullyUnconstrained()) {
// TODO(mbs): The function has been annotated with a device, which means we are supposed
// to be preserving device annotations on every transformation. However there's no
// such context for the free vars in args_map.
LOG(WARNING) << "introduced free var '" << PrettyPrint(v)
<< "' into function body but no device is known for it";
}
}
}
new_param_virtual_devices.push_back(VirtualDevice::FullyUnconstrained());
}
}
ret =
Function(new_params, new_body, func->ret_type, func->type_params, func->attrs, func->span);
ret =
MaybeFunctionOnDevice(ret, new_param_virtual_devices, GetFunctionResultVirtualDevice(func));

VLOG(4) << "Expr:\n" << expr;
VLOG(4) << "Ret:\n" << ret;

ICHECK_EQ(FreeVars(expr).size(), FreeVars(ret).size());
return std::move(ret);
} else {
Expand All @@ -528,6 +528,10 @@ TVM_REGISTER_GLOBAL("relay.ir.Bind").set_body([](TVMArgs args, TVMRetValue* ret)
}
});

Expr ExprBind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
return ExprBinder(args_map).VisitExpr(expr);
}

void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,
std::function<void(const LetNode*)> post_visit) {
std::stack<const LetNode*> stack;
Expand Down
44 changes: 1 addition & 43 deletions src/relay/op/memory/on_device.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>

#include "../../transforms/infer_layout_utils.h"
#include "../type_relations.h"
Expand Down Expand Up @@ -142,48 +143,5 @@ OnDeviceProps GetOnDeviceProps(const Expr& expr) {
return {};
}

Function FunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
VirtualDevice result_virtual_device) {
auto func = WithAttr(
WithFields(std::move(function), {}, {}, {}, {}, {}, std::move(result_virtual_device)),
tvm::attr::kParamVirtualDevice, std::move(param_virtual_devices));
VLOG(1) << "Annotated func: " << PrettyPrint(func);
return func;
}

TVM_REGISTER_GLOBAL("relay.op.annotation._make.FunctionOnDevice").set_body_typed(FunctionOnDevice);

Function MaybeFunctionOnDevice(Function function, Array<VirtualDevice> param_virtual_devices,
VirtualDevice result_virtual_device) {
if (std::all_of(param_virtual_devices.begin(), param_virtual_devices.end(),
[](const VirtualDevice& virtual_device) {
return virtual_device->IsFullyUnconstrained();
}) &&
result_virtual_device->IsFullyUnconstrained()) {
// Nothing to annotate.
return function;
}
return FunctionOnDevice(function, std::move(param_virtual_devices),
std::move(result_virtual_device));
}

VirtualDevice GetFunctionResultVirtualDevice(const FunctionNode* function_node) {
return function_node->virtual_device();
}

VirtualDevice GetFunctionParamVirtualDevice(const FunctionNode* function_node, size_t i) {
ICHECK_LT(i, function_node->params.size())
<< "param index " << i << " out of range for function of arity "
<< function_node->params.size();
auto opt_array = function_node->GetAttr<Array<VirtualDevice>>(tvm::attr::kParamVirtualDevice);
if (!opt_array) {
// No annotation.
return VirtualDevice::FullyUnconstrained();
}
ICHECK_EQ(opt_array.value().size(), function_node->params.size())
<< "annotation parameters do not match function arity";
return opt_array.value()[i];
}

} // namespace relay
} // namespace tvm
Loading