Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] DietCode: An Auto-Scheduler for Dynamic Tensor Programs #72

Merged
merged 2 commits into from
May 31, 2022
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 207 additions & 0 deletions rfcs/0072-dynamic-autoscheduler.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
- Feature Name: DietCode: An Auto-Scheduler for Dynamic Tensor Programs
- Start Date: (2022-05-10)
- RFC PR: [apache/tvm-rfcs#xx](https://github.com/apache/tvm-rfcs/pull/xx)
- GitHub Issue: [apache/tvm#yy](https://github.com/apache/tvm/pull/yy)

# Summary
[summary]: #summary

We propose to integrate DietCode, an auto-scheduler for dynamic tensor programs,
to AutoTIR. DietCode offers the following features:
- A shape-generic search space to cover possible shapes in dynamic shape
workloads.
- A dynamic-shape aware cost model to judge the quality of schedule candidates.
- Enhancement to the TVM CUDA codegen for imperfect tiling.

DietCode has been published by MLSys 2022 so please see [the
paper](https://proceedings.mlsys.org/paper/2022/hash/fa7cdfad1a5aaf8370ebeda47a1ff1c3-Abstract.html)
for more details and evaluations. Meanwhile, the latest DietCode codebase is also publicly
available [here](https://github.com/UofT-EcoSystem/DietCode).

# Motivation
[motivation]: #motivation

Achieving high performance for compute-intensive operators in machine learning
workloads is a crucial but challenging task. Many machine learning and system
practitioners rely on vendor libraries or auto-schedulers to do the job. While
the former requires significant engineering efforts, the latter in TVM only supports
static-shape workloads in existing works. It is difficult, if not impractical,
to apply the existing auto-scheduler directly to **dynamic-shape workloads**, as
this leads to extremely long tuning time.

We observe that the key challenge faced by existing auto-schedulers when
handling a dynamic-shape workload is that they cannot construct a conclusive search
space for all the possible shapes of the workload, because their search space is
shape-dependent. To address this, this RFC aims to add dynamic-shape supports to
AutoTIR by integrating DietCode framework, which constructs **a shape-generic
search space and cost model** to auto-schedule dynamic-shape workloads
efficiently.

Our evaluation shows that DietCode has the following key strengths when
auto-scheduling an entire model end-to-end:

1. reduces the auto-scheduling time by up to 5.88x less than the current
auto-scheduler on 8 uniformly sampled dynamic shapes, and
1. improves performance by up to 69.5% better than the auto-scheduler and 18.6%
better than the vendor library. All these advantages make DietCode an
efficient and practical solution for dynamic-shape workloads.


# Guide-Level Explanation
[guide-level-explanation]: #guide-level-explanation

The existing experiments are largely conducted with auto-scheduler. However,
ArmageddonKnight marked this conversation as resolved.
Show resolved Hide resolved
having been syncing with the AutoTIR team for quarters, we plan to integrate
this RFC to MetaSchedule (AutoTIR), because it provides more systematic
interface and cleaner integration path with less hacks.

To provide an example of additional information users are required to feed the
system:

```python
# A symbolic shape constraint
T = tir.ShapeVar('T’)
# The candidate values of `T`
T_vals = list(range(1, 128))

task = Task(func=Dense,
args=(16*T, 768, 2304),
shape_vars=(T,),
ArmageddonKnight marked this conversation as resolved.
Show resolved Hide resolved
wkl_insts=(T_vals,)
wkl_inst_weights=([1. for _ in T_vals],))
```

To enable auto-scheduling for dynamic shape workloads, users only need to:
1. Have `ShapeVar` in the TE/TensorIR compututation.
2. Specify the weight/distribution of each shape value.

ArmageddonKnight marked this conversation as resolved.
Show resolved Hide resolved
Notes:
1. Symbolic constraint is required additional in Relay, but could be inferred
automatically after Relax is introduced;
2. The proposed interface does not change any existing functionality.

# Reference-Level Explanation
[reference-level-explanation]: #reference-level-explanation

Here is an overview of the DietCode framework design.

<img src="https://raw.githubusercontent.com/UofT-EcoSystem/DietCode/main/docs/figures/DietCode.jpg" width="61.8%" />
ArmageddonKnight marked this conversation as resolved.
Show resolved Hide resolved

- We construct **a shape-generic search space that consists of micro-kernels**,
an incomplete program that carries out a tile of the complete computation, to
efficiently support dynamic-shape workloads.

We use the hardware constraints (e.g., the maximum number of threads, the
amount of shared and local memory) rather than the shape information to
determine the micro-kernel candidates. Those candidates serve as the building
blocks and are executed repeatedly to carry out a workload instance (defined
as an static-shape instance of the dynamic-shape workload).
- We build a **micro-kernel-based cost model**. The key insight is that the cost
of a complete program *P* that is made up of a micro-kernel *M* can be
decomposed into two parts:

1. A shape-generic cost function *f*<sub>MK</sub> that predicts the cost of
*M*, and
1. A shape-dependent adaption cost function *f*<sub>adapt</sub> that defines
the penalty of porting *M* to *P*.

While *f*<sub>MK</sub> is a function that has to be learned and updated by
real hardware measurements during the auto-scheduling process,
*f*<sub>adapt</sub> is a simple term that can be evaluated using the core
occupancy and the padding ratio (in other words, it does not require feature
extraction from the schedules).

# Drawbacks
[drawbacks]: #drawbacks

- The current compilation workflow generates one program per input shape.
Although we can merge those static-shape programs into a single dynamic-shape
program like the following code snippet:
```CUDA
__global__ void default_function(float* X, float* W, float* Y,
const int T)
// Note the `T` here.
```
Our evaluations indicate that this program has at least 5% worse performance
compared with the static-shape alternatives. Hence, we decide to sacrifice the
binary size for the runtime performance, which can potentially be problematic
when the hardware resources are limited.

ArmageddonKnight marked this conversation as resolved.
Show resolved Hide resolved
# Rationale and Alternatives
[rationale-and-alternatives]: #rationale-and-alternatives

There is an approach proposed by [Nimble](https://arxiv.org/pdf/2006.03031.pdf),
which partitions a range of dynamic shape to buckets and tunes one kernel for
each bucket. We could, of course, implement this approach to the current
auto-scheduler and AutoTIR. However, as evaluated in the DietCode paper, this
approach is not guaranteed to achieve better performance as static shapes.

# Prior State-of-the-Arts
[prior-sotas]: #prior-sotas

- **Reuse-based Tuner**

Selective Tuning ([Cody Yu.
2019](https://github.com/apache/incubator-tvm/issues/4188)) and ETO ([Jingzhi
Fang et al. VLDB 2021](http://www.vldb.org/pvldb/vol15/p183-chen.pdf)) group
workloads into clusters based on a set of pre-defined rules (e.g., similarity
ratio in Selective Tuning) and reuse the same schedule in a single cluster.

- **Dynamic Neural Networks**

Dynamic batching is a common graph-level optimization adopted by frameworks
such as DyNet ([Graham Neubig et al. 2017](http://arxiv.org/abs/1701.03980)),
Cavs ([Shizhen Xu et al. USENIX ATC
2018](https://www.usenix.org/conference/atc18/presentation/xu-shizen)),
BatchMaker ([Pin Gao et al. EuroSys
2018](https://doi.org/10.1145/3190508.3190541)), and TensorFlow Fold ([Moshe
Looks et al. ICLR 2017](https://openreview.net/forum?id=ryrGawqex)) for cases
when the batch size is dynamic.

Nimble ([Haichen Shen et al. MLSys
2021](https://proceedings.mlsys.org/paper/2021/hash/4e732ced3463d06de0ca9a15b6153677-Abstract.html))
and DISC ([Kai Zhu et al. EuroMLSys
2021](https://dl.acm.org/doi/10.1145/3437984.3458838)) both design a compiler
to represent and execute dynamic neural networks.

Cortex ([Pratik Fegade et al. MLSys
2021](https://proceedings.mlsys.org/paper/2021/hash/182be0c5cdcd5072bb1864cdee4d3d6e-Abstract.html))
is a compiler-based framework on recursive neural networks.

Those works focus on the graph-level optimizations and therefore are
orthogonal to DietCode, which operates on each individual layer. In fact,
those graph-level solutions can also leverage DietCode for efficient operator
code generation.

# Unresolved Questions
[unresolved-questions]: #unresolved-questions

- The current design does not support arbitrary shape dimensions. For better
auto-scheduling outcomes, we expect that shape dimensions have to be specified
beforehand.
- The proposed approach mostly works on NVIDIA GPUs and has not been tested on
other hardware platforms.

# Future Possibilities
[future-possibilities]: #future-possibilities

- Evaluate more operator use cases.
- CPU Support

# Upstream Milestones
[upstream-milestones]: #upstream-milestones

We propose the following milestones for upstreaming, where each bullet point
corresponds to a PR with unit tests of roughly several hundred lines.

- [ ] Code Generation Support
ArmageddonKnight marked this conversation as resolved.
Show resolved Hide resolved
- Local Padding
- Loop Partitioning
- [ ] Auto-Scheduler
ArmageddonKnight marked this conversation as resolved.
Show resolved Hide resolved
- Frontend Interface
- Sketch Generation
- Random Annotations
- Program Measurer
- Micro-Kernel Cost Model
- Evolutionary Search
- [ ] Decision-Tree Dispatching