Skip to content

Commit

Permalink
Add a prototype of MX format training and inference
Browse files Browse the repository at this point in the history
Summary:

The MX numerical formats are new low precision formats with recent
acceptance into the OCP spec:
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf

This PR adds a reference native PyTorch implementation of training and
inference primitives for using MX accelerated matrix multiplications.
Currently, we use a reference layout (scale and raw data stored
separately) and an emulated matrix multiplication.

Test Plan:

```
// tests
pytest -s test/prototype/mx_formats/*
// benchmarks
python torchao/prototype/mx_formats/benchmarks/bench_qdq.py
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo committed May 24, 2024
1 parent 5e28109 commit 455f148
Show file tree
Hide file tree
Showing 16 changed files with 3,161 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ To learn more try out our APIs, you can check out API examples in
3. Support for lower precision [dtypes](./torchao/dtypes) such as
- [nf4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/nf4tensor.py) which was used to [implement QLoRA](https://github.com/pytorch/torchtune/blob/main/docs/source/tutorials/qlora_finetune.rst) without writing custom Triton or CUDA code
- [uint4](https://github.com/pytorch/ao/blob/main/torchao/dtypes/uint4.py)
- [MX](https://github.com/pytorch/ao/blob/main/torchao/prototype/mx_formats) implementing the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf), prototype as the hardware support is not available yet
4. [Bleeding Edge Kernels](./torchao/prototype/) for experimental kernels without backwards compatibility guarantees
- [GaLore](https://github.com/pytorch/ao/tree/main/torchao/prototype/galore) for memory efficient finetuning
- [fused HQQ Gemm Kernel](https://github.com/pytorch/ao/tree/main/torchao/prototype/hqq) for compute bound workloads
Expand Down
2 changes: 2 additions & 0 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ transformers
bitsandbytes #needed for testing triton quant / dequant ops for 8-bit optimizers
matplotlib
pandas
fire # QOL for commandline scripts
tabulate # QOL for printing tables to stdout

# Custom CUDA Extensions
ninja
Loading

0 comments on commit 455f148

Please sign in to comment.