Skip to content

Commit

Permalink
[checkpoint] Andrew's comments from apache#9038
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Sep 27, 2021
1 parent 4d7c798 commit 5553be2
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 13 deletions.
41 changes: 34 additions & 7 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,44 @@ namespace tvm {
namespace relay {

/*!
* \brief Attributes for the "on_device" operator.
* \brief Attributes for the "on_device" special operator.
*
* The relay call
* The Relay call (aka 'annotation'):
* \code
* on_device(expr, device_type=2)
* on_device(sub_expr, device_type=2)
* \endcode
* denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2
* (i.e. \p kDLCuda). Semantically the operator is the identity function.
* constrains \p sub_expr to execute and store its result on a device with \p DLDeviceType \p 2
* (i.e. a \p kDLCuda device). However the annotation itself may appear in an expression to be
* executed and stored on a different device. If so the compiler will automatically insert a
* "device_copy" call to mediate the transition between devices.
*
* See also FunctionOnDeviceAttrs in include/relay/attrs/function.h for the function-level
* companion.
* E.g.: Assuming %x and %y reside on the GPU and %z on the CPU then:
* \code
* multiply(on_device(add(%x, %y), device_type=2), %z)
* \endcode
* indicates the \p add should execute on the GPU but the \p multiply should execute on the CPU.
* The compiler will rewrite this to:
* \code
* multiply(device_copy(add(%x, %y), src_dev_type=2, dst_dev_type=1), %z)
* \endcode
*
* The Relay call
* \code
* on_device(sub_expr, device_type=2, is_fixed=True)
* \endcode
* is similar to the above, however the annotation itself must appear in an expression on the
* same device. The compiler will check the devices are consistent, and will not insert any
* "device_copy" call. This form of annotation shouldn't be necessary in user programs. However
* it is needed by the \p PlanDevices pass to fully specify the results of device planning so that
* the pass is idempotent.
*
* E.g.: The following program is equivalent to the above:
* \code
* let %a = on_device(add(%x, %y), device_type=2, is_fixed=True)
* multiply(device_copy(%a, src_dev_type=2, dst_dev_type=1), %z)
* \endcode
* The "on_device" annotation with \p is_fixed=True indicates unambiguously that \p %a is stored
* on the GPU.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
// TODO(mbs): Replace device types with TargetDevice.
Expand Down
1 change: 0 additions & 1 deletion python/tvm/script/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@ class TVMScriptParser(Transformer):
}

def __init__(self, base_lienno):
super().__init__(self)
self.context = None

self.base_lineno = base_lienno
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) {
return Call(OnDeviceOp(), {std::move(expr)}, Attrs(std::move(attrs)), /*type_args=*/{}, span);
}

Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) {
Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed) {
if (device_type == kInvalidDeviceType) {
// Undefined signals no annotation is required.
return expr;
Expand Down
9 changes: 6 additions & 3 deletions src/relay/op/annotation/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ const Op& OnDeviceOp();

/*!
* \brief Wraps \p expr in an "on_device" CallNode for \p device_type and \p is_fixed.
*
* See \p OnDeviceAttrs for an overview.
*/
Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed);

Expand All @@ -52,7 +54,7 @@ Expr OnDevice(Expr expr, DLDeviceType device_type, bool is_fixed);
* - \p expr is a constructor. There should probably be device polymorphic but are in an
* in-between state at the moment.
*/
Expr OptOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed);
Expr MaybeOnDevice(Expr expr, DLDeviceType device_type, bool is_fixed);

/*! \brief Result of \p GetOnDeviceProps. */
struct OnDeviceProps {
Expand Down Expand Up @@ -95,8 +97,9 @@ Function FunctionOnDevice(Function function, const std::vector<DLDeviceType>& pa
* \brief As for \p FunctionOnDevice, but returns \p function unchanged if all parameters and
* result device types are \p kInvalidDeviceType.
*/
Function OptFunctionOnDevice(Function function, const std::vector<DLDeviceType>& param_device_types,
DLDeviceType result_device_type);
Function MaybeFunctionOnDevice(Function function,
const std::vector<DLDeviceType>& param_device_types,
DLDeviceType result_device_type);

/*!
* \brief Returns the device type for the resut of \p function_node, or \p kInvalidDeviceType
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/memory/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Expr DeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type)
* a device of type \p src_dev_type but then copied to a device of type \p dst_dev_type.
* However, return \p expr directly if \p src_dev_type equals \p dst_dev_type.
*/
Expr OptDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type);
Expr MaybeDeviceCopy(Expr expr, DLDeviceType src_dev_type, DLDeviceType dst_dev_type);

/*! \brief Result of \p GetDeviceCopyProps. */
struct DeviceCopyProps {
Expand Down

0 comments on commit 5553be2

Please sign in to comment.